base_engine_spec_tests.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from tests.test_app import app # isort:skip
  18. import datetime
  19. from unittest import mock
  20. from superset.db_engine_specs import engines
  21. from superset.db_engine_specs.base import BaseEngineSpec, builtin_time_grains
  22. from superset.db_engine_specs.sqlite import SqliteEngineSpec
  23. from superset.utils.core import get_example_database
  24. from tests.db_engine_specs.base_tests import DbEngineSpecTestCase
  25. from ..fixtures.pyodbcRow import Row
  26. class DbEngineSpecsTests(DbEngineSpecTestCase):
  27. def test_extract_limit_from_query(self, engine_spec_class=BaseEngineSpec):
  28. q0 = "select * from table"
  29. q1 = "select * from mytable limit 10"
  30. q2 = "select * from (select * from my_subquery limit 10) where col=1 limit 20"
  31. q3 = "select * from (select * from my_subquery limit 10);"
  32. q4 = "select * from (select * from my_subquery limit 10) where col=1 limit 20;"
  33. q5 = "select * from mytable limit 20, 10"
  34. q6 = "select * from mytable limit 10 offset 20"
  35. q7 = "select * from mytable limit"
  36. q8 = "select * from mytable limit 10.0"
  37. q9 = "select * from mytable limit x"
  38. q10 = "select * from mytable limit 20, x"
  39. q11 = "select * from mytable limit x offset 20"
  40. self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
  41. self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
  42. self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
  43. self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
  44. self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
  45. self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
  46. self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
  47. self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
  48. self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
  49. self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
  50. self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
  51. self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
  52. def test_wrapped_semi_tabs(self):
  53. self.sql_limit_regex(
  54. "SELECT * FROM a \t \n ; \t \n ", "SELECT * FROM a\nLIMIT 1000"
  55. )
  56. def test_simple_limit_query(self):
  57. self.sql_limit_regex("SELECT * FROM a", "SELECT * FROM a\nLIMIT 1000")
  58. def test_modify_limit_query(self):
  59. self.sql_limit_regex("SELECT * FROM a LIMIT 9999", "SELECT * FROM a LIMIT 1000")
  60. def test_limit_query_with_limit_subquery(self): # pylint: disable=invalid-name
  61. self.sql_limit_regex(
  62. "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
  63. "SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 1000",
  64. )
  65. def test_limit_with_expr(self):
  66. self.sql_limit_regex(
  67. """
  68. SELECT
  69. 'LIMIT 777' AS a
  70. , b
  71. FROM
  72. table
  73. LIMIT 99990""",
  74. """SELECT
  75. 'LIMIT 777' AS a
  76. , b
  77. FROM
  78. table
  79. LIMIT 1000""",
  80. )
  81. def test_limit_expr_and_semicolon(self):
  82. self.sql_limit_regex(
  83. """
  84. SELECT
  85. 'LIMIT 777' AS a
  86. , b
  87. FROM
  88. table
  89. LIMIT 99990 ;""",
  90. """SELECT
  91. 'LIMIT 777' AS a
  92. , b
  93. FROM
  94. table
  95. LIMIT 1000""",
  96. )
  97. def test_get_datatype(self):
  98. self.assertEqual("VARCHAR", BaseEngineSpec.get_datatype("VARCHAR"))
  99. def test_limit_with_implicit_offset(self):
  100. self.sql_limit_regex(
  101. """
  102. SELECT
  103. 'LIMIT 777' AS a
  104. , b
  105. FROM
  106. table
  107. LIMIT 99990, 999999""",
  108. """SELECT
  109. 'LIMIT 777' AS a
  110. , b
  111. FROM
  112. table
  113. LIMIT 99990, 1000""",
  114. )
  115. def test_limit_with_explicit_offset(self):
  116. self.sql_limit_regex(
  117. """
  118. SELECT
  119. 'LIMIT 777' AS a
  120. , b
  121. FROM
  122. table
  123. LIMIT 99990
  124. OFFSET 999999""",
  125. """SELECT
  126. 'LIMIT 777' AS a
  127. , b
  128. FROM
  129. table
  130. LIMIT 1000
  131. OFFSET 999999""",
  132. )
  133. def test_limit_with_non_token_limit(self):
  134. self.sql_limit_regex(
  135. """SELECT 'LIMIT 777'""", """SELECT 'LIMIT 777'\nLIMIT 1000"""
  136. )
  137. def test_time_grain_blacklist(self):
  138. with app.app_context():
  139. app.config["TIME_GRAIN_BLACKLIST"] = ["PT1M"]
  140. time_grain_functions = SqliteEngineSpec.get_time_grain_functions()
  141. self.assertNotIn("PT1M", time_grain_functions)
  142. def test_time_grain_addons(self):
  143. with app.app_context():
  144. app.config["TIME_GRAIN_ADDONS"] = {"PTXM": "x seconds"}
  145. app.config["TIME_GRAIN_ADDON_FUNCTIONS"] = {
  146. "sqlite": {"PTXM": "ABC({col})"}
  147. }
  148. time_grains = SqliteEngineSpec.get_time_grains()
  149. time_grain_addon = time_grains[-1]
  150. self.assertEqual("PTXM", time_grain_addon.duration)
  151. self.assertEqual("x seconds", time_grain_addon.label)
  152. def test_engine_time_grain_validity(self):
  153. time_grains = set(builtin_time_grains.keys())
  154. # loop over all subclasses of BaseEngineSpec
  155. for engine in engines.values():
  156. if engine is not BaseEngineSpec:
  157. # make sure time grain functions have been defined
  158. self.assertGreater(len(engine.get_time_grain_functions()), 0)
  159. # make sure all defined time grains are supported
  160. defined_grains = {grain.duration for grain in engine.get_time_grains()}
  161. intersection = time_grains.intersection(defined_grains)
  162. self.assertSetEqual(defined_grains, intersection, engine)
  163. def test_get_table_names(self):
  164. inspector = mock.Mock()
  165. inspector.get_table_names = mock.Mock(return_value=["schema.table", "table_2"])
  166. inspector.get_foreign_table_names = mock.Mock(return_value=["table_3"])
  167. """ Make sure base engine spec removes schema name from table name
  168. ie. when try_remove_schema_from_table_name == True. """
  169. base_result_expected = ["table", "table_2"]
  170. base_result = BaseEngineSpec.get_table_names(
  171. database=mock.ANY, schema="schema", inspector=inspector
  172. )
  173. self.assertListEqual(base_result_expected, base_result)
  174. def test_column_datatype_to_string(self):
  175. example_db = get_example_database()
  176. sqla_table = example_db.get_table("energy_usage")
  177. dialect = example_db.get_dialect()
  178. col_names = [
  179. example_db.db_engine_spec.column_datatype_to_string(c.type, dialect)
  180. for c in sqla_table.columns
  181. ]
  182. if example_db.backend == "postgresql":
  183. expected = ["VARCHAR(255)", "VARCHAR(255)", "DOUBLE PRECISION"]
  184. else:
  185. expected = ["VARCHAR(255)", "VARCHAR(255)", "FLOAT"]
  186. self.assertEqual(col_names, expected)
  187. def test_convert_dttm(self):
  188. dttm = self.get_dttm()
  189. self.assertIsNone(BaseEngineSpec.convert_dttm("", dttm))
  190. def test_pyodbc_rows_to_tuples(self):
  191. # Test for case when pyodbc.Row is returned (odbc driver)
  192. data = [
  193. Row((1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000))),
  194. Row((2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000))),
  195. ]
  196. expected = [
  197. (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
  198. (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
  199. ]
  200. result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
  201. self.assertListEqual(result, expected)
  202. def test_pyodbc_rows_to_tuples_passthrough(self):
  203. # Test for case when tuples are returned
  204. data = [
  205. (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
  206. (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
  207. ]
  208. result = BaseEngineSpec.pyodbc_rows_to_tuples(data)
  209. self.assertListEqual(result, data)