core_tests.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. # isort:skip_file
  18. """Unit tests for Superset"""
  19. import cgi
  20. import csv
  21. import datetime
  22. import doctest
  23. import io
  24. import json
  25. import logging
  26. import os
  27. import pytz
  28. import random
  29. import re
  30. import string
  31. from typing import Any, Dict
  32. import unittest
  33. from unittest import mock, skipUnless
  34. import pandas as pd
  35. import sqlalchemy as sqla
  36. from tests.test_app import app
  37. from superset import (
  38. dataframe,
  39. db,
  40. jinja_context,
  41. security_manager,
  42. sql_lab,
  43. is_feature_enabled,
  44. )
  45. from superset.common.query_context import QueryContext
  46. from superset.connectors.connector_registry import ConnectorRegistry
  47. from superset.connectors.sqla.models import SqlaTable
  48. from superset.db_engine_specs.base import BaseEngineSpec
  49. from superset.db_engine_specs.mssql import MssqlEngineSpec
  50. from superset.models import core as models
  51. from superset.models.dashboard import Dashboard
  52. from superset.models.datasource_access_request import DatasourceAccessRequest
  53. from superset.models.slice import Slice
  54. from superset.models.sql_lab import Query
  55. from superset.result_set import SupersetResultSet
  56. from superset.utils import core as utils
  57. from superset.views import core as views
  58. from superset.views.database.views import DatabaseView
  59. from .base_tests import SupersetTestCase
  60. logger = logging.getLogger(__name__)
  61. class CoreTests(SupersetTestCase):
  62. def __init__(self, *args, **kwargs):
  63. super(CoreTests, self).__init__(*args, **kwargs)
  64. def setUp(self):
  65. db.session.query(Query).delete()
  66. db.session.query(DatasourceAccessRequest).delete()
  67. db.session.query(models.Log).delete()
  68. self.table_ids = {
  69. tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all())
  70. }
  71. def tearDown(self):
  72. db.session.query(Query).delete()
  73. def test_login(self):
  74. resp = self.get_resp("/login/", data=dict(username="admin", password="general"))
  75. self.assertNotIn("User confirmation needed", resp)
  76. resp = self.get_resp("/logout/", follow_redirects=True)
  77. self.assertIn("User confirmation needed", resp)
  78. resp = self.get_resp(
  79. "/login/", data=dict(username="admin", password="wrongPassword")
  80. )
  81. self.assertIn("User confirmation needed", resp)
  82. def test_dashboard_endpoint(self):
  83. resp = self.client.get("/superset/dashboard/-1/")
  84. assert resp.status_code == 404
  85. def test_slice_endpoint(self):
  86. self.login(username="admin")
  87. slc = self.get_slice("Girls", db.session)
  88. resp = self.get_resp("/superset/slice/{}/".format(slc.id))
  89. assert "Time Column" in resp
  90. assert "List Roles" in resp
  91. # Testing overrides
  92. resp = self.get_resp("/superset/slice/{}/?standalone=true".format(slc.id))
  93. assert '<div class="navbar' not in resp
  94. resp = self.client.get("/superset/slice/-1/")
  95. assert resp.status_code == 404
  96. def _get_query_context_dict(self) -> Dict[str, Any]:
  97. self.login(username="admin")
  98. slc = self.get_slice("Girl Name Cloud", db.session)
  99. return {
  100. "datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
  101. "queries": [
  102. {
  103. "granularity": "ds",
  104. "groupby": ["name"],
  105. "metrics": [{"label": "sum__num"}],
  106. "filters": [],
  107. "row_limit": 100,
  108. }
  109. ],
  110. }
  111. def test_viz_cache_key(self):
  112. self.login(username="admin")
  113. slc = self.get_slice("Girls", db.session)
  114. viz = slc.viz
  115. qobj = viz.query_obj()
  116. cache_key = viz.cache_key(qobj)
  117. self.assertEqual(cache_key, viz.cache_key(qobj))
  118. qobj["groupby"] = []
  119. self.assertNotEqual(cache_key, viz.cache_key(qobj))
  120. def test_cache_key_changes_when_datasource_is_updated(self):
  121. qc_dict = self._get_query_context_dict()
  122. # construct baseline cache_key
  123. query_context = QueryContext(**qc_dict)
  124. query_object = query_context.queries[0]
  125. cache_key_original = query_context.cache_key(query_object)
  126. # make temporary change and revert it to refresh the changed_on property
  127. datasource = ConnectorRegistry.get_datasource(
  128. datasource_type=qc_dict["datasource"]["type"],
  129. datasource_id=qc_dict["datasource"]["id"],
  130. session=db.session,
  131. )
  132. description_original = datasource.description
  133. datasource.description = "temporary description"
  134. db.session.commit()
  135. datasource.description = description_original
  136. db.session.commit()
  137. # create new QueryContext with unchanged attributes and extract new cache_key
  138. query_context = QueryContext(**qc_dict)
  139. query_object = query_context.queries[0]
  140. cache_key_new = query_context.cache_key(query_object)
  141. # the new cache_key should be different due to updated datasource
  142. self.assertNotEqual(cache_key_original, cache_key_new)
  143. def test_get_superset_tables_not_allowed(self):
  144. example_db = utils.get_example_database()
  145. schema_name = self.default_schema_backend_map[example_db.backend]
  146. self.login(username="gamma")
  147. uri = f"superset/tables/{example_db.id}/{schema_name}/undefined/"
  148. rv = self.client.get(uri)
  149. self.assertEqual(rv.status_code, 404)
  150. def test_get_superset_tables_substr(self):
  151. example_db = utils.get_example_database()
  152. self.login(username="admin")
  153. schema_name = self.default_schema_backend_map[example_db.backend]
  154. uri = f"superset/tables/{example_db.id}/{schema_name}/ab_role/"
  155. rv = self.client.get(uri)
  156. response = json.loads(rv.data.decode("utf-8"))
  157. self.assertEqual(rv.status_code, 200)
  158. expeted_response = {
  159. "options": [
  160. {
  161. "label": "ab_role",
  162. "schema": schema_name,
  163. "title": "ab_role",
  164. "type": "table",
  165. "value": "ab_role",
  166. }
  167. ],
  168. "tableLength": 1,
  169. }
  170. self.assertEqual(response, expeted_response)
  171. def test_get_superset_tables_not_found(self):
  172. self.login(username="admin")
  173. uri = f"superset/tables/invalid/public/undefined/"
  174. rv = self.client.get(uri)
  175. self.assertEqual(rv.status_code, 404)
  176. def test_api_v1_query_endpoint(self):
  177. self.login(username="admin")
  178. qc_dict = self._get_query_context_dict()
  179. data = json.dumps(qc_dict)
  180. resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
  181. self.assertEqual(resp[0]["rowcount"], 100)
  182. def test_old_slice_json_endpoint(self):
  183. self.login(username="admin")
  184. slc = self.get_slice("Girls", db.session)
  185. json_endpoint = "/superset/explore_json/{}/{}/".format(
  186. slc.datasource_type, slc.datasource_id
  187. )
  188. resp = self.get_resp(
  189. json_endpoint, {"form_data": json.dumps(slc.viz.form_data)}
  190. )
  191. assert '"Jennifer"' in resp
  192. def test_slice_json_endpoint(self):
  193. self.login(username="admin")
  194. slc = self.get_slice("Girls", db.session)
  195. resp = self.get_resp(slc.explore_json_url)
  196. assert '"Jennifer"' in resp
  197. def test_old_slice_csv_endpoint(self):
  198. self.login(username="admin")
  199. slc = self.get_slice("Girls", db.session)
  200. csv_endpoint = "/superset/explore_json/{}/{}/?csv=true".format(
  201. slc.datasource_type, slc.datasource_id
  202. )
  203. resp = self.get_resp(csv_endpoint, {"form_data": json.dumps(slc.viz.form_data)})
  204. assert "Jennifer," in resp
  205. def test_slice_csv_endpoint(self):
  206. self.login(username="admin")
  207. slc = self.get_slice("Girls", db.session)
  208. csv_endpoint = "/superset/explore_json/?csv=true"
  209. resp = self.get_resp(
  210. csv_endpoint, {"form_data": json.dumps({"slice_id": slc.id})}
  211. )
  212. assert "Jennifer," in resp
  213. def test_admin_only_permissions(self):
  214. def assert_admin_permission_in(role_name, assert_func):
  215. role = security_manager.find_role(role_name)
  216. permissions = [p.permission.name for p in role.permissions]
  217. assert_func("can_sync_druid_source", permissions)
  218. assert_func("can_approve", permissions)
  219. assert_admin_permission_in("Admin", self.assertIn)
  220. assert_admin_permission_in("Alpha", self.assertNotIn)
  221. assert_admin_permission_in("Gamma", self.assertNotIn)
  222. def test_admin_only_menu_views(self):
  223. def assert_admin_view_menus_in(role_name, assert_func):
  224. role = security_manager.find_role(role_name)
  225. view_menus = [p.view_menu.name for p in role.permissions]
  226. assert_func("ResetPasswordView", view_menus)
  227. assert_func("RoleModelView", view_menus)
  228. assert_func("Security", view_menus)
  229. assert_func("SQL Lab", view_menus)
  230. assert_admin_view_menus_in("Admin", self.assertIn)
  231. assert_admin_view_menus_in("Alpha", self.assertNotIn)
  232. assert_admin_view_menus_in("Gamma", self.assertNotIn)
  233. def test_save_slice(self):
  234. self.login(username="admin")
  235. slice_name = f"Energy Sankey"
  236. slice_id = self.get_slice(slice_name, db.session).id
  237. copy_name = f"Test Sankey Save_{random.random()}"
  238. tbl_id = self.table_ids.get("energy_usage")
  239. new_slice_name = f"Test Sankey Overwrite_{random.random()}"
  240. url = (
  241. "/superset/explore/table/{}/?slice_name={}&"
  242. "action={}&datasource_name=energy_usage"
  243. )
  244. form_data = {
  245. "viz_type": "sankey",
  246. "groupby": "target",
  247. "metric": "sum__value",
  248. "row_limit": 5000,
  249. "slice_id": slice_id,
  250. "time_range_endpoints": ["inclusive", "exclusive"],
  251. }
  252. # Changing name and save as a new slice
  253. resp = self.client.post(
  254. url.format(tbl_id, copy_name, "saveas"),
  255. data={"form_data": json.dumps(form_data)},
  256. )
  257. db.session.expunge_all()
  258. new_slice_id = resp.json["form_data"]["slice_id"]
  259. slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
  260. self.assertEqual(slc.slice_name, copy_name)
  261. form_data.pop("slice_id") # We don't save the slice id when saving as
  262. self.assertEqual(slc.viz.form_data, form_data)
  263. form_data = {
  264. "viz_type": "sankey",
  265. "groupby": "source",
  266. "metric": "sum__value",
  267. "row_limit": 5000,
  268. "slice_id": new_slice_id,
  269. "time_range": "now",
  270. "time_range_endpoints": ["inclusive", "exclusive"],
  271. }
  272. # Setting the name back to its original name by overwriting new slice
  273. self.client.post(
  274. url.format(tbl_id, new_slice_name, "overwrite"),
  275. data={"form_data": json.dumps(form_data)},
  276. )
  277. db.session.expunge_all()
  278. slc = db.session.query(Slice).filter_by(id=new_slice_id).one()
  279. self.assertEqual(slc.slice_name, new_slice_name)
  280. self.assertEqual(slc.viz.form_data, form_data)
  281. # Cleanup
  282. db.session.delete(slc)
  283. db.session.commit()
  284. def test_filter_endpoint(self):
  285. self.login(username="admin")
  286. slice_name = "Energy Sankey"
  287. slice_id = self.get_slice(slice_name, db.session).id
  288. db.session.commit()
  289. tbl_id = self.table_ids.get("energy_usage")
  290. table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id)
  291. table.filter_select_enabled = True
  292. url = (
  293. "/superset/filter/table/{}/target/?viz_type=sankey&groupby=source"
  294. "&metric=sum__value&flt_col_0=source&flt_op_0=in&flt_eq_0=&"
  295. "slice_id={}&datasource_name=energy_usage&"
  296. "datasource_id=1&datasource_type=table"
  297. )
  298. # Changing name
  299. resp = self.get_resp(url.format(tbl_id, slice_id))
  300. assert len(resp) > 0
  301. assert "Carbon Dioxide" in resp
  302. def test_slice_data(self):
  303. # slice data should have some required attributes
  304. self.login(username="admin")
  305. slc = self.get_slice("Girls", db.session)
  306. slc_data_attributes = slc.data.keys()
  307. assert "changed_on" in slc_data_attributes
  308. assert "modified" in slc_data_attributes
  309. def test_slices(self):
  310. # Testing by hitting the two supported end points for all slices
  311. self.login(username="admin")
  312. Slc = Slice
  313. urls = []
  314. for slc in db.session.query(Slc).all():
  315. urls += [
  316. (slc.slice_name, "explore", slc.slice_url),
  317. (slc.slice_name, "explore_json", slc.explore_json_url),
  318. ]
  319. for name, method, url in urls:
  320. logger.info(f"[{name}]/[{method}]: {url}")
  321. print(f"[{name}]/[{method}]: {url}")
  322. resp = self.client.get(url)
  323. self.assertEqual(resp.status_code, 200)
  324. def test_tablemodelview_list(self):
  325. self.login(username="admin")
  326. url = "/tablemodelview/list/"
  327. resp = self.get_resp(url)
  328. # assert that a table is listed
  329. table = db.session.query(SqlaTable).first()
  330. assert table.name in resp
  331. assert "/superset/explore/table/{}".format(table.id) in resp
  332. def test_add_slice(self):
  333. self.login(username="admin")
  334. # assert that /chart/add responds with 200
  335. url = "/chart/add"
  336. resp = self.client.get(url)
  337. self.assertEqual(resp.status_code, 200)
  338. def test_get_user_slices(self):
  339. self.login(username="admin")
  340. userid = security_manager.find_user("admin").id
  341. url = f"/sliceasync/api/read?_flt_0_created_by={userid}"
  342. resp = self.client.get(url)
  343. self.assertEqual(resp.status_code, 200)
  344. def test_slices_V2(self):
  345. # Add explore-v2-beta role to admin user
  346. # Test all slice urls as user with with explore-v2-beta role
  347. security_manager.add_role("explore-v2-beta")
  348. security_manager.add_user(
  349. "explore_beta",
  350. "explore_beta",
  351. " user",
  352. "explore_beta@airbnb.com",
  353. security_manager.find_role("explore-v2-beta"),
  354. password="general",
  355. )
  356. self.login(username="explore_beta", password="general")
  357. Slc = Slice
  358. urls = []
  359. for slc in db.session.query(Slc).all():
  360. urls += [(slc.slice_name, "slice_url", slc.slice_url)]
  361. for name, method, url in urls:
  362. print(f"[{name}]/[{method}]: {url}")
  363. self.client.get(url)
  364. def test_doctests(self):
  365. modules = [utils, models, sql_lab]
  366. for mod in modules:
  367. failed, tests = doctest.testmod(mod)
  368. if failed:
  369. raise Exception("Failed a doctest")
  370. def test_misc(self):
  371. assert self.get_resp("/health") == "OK"
  372. assert self.get_resp("/healthcheck") == "OK"
  373. assert self.get_resp("/ping") == "OK"
  374. def test_testconn(self, username="admin"):
  375. self.login(username=username)
  376. database = utils.get_example_database()
  377. # validate that the endpoint works with the password-masked sqlalchemy uri
  378. data = json.dumps(
  379. {
  380. "uri": database.safe_sqlalchemy_uri(),
  381. "name": "examples",
  382. "impersonate_user": False,
  383. }
  384. )
  385. response = self.client.post(
  386. "/superset/testconn", data=data, content_type="application/json"
  387. )
  388. assert response.status_code == 200
  389. assert response.headers["Content-Type"] == "application/json"
  390. # validate that the endpoint works with the decrypted sqlalchemy uri
  391. data = json.dumps(
  392. {
  393. "uri": database.sqlalchemy_uri_decrypted,
  394. "name": "examples",
  395. "impersonate_user": False,
  396. }
  397. )
  398. response = self.client.post(
  399. "/superset/testconn", data=data, content_type="application/json"
  400. )
  401. assert response.status_code == 200
  402. assert response.headers["Content-Type"] == "application/json"
  403. def test_testconn_failed_conn(self, username="admin"):
  404. self.login(username=username)
  405. data = json.dumps(
  406. {"uri": "broken://url", "name": "examples", "impersonate_user": False}
  407. )
  408. response = self.client.post(
  409. "/superset/testconn", data=data, content_type="application/json"
  410. )
  411. assert response.status_code == 400
  412. assert response.headers["Content-Type"] == "application/json"
  413. response_body = json.loads(response.data.decode("utf-8"))
  414. expected_body = {
  415. "error": "Connection failed!\n\nThe error message returned was:\nCan't load plugin: sqlalchemy.dialects:broken"
  416. }
  417. assert response_body == expected_body, "%s != %s" % (
  418. response_body,
  419. expected_body,
  420. )
  421. def test_custom_password_store(self):
  422. database = utils.get_example_database()
  423. conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
  424. def custom_password_store(uri):
  425. return "password_store_test"
  426. models.custom_password_store = custom_password_store
  427. conn = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted)
  428. if conn_pre.password:
  429. assert conn.password == "password_store_test"
  430. assert conn.password != conn_pre.password
  431. # Disable for password store for later tests
  432. models.custom_password_store = None
  433. def test_databaseview_edit(self, username="admin"):
  434. # validate that sending a password-masked uri does not over-write the decrypted
  435. # uri
  436. self.login(username=username)
  437. database = utils.get_example_database()
  438. sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
  439. url = "databaseview/edit/{}".format(database.id)
  440. data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns}
  441. data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
  442. self.client.post(url, data=data)
  443. database = utils.get_example_database()
  444. self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted)
  445. # Need to clean up after ourselves
  446. database.impersonate_user = False
  447. database.allow_dml = False
  448. database.allow_run_async = False
  449. db.session.commit()
  450. def test_warm_up_cache(self):
  451. slc = self.get_slice("Girls", db.session)
  452. data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
  453. self.assertEqual(data, [{"slice_id": slc.id, "slice_name": slc.slice_name}])
  454. data = self.get_json_resp(
  455. "/superset/warm_up_cache?table_name=energy_usage&db_name=main"
  456. )
  457. assert len(data) > 0
  458. def test_shortner(self):
  459. self.login(username="admin")
  460. data = (
  461. "//superset/explore/table/1/?viz_type=sankey&groupby=source&"
  462. "groupby=target&metric=sum__value&row_limit=5000&where=&having=&"
  463. "flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id=78&slice_name="
  464. "Energy+Sankey&collapsed_fieldsets=&action=&datasource_name="
  465. "energy_usage&datasource_id=1&datasource_type=table&"
  466. "previous_viz_type=sankey"
  467. )
  468. resp = self.client.post("/r/shortner/", data=dict(data=data))
  469. assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8"))
  470. @skipUnless(
  471. (is_feature_enabled("KV_STORE")), "skipping as /kv/ endpoints are not enabled"
  472. )
  473. def test_kv(self):
  474. self.login(username="admin")
  475. resp = self.client.get("/kv/10001/")
  476. self.assertEqual(404, resp.status_code)
  477. value = json.dumps({"data": "this is a test"})
  478. resp = self.client.post("/kv/store/", data=dict(data=value))
  479. self.assertEqual(resp.status_code, 200)
  480. kv = db.session.query(models.KeyValue).first()
  481. kv_value = kv.value
  482. self.assertEqual(json.loads(value), json.loads(kv_value))
  483. resp = self.client.get("/kv/{}/".format(kv.id))
  484. self.assertEqual(resp.status_code, 200)
  485. self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8")))
  486. def test_gamma(self):
  487. self.login(username="gamma")
  488. assert "Charts" in self.get_resp("/chart/list/")
  489. assert "Dashboards" in self.get_resp("/dashboard/list/")
  490. def test_csv_endpoint(self):
  491. self.login("admin")
  492. sql = """
  493. SELECT name
  494. FROM birth_names
  495. WHERE name = 'James'
  496. LIMIT 1
  497. """
  498. client_id = "{}".format(random.getrandbits(64))[:10]
  499. self.run_sql(sql, client_id, raise_on_error=True)
  500. resp = self.get_resp("/superset/csv/{}".format(client_id))
  501. data = csv.reader(io.StringIO(resp))
  502. expected_data = csv.reader(io.StringIO("name\nJames\n"))
  503. client_id = "{}".format(random.getrandbits(64))[:10]
  504. self.run_sql(sql, client_id, raise_on_error=True)
  505. resp = self.get_resp("/superset/csv/{}".format(client_id))
  506. data = csv.reader(io.StringIO(resp))
  507. expected_data = csv.reader(io.StringIO("name\nJames\n"))
  508. self.assertEqual(list(expected_data), list(data))
  509. self.logout()
  510. def test_extra_table_metadata(self):
  511. self.login("admin")
  512. dbid = utils.get_example_database().id
  513. self.get_json_resp(
  514. f"/superset/extra_table_metadata/{dbid}/birth_names/superset/"
  515. )
  516. def test_process_template(self):
  517. maindb = utils.get_example_database()
  518. sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
  519. tp = jinja_context.get_template_processor(database=maindb)
  520. rendered = tp.process_template(sql)
  521. self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered)
  522. def test_get_template_kwarg(self):
  523. maindb = utils.get_example_database()
  524. s = "{{ foo }}"
  525. tp = jinja_context.get_template_processor(database=maindb, foo="bar")
  526. rendered = tp.process_template(s)
  527. self.assertEqual("bar", rendered)
  528. def test_template_kwarg(self):
  529. maindb = utils.get_example_database()
  530. s = "{{ foo }}"
  531. tp = jinja_context.get_template_processor(database=maindb)
  532. rendered = tp.process_template(s, foo="bar")
  533. self.assertEqual("bar", rendered)
  534. def test_templated_sql_json(self):
  535. self.login("admin")
  536. sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}' as test"
  537. data = self.run_sql(sql, "fdaklj3ws")
  538. self.assertEqual(data["data"][0]["test"], "2017-01-01T00:00:00")
  539. def test_fetch_datasource_metadata(self):
  540. self.login(username="admin")
  541. url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table"
  542. resp = self.get_json_resp(url)
  543. keys = [
  544. "name",
  545. "type",
  546. "order_by_choices",
  547. "granularity_sqla",
  548. "time_grain_sqla",
  549. "id",
  550. ]
  551. for k in keys:
  552. self.assertIn(k, resp.keys())
  553. def test_user_profile(self, username="admin"):
  554. self.login(username=username)
  555. slc = self.get_slice("Girls", db.session)
  556. # Setting some faves
  557. url = "/superset/favstar/Slice/{}/select/".format(slc.id)
  558. resp = self.get_json_resp(url)
  559. self.assertEqual(resp["count"], 1)
  560. dash = db.session.query(Dashboard).filter_by(slug="births").first()
  561. url = "/superset/favstar/Dashboard/{}/select/".format(dash.id)
  562. resp = self.get_json_resp(url)
  563. self.assertEqual(resp["count"], 1)
  564. userid = security_manager.find_user("admin").id
  565. resp = self.get_resp("/superset/profile/admin/")
  566. self.assertIn('"app"', resp)
  567. data = self.get_json_resp("/superset/recent_activity/{}/".format(userid))
  568. self.assertNotIn("message", data)
  569. data = self.get_json_resp("/superset/created_slices/{}/".format(userid))
  570. self.assertNotIn("message", data)
  571. data = self.get_json_resp("/superset/created_dashboards/{}/".format(userid))
  572. self.assertNotIn("message", data)
  573. data = self.get_json_resp("/superset/fave_slices/{}/".format(userid))
  574. self.assertNotIn("message", data)
  575. data = self.get_json_resp("/superset/fave_dashboards/{}/".format(userid))
  576. self.assertNotIn("message", data)
  577. data = self.get_json_resp(
  578. "/superset/fave_dashboards_by_username/{}/".format(username)
  579. )
  580. self.assertNotIn("message", data)
  581. def test_slice_id_is_always_logged_correctly_on_web_request(self):
  582. # superset/explore case
  583. slc = db.session.query(Slice).filter_by(slice_name="Girls").one()
  584. qry = db.session.query(models.Log).filter_by(slice_id=slc.id)
  585. self.get_resp(slc.slice_url, {"form_data": json.dumps(slc.form_data)})
  586. self.assertEqual(1, qry.count())
  587. def test_slice_id_is_always_logged_correctly_on_ajax_request(self):
  588. # superset/explore_json case
  589. self.login(username="admin")
  590. slc = db.session.query(Slice).filter_by(slice_name="Girls").one()
  591. qry = db.session.query(models.Log).filter_by(slice_id=slc.id)
  592. slc_url = slc.slice_url.replace("explore", "explore_json")
  593. self.get_json_resp(slc_url, {"form_data": json.dumps(slc.form_data)})
  594. self.assertEqual(1, qry.count())
  595. def test_slice_query_endpoint(self):
  596. # API endpoint for query string
  597. self.login(username="admin")
  598. slc = self.get_slice("Girls", db.session)
  599. resp = self.get_resp("/superset/slice_query/{}/".format(slc.id))
  600. assert "query" in resp
  601. assert "language" in resp
  602. self.logout()
  603. def test_import_csv(self):
  604. self.login(username="admin")
  605. table_name = "".join(random.choice(string.ascii_uppercase) for _ in range(5))
  606. filename_1 = "testCSV.csv"
  607. test_file_1 = open(filename_1, "w+")
  608. test_file_1.write("a,b\n")
  609. test_file_1.write("john,1\n")
  610. test_file_1.write("paul,2\n")
  611. test_file_1.close()
  612. filename_2 = "testCSV2.csv"
  613. test_file_2 = open(filename_2, "w+")
  614. test_file_2.write("b,c,d\n")
  615. test_file_2.write("john,1,x\n")
  616. test_file_2.write("paul,2,y\n")
  617. test_file_2.close()
  618. example_db = utils.get_example_database()
  619. example_db.allow_csv_upload = True
  620. db_id = example_db.id
  621. db.session.commit()
  622. form_data = {
  623. "csv_file": open(filename_1, "rb"),
  624. "sep": ",",
  625. "name": table_name,
  626. "con": db_id,
  627. "if_exists": "fail",
  628. "index_label": "test_label",
  629. "mangle_dupe_cols": False,
  630. }
  631. url = "/databaseview/list/"
  632. add_datasource_page = self.get_resp(url)
  633. self.assertIn("Upload a CSV", add_datasource_page)
  634. url = "/csvtodatabaseview/form"
  635. form_get = self.get_resp(url)
  636. self.assertIn("CSV to Database configuration", form_get)
  637. try:
  638. # initial upload with fail mode
  639. resp = self.get_resp(url, data=form_data)
  640. self.assertIn(
  641. f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp
  642. )
  643. # upload again with fail mode; should fail
  644. form_data["csv_file"] = open(filename_1, "rb")
  645. resp = self.get_resp(url, data=form_data)
  646. self.assertIn(
  647. f'Unable to upload CSV file "{filename_1}" to table "{table_name}"',
  648. resp,
  649. )
  650. # upload again with append mode
  651. form_data["csv_file"] = open(filename_1, "rb")
  652. form_data["if_exists"] = "append"
  653. resp = self.get_resp(url, data=form_data)
  654. self.assertIn(
  655. f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp
  656. )
  657. # upload again with replace mode
  658. form_data["csv_file"] = open(filename_1, "rb")
  659. form_data["if_exists"] = "replace"
  660. resp = self.get_resp(url, data=form_data)
  661. self.assertIn(
  662. f'CSV file "{filename_1}" uploaded to table "{table_name}"', resp
  663. )
  664. # try to append to table from file with different schema
  665. form_data["csv_file"] = open(filename_2, "rb")
  666. form_data["if_exists"] = "append"
  667. resp = self.get_resp(url, data=form_data)
  668. self.assertIn(
  669. f'Unable to upload CSV file "{filename_2}" to table "{table_name}"',
  670. resp,
  671. )
  672. # replace table from file with different schema
  673. form_data["csv_file"] = open(filename_2, "rb")
  674. form_data["if_exists"] = "replace"
  675. resp = self.get_resp(url, data=form_data)
  676. self.assertIn(
  677. f'CSV file "{filename_2}" uploaded to table "{table_name}"', resp
  678. )
  679. table = (
  680. db.session.query(SqlaTable)
  681. .filter_by(table_name=table_name, database_id=db_id)
  682. .first()
  683. )
  684. # make sure the new column name is reflected in the table metadata
  685. self.assertIn("d", table.column_names)
  686. finally:
  687. os.remove(filename_1)
  688. os.remove(filename_2)
  689. def test_dataframe_timezone(self):
  690. tz = pytz.FixedOffset(60)
  691. data = [
  692. (datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),),
  693. (datetime.datetime(2017, 11, 18, 22, 6, 30, tzinfo=tz),),
  694. ]
  695. results = SupersetResultSet(list(data), [["data"]], BaseEngineSpec)
  696. df = results.to_pandas_df()
  697. data = dataframe.df_to_records(df)
  698. json_str = json.dumps(data, default=utils.pessimistic_json_iso_dttm_ser)
  699. self.assertDictEqual(
  700. data[0], {"data": pd.Timestamp("2017-11-18 21:53:00.219225+0100", tz=tz)}
  701. )
  702. self.assertDictEqual(
  703. data[1], {"data": pd.Timestamp("2017-11-18 22:06:30+0100", tz=tz)}
  704. )
  705. self.assertEqual(
  706. json_str,
  707. '[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]',
  708. )
  709. def test_mssql_engine_spec_pymssql(self):
  710. # Test for case when tuple is returned (pymssql)
  711. data = [
  712. (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)),
  713. (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)),
  714. ]
  715. results = SupersetResultSet(
  716. list(data), [["col1"], ["col2"], ["col3"]], MssqlEngineSpec
  717. )
  718. df = results.to_pandas_df()
  719. data = dataframe.df_to_records(df)
  720. self.assertEqual(len(data), 2)
  721. self.assertEqual(
  722. data[0],
  723. {"col1": 1, "col2": 1, "col3": pd.Timestamp("2017-10-19 23:39:16.660000")},
  724. )
  725. def test_comments_in_sqlatable_query(self):
  726. clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl"
  727. commented_query = "/* comment 1 */" + clean_query + "-- comment 2"
  728. table = SqlaTable(
  729. table_name="test_comments_in_sqlatable_query_table", sql=commented_query
  730. )
  731. rendered_query = str(table.get_from_clause())
  732. self.assertEqual(clean_query, rendered_query)
  733. def test_slice_payload_no_data(self):
  734. self.login(username="admin")
  735. slc = self.get_slice("Girls", db.session)
  736. json_endpoint = "/superset/explore_json/"
  737. form_data = slc.form_data
  738. form_data.update(
  739. {
  740. "adhoc_filters": [
  741. {
  742. "clause": "WHERE",
  743. "comparator": "NA",
  744. "expressionType": "SIMPLE",
  745. "operator": "==",
  746. "subject": "gender",
  747. }
  748. ]
  749. }
  750. )
  751. data = self.get_json_resp(json_endpoint, {"form_data": json.dumps(form_data)})
  752. self.assertEqual(data["status"], utils.QueryStatus.SUCCESS)
  753. self.assertEqual(data["error"], "No data")
  754. def test_slice_payload_invalid_query(self):
  755. self.login(username="admin")
  756. slc = self.get_slice("Girls", db.session)
  757. form_data = slc.form_data
  758. form_data.update({"groupby": ["N/A"]})
  759. data = self.get_json_resp(
  760. "/superset/explore_json/", {"form_data": json.dumps(form_data)}
  761. )
  762. self.assertEqual(data["status"], utils.QueryStatus.FAILED)
  763. def test_slice_payload_no_datasource(self):
  764. self.login(username="admin")
  765. data = self.get_json_resp("/superset/explore_json/", raise_on_error=False)
  766. self.assertEqual(
  767. data["error"], "The datasource associated with this chart no longer exists"
  768. )
  769. @mock.patch("superset.security.SupersetSecurityManager.schemas_accessible_by_user")
  770. @mock.patch("superset.security.SupersetSecurityManager.database_access")
  771. @mock.patch("superset.security.SupersetSecurityManager.all_datasource_access")
  772. def test_schemas_access_for_csv_upload_endpoint(
  773. self, mock_all_datasource_access, mock_database_access, mock_schemas_accessible
  774. ):
  775. self.login(username="admin")
  776. dbobj = self.create_fake_db()
  777. mock_all_datasource_access.return_value = False
  778. mock_database_access.return_value = False
  779. mock_schemas_accessible.return_value = ["this_schema_is_allowed_too"]
  780. data = self.get_json_resp(
  781. url="/superset/schemas_access_for_csv_upload?db_id={db_id}".format(
  782. db_id=dbobj.id
  783. )
  784. )
  785. assert data == ["this_schema_is_allowed_too"]
  786. self.delete_fake_db()
  787. def test_select_star(self):
  788. self.login(username="admin")
  789. examples_db = utils.get_example_database()
  790. resp = self.get_resp(f"/superset/select_star/{examples_db.id}/birth_names")
  791. self.assertIn("gender", resp)
  792. def test_get_select_star_not_allowed(self):
  793. """
  794. Database API: Test get select star not allowed
  795. """
  796. self.login(username="gamma")
  797. example_db = utils.get_example_database()
  798. resp = self.client.get(f"/superset/select_star/{example_db.id}/birth_names")
  799. self.assertEqual(resp.status_code, 404)
  800. @mock.patch("superset.views.core.results_backend_use_msgpack", False)
  801. @mock.patch("superset.views.core.results_backend")
  802. @mock.patch("superset.views.core.db")
  803. def test_display_limit(self, mock_superset_db, mock_results_backend):
  804. query_mock = mock.Mock()
  805. query_mock.sql = "SELECT *"
  806. query_mock.database = 1
  807. query_mock.schema = "superset"
  808. mock_superset_db.session.query().filter_by().one_or_none.return_value = (
  809. query_mock
  810. )
  811. data = [{"col_0": i} for i in range(100)]
  812. payload = {
  813. "status": utils.QueryStatus.SUCCESS,
  814. "query": {"rows": 100},
  815. "data": data,
  816. }
  817. # do not apply msgpack serialization
  818. use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"]
  819. app.config["RESULTS_BACKEND_USE_MSGPACK"] = False
  820. serialized_payload = sql_lab._serialize_payload(payload, False)
  821. compressed = utils.zlib_compress(serialized_payload)
  822. mock_results_backend.get.return_value = compressed
  823. # get all results
  824. result = json.loads(self.get_resp("/superset/results/key/"))
  825. expected = {"status": "success", "query": {"rows": 100}, "data": data}
  826. self.assertEqual(result, expected)
  827. # limit results to 1
  828. limited_data = data[:1]
  829. result = json.loads(self.get_resp("/superset/results/key/?rows=1"))
  830. expected = {
  831. "status": "success",
  832. "query": {"rows": 100},
  833. "data": limited_data,
  834. "displayLimitReached": True,
  835. }
  836. self.assertEqual(result, expected)
  837. app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack
  838. def test_results_default_deserialization(self):
  839. use_new_deserialization = False
  840. data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")]
  841. cursor_descr = (
  842. ("a", "string"),
  843. ("b", "int"),
  844. ("c", "float"),
  845. ("d", "datetime"),
  846. )
  847. db_engine_spec = BaseEngineSpec()
  848. results = SupersetResultSet(data, cursor_descr, db_engine_spec)
  849. query = {
  850. "database_id": 1,
  851. "sql": "SELECT * FROM birth_names LIMIT 100",
  852. "status": utils.QueryStatus.PENDING,
  853. }
  854. serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
  855. results, db_engine_spec, use_new_deserialization
  856. )
  857. payload = {
  858. "query_id": 1,
  859. "status": utils.QueryStatus.SUCCESS,
  860. "state": utils.QueryStatus.SUCCESS,
  861. "data": serialized_data,
  862. "columns": all_columns,
  863. "selected_columns": selected_columns,
  864. "expanded_columns": expanded_columns,
  865. "query": query,
  866. }
  867. serialized_payload = sql_lab._serialize_payload(
  868. payload, use_new_deserialization
  869. )
  870. self.assertIsInstance(serialized_payload, str)
  871. query_mock = mock.Mock()
  872. deserialized_payload = views._deserialize_results_payload(
  873. serialized_payload, query_mock, use_new_deserialization
  874. )
  875. self.assertDictEqual(deserialized_payload, payload)
  876. query_mock.assert_not_called()
  877. def test_results_msgpack_deserialization(self):
  878. use_new_deserialization = True
  879. data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")]
  880. cursor_descr = (
  881. ("a", "string"),
  882. ("b", "int"),
  883. ("c", "float"),
  884. ("d", "datetime"),
  885. )
  886. db_engine_spec = BaseEngineSpec()
  887. results = SupersetResultSet(data, cursor_descr, db_engine_spec)
  888. query = {
  889. "database_id": 1,
  890. "sql": "SELECT * FROM birth_names LIMIT 100",
  891. "status": utils.QueryStatus.PENDING,
  892. }
  893. serialized_data, selected_columns, all_columns, expanded_columns = sql_lab._serialize_and_expand_data(
  894. results, db_engine_spec, use_new_deserialization
  895. )
  896. payload = {
  897. "query_id": 1,
  898. "status": utils.QueryStatus.SUCCESS,
  899. "state": utils.QueryStatus.SUCCESS,
  900. "data": serialized_data,
  901. "columns": all_columns,
  902. "selected_columns": selected_columns,
  903. "expanded_columns": expanded_columns,
  904. "query": query,
  905. }
  906. serialized_payload = sql_lab._serialize_payload(
  907. payload, use_new_deserialization
  908. )
  909. self.assertIsInstance(serialized_payload, bytes)
  910. with mock.patch.object(
  911. db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
  912. ) as expand_data:
  913. query_mock = mock.Mock()
  914. query_mock.database.db_engine_spec.expand_data = expand_data
  915. deserialized_payload = views._deserialize_results_payload(
  916. serialized_payload, query_mock, use_new_deserialization
  917. )
  918. df = results.to_pandas_df()
  919. payload["data"] = dataframe.df_to_records(df)
  920. self.assertDictEqual(deserialized_payload, payload)
  921. expand_data.assert_called_once()
  922. @mock.patch.dict(
  923. "superset.extensions.feature_flag_manager._feature_flags",
  924. {"FOO": lambda x: 1},
  925. clear=True,
  926. )
  927. def test_feature_flag_serialization(self):
  928. """
  929. Functions in feature flags don't break bootstrap data serialization.
  930. """
  931. self.login()
  932. encoded = json.dumps(
  933. {"FOO": lambda x: 1, "super": "set"},
  934. default=utils.pessimistic_json_iso_dttm_ser,
  935. )
  936. html = cgi.escape(encoded).replace("'", "&#39;").replace('"', "&#34;")
  937. urls = [
  938. "/superset/sqllab",
  939. "/superset/welcome",
  940. "/superset/dashboard/1/",
  941. "/superset/profile/admin/",
  942. "/superset/explore/table/1",
  943. ]
  944. for url in urls:
  945. data = self.get_resp(url)
  946. self.assertTrue(html in data)
  947. @mock.patch.dict(
  948. "superset.extensions.feature_flag_manager._feature_flags",
  949. {"SQLLAB_BACKEND_PERSISTENCE": True},
  950. clear=True,
  951. )
  952. def test_sqllab_backend_persistence_payload(self):
  953. username = "admin"
  954. self.login(username)
  955. user_id = security_manager.find_user(username).id
  956. # create a tab
  957. data = {
  958. "queryEditor": json.dumps(
  959. {
  960. "title": "Untitled Query 1",
  961. "dbId": 1,
  962. "schema": None,
  963. "autorun": False,
  964. "sql": "SELECT ...",
  965. "queryLimit": 1000,
  966. }
  967. )
  968. }
  969. resp = self.get_json_resp("/tabstateview/", data=data)
  970. tab_state_id = resp["id"]
  971. # run a query in the created tab
  972. self.run_sql(
  973. "SELECT name FROM birth_names",
  974. "client_id_1",
  975. user_name=username,
  976. raise_on_error=True,
  977. sql_editor_id=tab_state_id,
  978. )
  979. # run an orphan query (no tab)
  980. self.run_sql(
  981. "SELECT name FROM birth_names",
  982. "client_id_2",
  983. user_name=username,
  984. raise_on_error=True,
  985. )
  986. # we should have only 1 query returned, since the second one is not
  987. # associated with any tabs
  988. payload = views.Superset._get_sqllab_payload(user_id=user_id)
  989. self.assertEqual(len(payload["queries"]), 1)
  990. if __name__ == "__main__":
  991. unittest.main()