model_tests.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. # isort:skip_file
  18. import textwrap
  19. import unittest
  20. import pandas
  21. from sqlalchemy.engine.url import make_url
  22. import tests.test_app
  23. from superset import app
  24. from superset.models.core import Database
  25. from superset.utils.core import get_example_database, QueryStatus
  26. from .base_tests import SupersetTestCase
  27. class DatabaseModelTestCase(SupersetTestCase):
  28. @unittest.skipUnless(
  29. SupersetTestCase.is_module_installed("requests"), "requests not installed"
  30. )
  31. def test_database_schema_presto(self):
  32. sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive/default"
  33. model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
  34. db = make_url(model.get_sqla_engine().url).database
  35. self.assertEqual("hive/default", db)
  36. db = make_url(model.get_sqla_engine(schema="core_db").url).database
  37. self.assertEqual("hive/core_db", db)
  38. sqlalchemy_uri = "presto://presto.airbnb.io:8080/hive"
  39. model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
  40. db = make_url(model.get_sqla_engine().url).database
  41. self.assertEqual("hive", db)
  42. db = make_url(model.get_sqla_engine(schema="core_db").url).database
  43. self.assertEqual("hive/core_db", db)
  44. def test_database_schema_postgres(self):
  45. sqlalchemy_uri = "postgresql+psycopg2://postgres.airbnb.io:5439/prod"
  46. model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
  47. db = make_url(model.get_sqla_engine().url).database
  48. self.assertEqual("prod", db)
  49. db = make_url(model.get_sqla_engine(schema="foo").url).database
  50. self.assertEqual("prod", db)
  51. @unittest.skipUnless(
  52. SupersetTestCase.is_module_installed("thrift"), "thrift not installed"
  53. )
  54. @unittest.skipUnless(
  55. SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
  56. )
  57. def test_database_schema_hive(self):
  58. sqlalchemy_uri = "hive://hive@hive.airbnb.io:10000/default?auth=NOSASL"
  59. model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
  60. db = make_url(model.get_sqla_engine().url).database
  61. self.assertEqual("default", db)
  62. db = make_url(model.get_sqla_engine(schema="core_db").url).database
  63. self.assertEqual("core_db", db)
  64. @unittest.skipUnless(
  65. SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
  66. )
  67. def test_database_schema_mysql(self):
  68. sqlalchemy_uri = "mysql://root@localhost/superset"
  69. model = Database(database_name="test_database", sqlalchemy_uri=sqlalchemy_uri)
  70. db = make_url(model.get_sqla_engine().url).database
  71. self.assertEqual("superset", db)
  72. db = make_url(model.get_sqla_engine(schema="staging").url).database
  73. self.assertEqual("staging", db)
  74. @unittest.skipUnless(
  75. SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed"
  76. )
  77. def test_database_impersonate_user(self):
  78. uri = "mysql://root@localhost"
  79. example_user = "giuseppe"
  80. model = Database(database_name="test_database", sqlalchemy_uri=uri)
  81. model.impersonate_user = True
  82. user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
  83. self.assertEqual(example_user, user_name)
  84. model.impersonate_user = False
  85. user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
  86. self.assertNotEqual(example_user, user_name)
  87. def test_select_star(self):
  88. db = get_example_database()
  89. table_name = "energy_usage"
  90. sql = db.select_star(table_name, show_cols=False, latest_partition=False)
  91. expected = textwrap.dedent(
  92. f"""\
  93. SELECT *
  94. FROM {table_name}
  95. LIMIT 100"""
  96. )
  97. assert sql.startswith(expected)
  98. sql = db.select_star(table_name, show_cols=True, latest_partition=False)
  99. expected = textwrap.dedent(
  100. f"""\
  101. SELECT source,
  102. target,
  103. value
  104. FROM energy_usage
  105. LIMIT 100"""
  106. )
  107. assert sql.startswith(expected)
  108. def test_select_star_fully_qualified_names(self):
  109. db = get_example_database()
  110. schema = "schema.name"
  111. table_name = "table/name"
  112. sql = db.select_star(
  113. table_name, schema=schema, show_cols=False, latest_partition=False
  114. )
  115. fully_qualified_names = {
  116. "sqlite": '"schema.name"."table/name"',
  117. "mysql": "`schema.name`.`table/name`",
  118. "postgres": '"schema.name"."table/name"',
  119. }
  120. fully_qualified_name = fully_qualified_names.get(db.db_engine_spec.engine)
  121. if fully_qualified_name:
  122. expected = textwrap.dedent(
  123. f"""\
  124. SELECT *
  125. FROM {fully_qualified_name}
  126. LIMIT 100"""
  127. )
  128. assert sql.startswith(expected)
  129. def test_single_statement(self):
  130. main_db = get_example_database()
  131. if main_db.backend == "mysql":
  132. df = main_db.get_df("SELECT 1", None)
  133. self.assertEqual(df.iat[0, 0], 1)
  134. df = main_db.get_df("SELECT 1;", None)
  135. self.assertEqual(df.iat[0, 0], 1)
  136. def test_multi_statement(self):
  137. main_db = get_example_database()
  138. if main_db.backend == "mysql":
  139. df = main_db.get_df("USE superset; SELECT 1", None)
  140. self.assertEqual(df.iat[0, 0], 1)
  141. df = main_db.get_df("USE superset; SELECT ';';", None)
  142. self.assertEqual(df.iat[0, 0], ";")
  143. class SqlaTableModelTestCase(SupersetTestCase):
  144. def test_get_timestamp_expression(self):
  145. tbl = self.get_table_by_name("birth_names")
  146. ds_col = tbl.get_column("ds")
  147. sqla_literal = ds_col.get_timestamp_expression(None)
  148. self.assertEqual(str(sqla_literal.compile()), "ds")
  149. sqla_literal = ds_col.get_timestamp_expression("P1D")
  150. compiled = "{}".format(sqla_literal.compile())
  151. if tbl.database.backend == "mysql":
  152. self.assertEqual(compiled, "DATE(ds)")
  153. prev_ds_expr = ds_col.expression
  154. ds_col.expression = "DATE_ADD(ds, 1)"
  155. sqla_literal = ds_col.get_timestamp_expression("P1D")
  156. compiled = "{}".format(sqla_literal.compile())
  157. if tbl.database.backend == "mysql":
  158. self.assertEqual(compiled, "DATE(DATE_ADD(ds, 1))")
  159. ds_col.expression = prev_ds_expr
  160. def test_get_timestamp_expression_epoch(self):
  161. tbl = self.get_table_by_name("birth_names")
  162. ds_col = tbl.get_column("ds")
  163. ds_col.expression = None
  164. ds_col.python_date_format = "epoch_s"
  165. sqla_literal = ds_col.get_timestamp_expression(None)
  166. compiled = "{}".format(sqla_literal.compile())
  167. if tbl.database.backend == "mysql":
  168. self.assertEqual(compiled, "from_unixtime(ds)")
  169. ds_col.python_date_format = "epoch_s"
  170. sqla_literal = ds_col.get_timestamp_expression("P1D")
  171. compiled = "{}".format(sqla_literal.compile())
  172. if tbl.database.backend == "mysql":
  173. self.assertEqual(compiled, "DATE(from_unixtime(ds))")
  174. prev_ds_expr = ds_col.expression
  175. ds_col.expression = "DATE_ADD(ds, 1)"
  176. sqla_literal = ds_col.get_timestamp_expression("P1D")
  177. compiled = "{}".format(sqla_literal.compile())
  178. if tbl.database.backend == "mysql":
  179. self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))")
  180. ds_col.expression = prev_ds_expr
  181. def query_with_expr_helper(self, is_timeseries, inner_join=True):
  182. tbl = self.get_table_by_name("birth_names")
  183. ds_col = tbl.get_column("ds")
  184. ds_col.expression = None
  185. ds_col.python_date_format = None
  186. spec = self.get_database_by_id(tbl.database_id).db_engine_spec
  187. if not spec.allows_joins and inner_join:
  188. # if the db does not support inner joins, we cannot force it so
  189. return None
  190. old_inner_join = spec.allows_joins
  191. spec.allows_joins = inner_join
  192. arbitrary_gby = "state || gender || '_test'"
  193. arbitrary_metric = dict(
  194. label="arbitrary", expressionType="SQL", sqlExpression="COUNT(1)"
  195. )
  196. query_obj = dict(
  197. groupby=[arbitrary_gby, "name"],
  198. metrics=[arbitrary_metric],
  199. filter=[],
  200. is_timeseries=is_timeseries,
  201. columns=[],
  202. granularity="ds",
  203. from_dttm=None,
  204. to_dttm=None,
  205. extras=dict(time_grain_sqla="P1Y"),
  206. )
  207. qr = tbl.query(query_obj)
  208. self.assertEqual(qr.status, QueryStatus.SUCCESS)
  209. sql = qr.query
  210. self.assertIn(arbitrary_gby, sql)
  211. self.assertIn("name", sql)
  212. if inner_join and is_timeseries:
  213. self.assertIn("JOIN", sql.upper())
  214. else:
  215. self.assertNotIn("JOIN", sql.upper())
  216. spec.allows_joins = old_inner_join
  217. self.assertFalse(qr.df.empty)
  218. return qr.df
  219. def test_query_with_expr_groupby_timeseries(self):
  220. def cannonicalize_df(df):
  221. ret = df.sort_values(by=list(df.columns.values), inplace=False)
  222. ret.reset_index(inplace=True, drop=True)
  223. return ret
  224. df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True)
  225. df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
  226. self.assertFalse(df2.empty)
  227. # df1 can be empty if the db does not support join
  228. if not df1.empty:
  229. pandas.testing.assert_frame_equal(
  230. cannonicalize_df(df1), cannonicalize_df(df2)
  231. )
  232. def test_query_with_expr_groupby(self):
  233. self.query_with_expr_helper(is_timeseries=False)
  234. def test_sql_mutator(self):
  235. tbl = self.get_table_by_name("birth_names")
  236. query_obj = dict(
  237. groupby=[],
  238. metrics=[],
  239. filter=[],
  240. is_timeseries=False,
  241. columns=["name"],
  242. granularity=None,
  243. from_dttm=None,
  244. to_dttm=None,
  245. extras={},
  246. )
  247. sql = tbl.get_query_str(query_obj)
  248. self.assertNotIn("-- COMMENT", sql)
  249. def mutator(*args):
  250. return "-- COMMENT\n" + args[0]
  251. app.config["SQL_QUERY_MUTATOR"] = mutator
  252. sql = tbl.get_query_str(query_obj)
  253. self.assertIn("-- COMMENT", sql)
  254. app.config["SQL_QUERY_MUTATOR"] = None
  255. def test_query_with_non_existent_metrics(self):
  256. tbl = self.get_table_by_name("birth_names")
  257. query_obj = dict(
  258. groupby=[],
  259. metrics=["invalid"],
  260. filter=[],
  261. is_timeseries=False,
  262. columns=["name"],
  263. granularity=None,
  264. from_dttm=None,
  265. to_dttm=None,
  266. extras={},
  267. )
  268. with self.assertRaises(Exception) as context:
  269. tbl.get_query_str(query_obj)
  270. self.assertTrue("Metric 'invalid' does not exist", context.exception)