123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- from tests.test_app import app # isort:skip
- import datetime
- from unittest import mock
- from superset.db_engine_specs import engines
- from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains
- from superset.db_engine_specs.sqlite import SqliteEngineSpec
- from superset.utils.core import get_example_database
- from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
- from ..fixtures.pyodbcRow import Row
- class DbEngineSpecsTests(DbEngineSpecTestCase):
- def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
- q0 = "select * from table"
- q1 = "select * from mytable limit 10"
- q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
- q3 = "select * from (select * from my_subquery limit 10);"
- q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
- q5 = "select * from mytable limit 20, 10"
- q6 = "select * from mytable limit 10 offset 20"
- q7 = "select * from mytable limit"
- q8 = "select * from mytable limit 10.0"
- q9 = "select * from mytable limit x"
- q10 = "select * from mytable limit 20, x"
- q11 = "select * from mytable limit x offset 20"
- self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
- self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
- def test_wrapped_semi_tabs(self):
- self.sql_limit_regex(
- "SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
- )
- def test_simple_limit_query(self):
- self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
- def test_modify_limit_query(self):
- self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
- def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
- self.sql_limit_regex(
- "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
- "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
- )
- def test_limit_with_expr(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 1000""",
- )
- def test_limit_expr_and_semicolon(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990 ;""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 1000""",
- )
- def test_get_datatype(self):
- self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
- def test_limit_with_implicit_offset(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990, 999999""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990, 1000""",
- )
- def test_limit_with_explicit_offset(self):
- self.sql_limit_regex(
- """
- SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 99990
- OFFSET 999999""",
- """SELECT
- 'LIMIT 777' AS a
- , b
- FROM
- table
- LIMIT 1000
- OFFSET 999999""",
- )
- def test_limit_with_non_token_limit(self):
- self.sql_limit_regex(
- """SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
- )
- def test_time_grain_blacklist(self):
- with app.app_context():
- app.config["TIME_GRAIN_BLACKLIST"] = ["PT1M"]
- time_grain_functions = SqliteEngineSpec.get_time_grain_functions()
- self.assertNotIn("PT1M", time_grain_functions)
- def test_time_grain_addons(self):
- with app.app_context():
- app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
- app.config["TIME_GRAIN_ADDON_FUNCTIONS"] = {
- "sqlite": {"PTXM": "ABC({col})"}
- }
- time_grains = SqliteEngineSpec.get_time_grains()
- time_grain_addon = time_grains[-1]
- self.assertEqual("PTXM", time_grain_addon.duration)
- self.assertEqual("x seconds", time_grain_addon.label)
- def test_engine_time_grain_validity(self):
- time_grains = set(builtin_time_grains.keys())
- # loop over all subclasses of BaseEngineSpec
- for engine in engines.values():
- if engine is not BaseEngineSpec:
- # make sure time grain functions have been defined
- self.assertGreater(len(engine.get_time_grain_functions()), 0)
- # make sure all defined time grains are supported
- defined_grains = {grain.duration for grain in engine.get_time_grains()}
- intersection = time_grains.intersection(defined_grains)
- self.assertSetEqual(defined_grains, intersection, engine)
- def test_get_table_names(self):
- inspector = mock.Mock()
- inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
- inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
- """ Make sure base engine spec removes schema name from table name
- ie. when try_remove_schema_from_table_name == True. """
- base_result_expected = ["table", "table_2"]
- base_result = BaseEngineSpec.get_table_names(
- database=mock.ANY, schema="schema", inspector=inspector
- )
- self.assertListEqual(base_result_expected, base_result)
- def test_column_datatype_to_string(self):
- example_db = get_example_database()
- sqla_table = example_db.get_table("energy_usage")
- dialect = example_db.get_dialect()
- col_names = [
- example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
- for c in sqla_table.columns
- ]
- if example_db.backend == "postgresql":
- expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
- else:
- expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
- self.assertEqual(col_names, expected)
- def test_convert_dttm(self):
- dttm = self.get_dttm()
- self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm))
- def test_pyodbc_rows_to_tuples(self):
- # Test for case when pyodbc.Row is returned (odbc driver)
- data = [
- Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
- Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
- ]
- expected = [
- (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
- (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
- ]
- result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
- self.assertListEqual(result, expected)
- def test_pyodbc_rows_to_tuples_passthrough(self):
- # Test for case when tuples are returned
- data = [
- (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
- (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
- ]
- result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
- self.assertListEqual(result, data)
|