celery_tests.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. # isort:skip_file
  18. """Unit tests for Superset Celery worker"""
  19. import datetime
  20. import json
  21. import subprocess
  22. import time
  23. import unittest
  24. import unittest.mock as mock
  25. import flask
  26. from flask import current_app
  27. from tests.test_app import app
  28. from superset import db, sql_lab
  29. from superset.result_set import SupersetResultSet
  30. from superset.db_engine_specs.base import BaseEngineSpec
  31. from superset.extensions import celery_app
  32. from superset.models.helpers import QueryStatus
  33. from superset.models.sql_lab import Query
  34. from superset.sql_parse import ParsedQuery
  35. from superset.utils.core import get_example_database
  36. from .base_tests import SupersetTestCase
  37. CELERY_SLEEP_TIME = 5
  38. class UtilityFunctionTests(SupersetTestCase):
  39. # TODO(bkyryliuk): support more cases in CTA function.
  40. def test_create_table_as(self):
  41. q = ParsedQuery("SELECT * FROM outer_space;")
  42. self.assertEqual(
  43. "CREATE TABLE tmp AS \nSELECT * FROM outer_space", q.as_create_table("tmp")
  44. )
  45. self.assertEqual(
  46. "DROP TABLE IF EXISTS tmp;\n"
  47. "CREATE TABLE tmp AS \nSELECT * FROM outer_space",
  48. q.as_create_table("tmp", overwrite=True),
  49. )
  50. # now without a semicolon
  51. q = ParsedQuery("SELECT * FROM outer_space")
  52. self.assertEqual(
  53. "CREATE TABLE tmp AS \nSELECT * FROM outer_space", q.as_create_table("tmp")
  54. )
  55. # now a multi-line query
  56. multi_line_query = "SELECT * FROM planets WHERE\n" "Luke_Father = 'Darth Vader'"
  57. q = ParsedQuery(multi_line_query)
  58. self.assertEqual(
  59. "CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\n"
  60. "Luke_Father = 'Darth Vader'",
  61. q.as_create_table("tmp"),
  62. )
  63. class AppContextTests(SupersetTestCase):
  64. def test_in_app_context(self):
  65. @celery_app.task()
  66. def my_task():
  67. self.assertTrue(current_app)
  68. # Make sure we can call tasks with an app already setup
  69. my_task()
  70. # Make sure the app gets pushed onto the stack properly
  71. try:
  72. popped_app = flask._app_ctx_stack.pop()
  73. my_task()
  74. finally:
  75. flask._app_ctx_stack.push(popped_app)
  76. class CeleryTestCase(SupersetTestCase):
  77. def get_query_by_name(self, sql):
  78. session = db.session
  79. query = session.query(Query).filter_by(sql=sql).first()
  80. session.close()
  81. return query
  82. def get_query_by_id(self, id):
  83. session = db.session
  84. query = session.query(Query).filter_by(id=id).first()
  85. session.close()
  86. return query
  87. @classmethod
  88. def setUpClass(cls):
  89. with app.app_context():
  90. class CeleryConfig(object):
  91. BROKER_URL = app.config["CELERY_CONFIG"].BROKER_URL
  92. CELERY_IMPORTS = ("superset.sql_lab",)
  93. CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}}
  94. CONCURRENCY = 1
  95. app.config["CELERY_CONFIG"] = CeleryConfig
  96. db.session.query(Query).delete()
  97. db.session.commit()
  98. base_dir = app.config["BASE_DIR"]
  99. worker_command = base_dir + "/bin/superset worker -w 2"
  100. subprocess.Popen(worker_command, shell=True, stdout=subprocess.PIPE)
  101. @classmethod
  102. def tearDownClass(cls):
  103. subprocess.call(
  104. "ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9", shell=True
  105. )
  106. subprocess.call(
  107. "ps auxww | grep 'superset worker' | awk '{print $2}' | xargs kill -9",
  108. shell=True,
  109. )
  110. def run_sql(
  111. self, db_id, sql, client_id=None, cta=False, tmp_table="tmp", async_=False
  112. ):
  113. self.login()
  114. resp = self.client.post(
  115. "/superset/sql_json/",
  116. json=dict(
  117. database_id=db_id,
  118. sql=sql,
  119. runAsync=async_,
  120. select_as_cta=cta,
  121. tmp_table_name=tmp_table,
  122. client_id=client_id,
  123. ),
  124. )
  125. self.logout()
  126. return json.loads(resp.data)
  127. def test_run_sync_query_dont_exist(self):
  128. main_db = get_example_database()
  129. db_id = main_db.id
  130. sql_dont_exist = "SELECT name FROM table_dont_exist"
  131. result1 = self.run_sql(db_id, sql_dont_exist, "1", cta=True)
  132. self.assertTrue("error" in result1)
  133. def test_run_sync_query_cta(self):
  134. main_db = get_example_database()
  135. backend = main_db.backend
  136. db_id = main_db.id
  137. tmp_table_name = "tmp_async_22"
  138. self.drop_table_if_exists(tmp_table_name, main_db)
  139. name = "James"
  140. sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
  141. result = self.run_sql(db_id, sql_where, "2", tmp_table=tmp_table_name, cta=True)
  142. self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"])
  143. self.assertEqual([], result["data"])
  144. self.assertEqual([], result["columns"])
  145. query2 = self.get_query_by_id(result["query"]["serverId"])
  146. # Check the data in the tmp table.
  147. if backend != "postgresql":
  148. # TODO This test won't work in Postgres
  149. results = self.run_sql(db_id, query2.select_sql, "sdf2134")
  150. self.assertEqual(results["status"], "success")
  151. self.assertGreater(len(results["data"]), 0)
  152. def test_run_sync_query_cta_no_data(self):
  153. main_db = get_example_database()
  154. db_id = main_db.id
  155. sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
  156. result3 = self.run_sql(db_id, sql_empty_result, "3")
  157. self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"])
  158. self.assertEqual([], result3["data"])
  159. self.assertEqual([], result3["columns"])
  160. query3 = self.get_query_by_id(result3["query"]["serverId"])
  161. self.assertEqual(QueryStatus.SUCCESS, query3.status)
  162. def drop_table_if_exists(self, table_name, database=None):
  163. """Drop table if it exists, works on any DB"""
  164. sql = "DROP TABLE {}".format(table_name)
  165. db_id = database.id
  166. if database:
  167. database.allow_dml = True
  168. db.session.flush()
  169. return self.run_sql(db_id, sql)
  170. def test_run_async_query(self):
  171. main_db = get_example_database()
  172. db_id = main_db.id
  173. self.drop_table_if_exists("tmp_async_1", main_db)
  174. sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
  175. result = self.run_sql(
  176. db_id, sql_where, "4", async_=True, tmp_table="tmp_async_1", cta=True
  177. )
  178. db.session.close()
  179. assert result["query"]["state"] in (
  180. QueryStatus.PENDING,
  181. QueryStatus.RUNNING,
  182. QueryStatus.SUCCESS,
  183. )
  184. time.sleep(CELERY_SLEEP_TIME)
  185. query = self.get_query_by_id(result["query"]["serverId"])
  186. self.assertEqual(QueryStatus.SUCCESS, query.status)
  187. self.assertTrue("FROM tmp_async_1" in query.select_sql)
  188. self.assertEqual(
  189. "CREATE TABLE tmp_async_1 AS \n"
  190. "SELECT name FROM birth_names "
  191. "WHERE name='James' "
  192. "LIMIT 10",
  193. query.executed_sql,
  194. )
  195. self.assertEqual(sql_where, query.sql)
  196. self.assertEqual(0, query.rows)
  197. self.assertEqual(True, query.select_as_cta)
  198. self.assertEqual(True, query.select_as_cta_used)
  199. def test_run_async_query_with_lower_limit(self):
  200. main_db = get_example_database()
  201. db_id = main_db.id
  202. tmp_table = "tmp_async_2"
  203. self.drop_table_if_exists(tmp_table, main_db)
  204. sql_where = "SELECT name FROM birth_names LIMIT 1"
  205. result = self.run_sql(
  206. db_id, sql_where, "5", async_=True, tmp_table=tmp_table, cta=True
  207. )
  208. db.session.close()
  209. assert result["query"]["state"] in (
  210. QueryStatus.PENDING,
  211. QueryStatus.RUNNING,
  212. QueryStatus.SUCCESS,
  213. )
  214. time.sleep(CELERY_SLEEP_TIME)
  215. query = self.get_query_by_id(result["query"]["serverId"])
  216. self.assertEqual(QueryStatus.SUCCESS, query.status)
  217. self.assertTrue(f"FROM {tmp_table}" in query.select_sql)
  218. self.assertEqual(
  219. f"CREATE TABLE {tmp_table} AS \n" "SELECT name FROM birth_names LIMIT 1",
  220. query.executed_sql,
  221. )
  222. self.assertEqual(sql_where, query.sql)
  223. self.assertEqual(0, query.rows)
  224. self.assertEqual(1, query.limit)
  225. self.assertEqual(True, query.select_as_cta)
  226. self.assertEqual(True, query.select_as_cta_used)
  227. def test_default_data_serialization(self):
  228. data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
  229. cursor_descr = (
  230. ("a", "string"),
  231. ("b", "int"),
  232. ("c", "float"),
  233. ("d", "datetime"),
  234. )
  235. db_engine_spec = BaseEngineSpec()
  236. results = SupersetResultSet(data, cursor_descr, db_engine_spec)
  237. with mock.patch.object(
  238. db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
  239. ) as expand_data:
  240. data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
  241. results, db_engine_spec, False, True
  242. )
  243. expand_data.assert_called_once()
  244. self.assertIsInstance(data, list)
  245. def test_new_data_serialization(self):
  246. data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
  247. cursor_descr = (
  248. ("a", "string"),
  249. ("b", "int"),
  250. ("c", "float"),
  251. ("d", "datetime"),
  252. )
  253. db_engine_spec = BaseEngineSpec()
  254. results = SupersetResultSet(data, cursor_descr, db_engine_spec)
  255. with mock.patch.object(
  256. db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
  257. ) as expand_data:
  258. data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
  259. results, db_engine_spec, True
  260. )
  261. expand_data.assert_not_called()
  262. self.assertIsInstance(data, bytes)
  263. def test_default_payload_serialization(self):
  264. use_new_deserialization = False
  265. data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
  266. cursor_descr = (
  267. ("a", "string"),
  268. ("b", "int"),
  269. ("c", "float"),
  270. ("d", "datetime"),
  271. )
  272. db_engine_spec = BaseEngineSpec()
  273. results = SupersetResultSet(data, cursor_descr, db_engine_spec)
  274. query = {
  275. "database_id": 1,
  276. "sql": "SELECT * FROM birth_names LIMIT 100",
  277. "status": QueryStatus.PENDING,
  278. }
  279. serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
  280. results, db_engine_spec, use_new_deserialization
  281. )
  282. payload = {
  283. "query_id": 1,
  284. "status": QueryStatus.SUCCESS,
  285. "state": QueryStatus.SUCCESS,
  286. "data": serialized_data,
  287. "columns": all_columns,
  288. "selected_columns": selected_columns,
  289. "expanded_columns": expanded_columns,
  290. "query": query,
  291. }
  292. serialized = sql_lab._serialize_payload(payload, use_new_deserialization)
  293. self.assertIsInstance(serialized, str)
  294. def test_msgpack_payload_serialization(self):
  295. use_new_deserialization = True
  296. data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
  297. cursor_descr = (
  298. ("a", "string"),
  299. ("b", "int"),
  300. ("c", "float"),
  301. ("d", "datetime"),
  302. )
  303. db_engine_spec = BaseEngineSpec()
  304. results = SupersetResultSet(data, cursor_descr, db_engine_spec)
  305. query = {
  306. "database_id": 1,
  307. "sql": "SELECT * FROM birth_names LIMIT 100",
  308. "status": QueryStatus.PENDING,
  309. }
  310. serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
  311. results, db_engine_spec, use_new_deserialization
  312. )
  313. payload = {
  314. "query_id": 1,
  315. "status": QueryStatus.SUCCESS,
  316. "state": QueryStatus.SUCCESS,
  317. "data": serialized_data,
  318. "columns": all_columns,
  319. "selected_columns": selected_columns,
  320. "expanded_columns": expanded_columns,
  321. "query": query,
  322. }
  323. serialized = sql_lab._serialize_payload(payload, use_new_deserialization)
  324. self.assertIsInstance(serialized, bytes)
  325. @staticmethod
  326. def de_unicode_dict(d):
  327. def str_if_basestring(o):
  328. if isinstance(o, str):
  329. return str(o)
  330. return o
  331. return {str_if_basestring(k): str_if_basestring(d[k]) for k in d}
  332. @classmethod
  333. def dictify_list_of_dicts(cls, l, k):
  334. return {str(o[k]): cls.de_unicode_dict(o) for o in l}
  335. if __name__ == "__main__":
  336. unittest.main()