base_tests.py 10 KB


  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. # isort:skip_file
  18. """Unit tests for Superset"""
  19. import imp
  20. import json
  21. from typing import Union
  22. from unittest.mock import Mock
  23. import pandas as pd
  24. from flask_appbuilder.security.sqla import models as ab_models
  25. from flask_testing import TestCase
  26. from tests.test_app import app # isort:skip
  27. from superset import db, security_manager
  28. from superset.connectors.druid.models import DruidCluster, DruidDatasource
  29. from superset.connectors.sqla.models import SqlaTable
  30. from superset.models import core as models
  31. from superset.models.slice import Slice
  32. from superset.models.core import Database
  33. from superset.models.dashboard import Dashboard
  34. from superset.models.datasource_access_request import DatasourceAccessRequest
  35. from superset.utils.core import get_example_database
  36. FAKE_DB_NAME = "fake_db_100"
  37. class SupersetTestCase(TestCase):
  38. default_schema_backend_map = {
  39. "sqlite": "main",
  40. "mysql": "superset",
  41. "postgresql": "public",
  42. }
  43. def __init__(self, *args, **kwargs):
  44. super(SupersetTestCase, self).__init__(*args, **kwargs)
  45. self.maxDiff = None
  46. def create_app(self):
  47. return app
  48. @staticmethod
  49. def create_user(
  50. username: str,
  51. password: str,
  52. role_name: str,
  53. first_name: str = "admin",
  54. last_name: str = "user",
  55. email: str = "admin@fab.org",
  56. ) -> Union[ab_models.User, bool]:
  57. role_admin = security_manager.find_role(role_name)
  58. return security_manager.add_user(
  59. username, first_name, last_name, email, role_admin, password
  60. )
  61. @staticmethod
  62. def get_user(username: str) -> ab_models.User:
  63. user = (
  64. db.session.query(security_manager.user_model)
  65. .filter_by(username=username)
  66. .one_or_none()
  67. )
  68. return user
  69. @classmethod
  70. def create_druid_test_objects(cls):
  71. # create druid cluster and druid datasources
  72. with app.app_context():
  73. session = db.session
  74. cluster = (
  75. session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
  76. )
  77. if not cluster:
  78. cluster = DruidCluster(cluster_name="druid_test")
  79. session.add(cluster)
  80. session.commit()
  81. druid_datasource1 = DruidDatasource(
  82. datasource_name="druid_ds_1", cluster=cluster
  83. )
  84. session.add(druid_datasource1)
  85. druid_datasource2 = DruidDatasource(
  86. datasource_name="druid_ds_2", cluster=cluster
  87. )
  88. session.add(druid_datasource2)
  89. session.commit()
  90. def get_table(self, table_id):
  91. return db.session.query(SqlaTable).filter_by(id=table_id).one()
  92. @staticmethod
  93. def is_module_installed(module_name):
  94. try:
  95. imp.find_module(module_name)
  96. return True
  97. except ImportError:
  98. return False
  99. def get_or_create(self, cls, criteria, session, **kwargs):
  100. obj = session.query(cls).filter_by(**criteria).first()
  101. if not obj:
  102. obj = cls(**criteria)
  103. obj.__dict__.update(**kwargs)
  104. session.add(obj)
  105. session.commit()
  106. return obj
  107. def login(self, username="admin", password="general"):
  108. resp = self.get_resp("/login/", data=dict(username=username, password=password))
  109. self.assertNotIn("User confirmation needed", resp)
  110. def get_slice(self, slice_name, session):
  111. slc = session.query(Slice).filter_by(slice_name=slice_name).one()
  112. session.expunge_all()
  113. return slc
  114. def get_table_by_name(self, name):
  115. return db.session.query(SqlaTable).filter_by(table_name=name).one()
  116. def get_database_by_id(self, db_id):
  117. return db.session.query(Database).filter_by(id=db_id).one()
  118. def get_druid_ds_by_name(self, name):
  119. return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
  120. def get_datasource_mock(self):
  121. datasource = Mock()
  122. results = Mock()
  123. results.query = Mock()
  124. results.status = Mock()
  125. results.error_message = None
  126. results.df = pd.DataFrame()
  127. datasource.type = "table"
  128. datasource.query = Mock(return_value=results)
  129. mock_dttm_col = Mock()
  130. datasource.get_col = Mock(return_value=mock_dttm_col)
  131. datasource.query = Mock(return_value=results)
  132. datasource.database = Mock()
  133. datasource.database.db_engine_spec = Mock()
  134. datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
  135. return datasource
  136. def get_resp(
  137. self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
  138. ):
  139. """Shortcut to get the parsed results while following redirects"""
  140. if data:
  141. resp = self.client.post(url, data=data, follow_redirects=follow_redirects)
  142. elif json_:
  143. resp = self.client.post(url, json=json_, follow_redirects=follow_redirects)
  144. else:
  145. resp = self.client.get(url, follow_redirects=follow_redirects)
  146. if raise_on_error and resp.status_code > 400:
  147. raise Exception("http request failed with code {}".format(resp.status_code))
  148. return resp.data.decode("utf-8")
  149. def get_json_resp(
  150. self, url, data=None, follow_redirects=True, raise_on_error=True, json_=None
  151. ):
  152. """Shortcut to get the parsed results while following redirects"""
  153. resp = self.get_resp(url, data, follow_redirects, raise_on_error, json_)
  154. return json.loads(resp)
  155. def get_access_requests(self, username, ds_type, ds_id):
  156. DAR = DatasourceAccessRequest
  157. return (
  158. db.session.query(DAR)
  159. .filter(
  160. DAR.created_by == security_manager.find_user(username=username),
  161. DAR.datasource_type == ds_type,
  162. DAR.datasource_id == ds_id,
  163. )
  164. .first()
  165. )
  166. def logout(self):
  167. self.client.get("/logout/", follow_redirects=True)
  168. def grant_public_access_to_table(self, table):
  169. public_role = security_manager.find_role("Public")
  170. perms = db.session.query(ab_models.PermissionView).all()
  171. for perm in perms:
  172. if (
  173. perm.permission.name == "datasource_access"
  174. and perm.view_menu
  175. and table.perm in perm.view_menu.name
  176. ):
  177. security_manager.add_permission_role(public_role, perm)
  178. def revoke_public_access_to_table(self, table):
  179. public_role = security_manager.find_role("Public")
  180. perms = db.session.query(ab_models.PermissionView).all()
  181. for perm in perms:
  182. if (
  183. perm.permission.name == "datasource_access"
  184. and perm.view_menu
  185. and table.perm in perm.view_menu.name
  186. ):
  187. security_manager.del_permission_role(public_role, perm)
  188. def _get_database_by_name(self, database_name="main"):
  189. if database_name == "examples":
  190. return get_example_database()
  191. else:
  192. raise ValueError("Database doesn't exist")
  193. def run_sql(
  194. self,
  195. sql,
  196. client_id=None,
  197. user_name=None,
  198. raise_on_error=False,
  199. query_limit=None,
  200. database_name="examples",
  201. sql_editor_id=None,
  202. ):
  203. if user_name:
  204. self.logout()
  205. self.login(username=(user_name or "admin"))
  206. dbid = self._get_database_by_name(database_name).id
  207. resp = self.get_json_resp(
  208. "/superset/sql_json/",
  209. raise_on_error=False,
  210. json_=dict(
  211. database_id=dbid,
  212. sql=sql,
  213. select_as_create_as=False,
  214. client_id=client_id,
  215. queryLimit=query_limit,
  216. sql_editor_id=sql_editor_id,
  217. ),
  218. )
  219. if raise_on_error and "error" in resp:
  220. raise Exception("run_sql failed")
  221. return resp
  222. def create_fake_db(self):
  223. self.login(username="admin")
  224. database_name = FAKE_DB_NAME
  225. db_id = 100
  226. extra = """{
  227. "schemas_allowed_for_csv_upload":
  228. ["this_schema_is_allowed", "this_schema_is_allowed_too"]
  229. }"""
  230. return self.get_or_create(
  231. cls=models.Database,
  232. criteria={"database_name": database_name},
  233. session=db.session,
  234. sqlalchemy_uri="sqlite://test",
  235. id=db_id,
  236. extra=extra,
  237. )
  238. def delete_fake_db(self):
  239. database = (
  240. db.session.query(Database)
  241. .filter(Database.database_name == FAKE_DB_NAME)
  242. .scalar()
  243. )
  244. if database:
  245. db.session.delete(database)
  246. def validate_sql(
  247. self,
  248. sql,
  249. client_id=None,
  250. user_name=None,
  251. raise_on_error=False,
  252. database_name="examples",
  253. ):
  254. if user_name:
  255. self.logout()
  256. self.login(username=(user_name if user_name else "admin"))
  257. dbid = self._get_database_by_name(database_name).id
  258. resp = self.get_json_resp(
  259. "/superset/validate_sql_json/",
  260. raise_on_error=False,
  261. data=dict(database_id=dbid, sql=sql, client_id=client_id),
  262. )
  263. if raise_on_error and "error" in resp:
  264. raise Exception("validate_sql failed")
  265. return resp
  266. def get_dash_by_slug(self, dash_slug):
  267. sesh = db.session()
  268. return sesh.query(Dashboard).filter_by(slug=dash_slug).first()