dict_import_export_tests.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. # isort:skip_file
  18. """Unit tests for Superset"""
  19. import json
  20. import unittest
  21. import yaml
  22. from tests.test_app import app
  23. from superset import db
  24. from superset.connectors.druid.models import (
  25. DruidColumn,
  26. DruidDatasource,
  27. DruidMetric,
  28. DruidCluster,
  29. )
  30. from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
  31. from superset.utils.core import get_example_database
  32. from superset.utils.dict_import_export import export_to_dict
  33. from .base_tests import SupersetTestCase
  34. DBREF = "dict_import__export_test"
  35. NAME_PREFIX = "dict_"
  36. ID_PREFIX = 20000
  37. class DictImportExportTests(SupersetTestCase):
  38. """Testing export import functionality for dashboards"""
  39. def __init__(self, *args, **kwargs):
  40. super(DictImportExportTests, self).__init__(*args, **kwargs)
  41. @classmethod
  42. def delete_imports(cls):
  43. with app.app_context():
  44. # Imported data clean up
  45. session = db.session
  46. for table in session.query(SqlaTable):
  47. if DBREF in table.params_dict:
  48. session.delete(table)
  49. for datasource in session.query(DruidDatasource):
  50. if DBREF in datasource.params_dict:
  51. session.delete(datasource)
  52. session.commit()
  53. @classmethod
  54. def setUpClass(cls):
  55. cls.delete_imports()
  56. @classmethod
  57. def tearDownClass(cls):
  58. cls.delete_imports()
  59. def create_table(self, name, schema="", id=0, cols_names=[], metric_names=[]):
  60. database_name = "main"
  61. name = "{0}{1}".format(NAME_PREFIX, name)
  62. params = {DBREF: id, "database_name": database_name}
  63. dict_rep = {
  64. "database_id": get_example_database().id,
  65. "table_name": name,
  66. "schema": schema,
  67. "id": id,
  68. "params": json.dumps(params),
  69. "columns": [{"column_name": c} for c in cols_names],
  70. "metrics": [{"metric_name": c, "expression": ""} for c in metric_names],
  71. }
  72. table = SqlaTable(
  73. id=id, schema=schema, table_name=name, params=json.dumps(params)
  74. )
  75. for col_name in cols_names:
  76. table.columns.append(TableColumn(column_name=col_name))
  77. for metric_name in metric_names:
  78. table.metrics.append(SqlMetric(metric_name=metric_name, expression=""))
  79. return table, dict_rep
  80. def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
  81. cluster_name = "druid_test"
  82. cluster = self.get_or_create(
  83. DruidCluster, {"cluster_name": cluster_name}, db.session
  84. )
  85. name = "{0}{1}".format(NAME_PREFIX, name)
  86. params = {DBREF: id, "database_name": cluster_name}
  87. dict_rep = {
  88. "cluster_id": cluster.id,
  89. "datasource_name": name,
  90. "id": id,
  91. "params": json.dumps(params),
  92. "columns": [{"column_name": c} for c in cols_names],
  93. "metrics": [{"metric_name": c, "json": "{}"} for c in metric_names],
  94. }
  95. datasource = DruidDatasource(
  96. id=id,
  97. datasource_name=name,
  98. cluster_id=cluster.id,
  99. params=json.dumps(params),
  100. )
  101. for col_name in cols_names:
  102. datasource.columns.append(DruidColumn(column_name=col_name))
  103. for metric_name in metric_names:
  104. datasource.metrics.append(DruidMetric(metric_name=metric_name))
  105. return datasource, dict_rep
  106. def get_datasource(self, datasource_id):
  107. return db.session.query(DruidDatasource).filter_by(id=datasource_id).first()
  108. def get_table_by_name(self, name):
  109. return db.session.query(SqlaTable).filter_by(table_name=name).first()
  110. def yaml_compare(self, obj_1, obj_2):
  111. obj_1_str = yaml.safe_dump(obj_1, default_flow_style=False)
  112. obj_2_str = yaml.safe_dump(obj_2, default_flow_style=False)
  113. self.assertEqual(obj_1_str, obj_2_str)
  114. def assert_table_equals(self, expected_ds, actual_ds):
  115. self.assertEqual(expected_ds.table_name, actual_ds.table_name)
  116. self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
  117. self.assertEqual(expected_ds.schema, actual_ds.schema)
  118. self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
  119. self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
  120. self.assertEqual(
  121. set([c.column_name for c in expected_ds.columns]),
  122. set([c.column_name for c in actual_ds.columns]),
  123. )
  124. self.assertEqual(
  125. set([m.metric_name for m in expected_ds.metrics]),
  126. set([m.metric_name for m in actual_ds.metrics]),
  127. )
  128. def assert_datasource_equals(self, expected_ds, actual_ds):
  129. self.assertEqual(expected_ds.datasource_name, actual_ds.datasource_name)
  130. self.assertEqual(expected_ds.main_dttm_col, actual_ds.main_dttm_col)
  131. self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
  132. self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
  133. self.assertEqual(
  134. set([c.column_name for c in expected_ds.columns]),
  135. set([c.column_name for c in actual_ds.columns]),
  136. )
  137. self.assertEqual(
  138. set([m.metric_name for m in expected_ds.metrics]),
  139. set([m.metric_name for m in actual_ds.metrics]),
  140. )
  141. def test_import_table_no_metadata(self):
  142. table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
  143. new_table = SqlaTable.import_from_dict(db.session, dict_table)
  144. db.session.commit()
  145. imported_id = new_table.id
  146. imported = self.get_table(imported_id)
  147. self.assert_table_equals(table, imported)
  148. self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
  149. def test_import_table_1_col_1_met(self):
  150. table, dict_table = self.create_table(
  151. "table_1_col_1_met",
  152. id=ID_PREFIX + 2,
  153. cols_names=["col1"],
  154. metric_names=["metric1"],
  155. )
  156. imported_table = SqlaTable.import_from_dict(db.session, dict_table)
  157. db.session.commit()
  158. imported = self.get_table(imported_table.id)
  159. self.assert_table_equals(table, imported)
  160. self.assertEqual(
  161. {DBREF: ID_PREFIX + 2, "database_name": "main"}, json.loads(imported.params)
  162. )
  163. self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
  164. def test_import_table_2_col_2_met(self):
  165. table, dict_table = self.create_table(
  166. "table_2_col_2_met",
  167. id=ID_PREFIX + 3,
  168. cols_names=["c1", "c2"],
  169. metric_names=["m1", "m2"],
  170. )
  171. imported_table = SqlaTable.import_from_dict(db.session, dict_table)
  172. db.session.commit()
  173. imported = self.get_table(imported_table.id)
  174. self.assert_table_equals(table, imported)
  175. self.yaml_compare(table.export_to_dict(), imported.export_to_dict())
  176. def test_import_table_override_append(self):
  177. table, dict_table = self.create_table(
  178. "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
  179. )
  180. imported_table = SqlaTable.import_from_dict(db.session, dict_table)
  181. db.session.commit()
  182. table_over, dict_table_over = self.create_table(
  183. "table_override",
  184. id=ID_PREFIX + 3,
  185. cols_names=["new_col1", "col2", "col3"],
  186. metric_names=["new_metric1"],
  187. )
  188. imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
  189. db.session.commit()
  190. imported_over = self.get_table(imported_over_table.id)
  191. self.assertEqual(imported_table.id, imported_over.id)
  192. expected_table, _ = self.create_table(
  193. "table_override",
  194. id=ID_PREFIX + 3,
  195. metric_names=["new_metric1", "m1"],
  196. cols_names=["col1", "new_col1", "col2", "col3"],
  197. )
  198. self.assert_table_equals(expected_table, imported_over)
  199. self.yaml_compare(
  200. expected_table.export_to_dict(), imported_over.export_to_dict()
  201. )
  202. def test_import_table_override_sync(self):
  203. table, dict_table = self.create_table(
  204. "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
  205. )
  206. imported_table = SqlaTable.import_from_dict(db.session, dict_table)
  207. db.session.commit()
  208. table_over, dict_table_over = self.create_table(
  209. "table_override",
  210. id=ID_PREFIX + 3,
  211. cols_names=["new_col1", "col2", "col3"],
  212. metric_names=["new_metric1"],
  213. )
  214. imported_over_table = SqlaTable.import_from_dict(
  215. session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"]
  216. )
  217. db.session.commit()
  218. imported_over = self.get_table(imported_over_table.id)
  219. self.assertEqual(imported_table.id, imported_over.id)
  220. expected_table, _ = self.create_table(
  221. "table_override",
  222. id=ID_PREFIX + 3,
  223. metric_names=["new_metric1"],
  224. cols_names=["new_col1", "col2", "col3"],
  225. )
  226. self.assert_table_equals(expected_table, imported_over)
  227. self.yaml_compare(
  228. expected_table.export_to_dict(), imported_over.export_to_dict()
  229. )
  230. def test_import_table_override_identical(self):
  231. table, dict_table = self.create_table(
  232. "copy_cat",
  233. id=ID_PREFIX + 4,
  234. cols_names=["new_col1", "col2", "col3"],
  235. metric_names=["new_metric1"],
  236. )
  237. imported_table = SqlaTable.import_from_dict(db.session, dict_table)
  238. db.session.commit()
  239. copy_table, dict_copy_table = self.create_table(
  240. "copy_cat",
  241. id=ID_PREFIX + 4,
  242. cols_names=["new_col1", "col2", "col3"],
  243. metric_names=["new_metric1"],
  244. )
  245. imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
  246. db.session.commit()
  247. self.assertEqual(imported_table.id, imported_copy_table.id)
  248. self.assert_table_equals(copy_table, self.get_table(imported_table.id))
  249. self.yaml_compare(
  250. imported_copy_table.export_to_dict(), imported_table.export_to_dict()
  251. )
  252. def test_export_datasource_ui_cli(self):
  253. cli_export = export_to_dict(
  254. session=db.session,
  255. recursive=True,
  256. back_references=False,
  257. include_defaults=False,
  258. )
  259. self.get_resp("/login/", data=dict(username="admin", password="general"))
  260. resp = self.get_resp(
  261. "/databaseview/action_post", {"action": "yaml_export", "rowid": 1}
  262. )
  263. ui_export = yaml.safe_load(resp)
  264. self.assertEqual(
  265. ui_export["databases"][0]["database_name"],
  266. cli_export["databases"][0]["database_name"],
  267. )
  268. self.assertEqual(
  269. ui_export["databases"][0]["tables"], cli_export["databases"][0]["tables"]
  270. )
  271. def test_import_druid_no_metadata(self):
  272. datasource, dict_datasource = self.create_druid_datasource(
  273. "pure_druid", id=ID_PREFIX + 1
  274. )
  275. imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
  276. db.session.commit()
  277. imported = self.get_datasource(imported_cluster.id)
  278. self.assert_datasource_equals(datasource, imported)
  279. def test_import_druid_1_col_1_met(self):
  280. datasource, dict_datasource = self.create_druid_datasource(
  281. "druid_1_col_1_met",
  282. id=ID_PREFIX + 2,
  283. cols_names=["col1"],
  284. metric_names=["metric1"],
  285. )
  286. imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
  287. db.session.commit()
  288. imported = self.get_datasource(imported_cluster.id)
  289. self.assert_datasource_equals(datasource, imported)
  290. self.assertEqual(
  291. {DBREF: ID_PREFIX + 2, "database_name": "druid_test"},
  292. json.loads(imported.params),
  293. )
  294. def test_import_druid_2_col_2_met(self):
  295. datasource, dict_datasource = self.create_druid_datasource(
  296. "druid_2_col_2_met",
  297. id=ID_PREFIX + 3,
  298. cols_names=["c1", "c2"],
  299. metric_names=["m1", "m2"],
  300. )
  301. imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
  302. db.session.commit()
  303. imported = self.get_datasource(imported_cluster.id)
  304. self.assert_datasource_equals(datasource, imported)
  305. def test_import_druid_override_append(self):
  306. datasource, dict_datasource = self.create_druid_datasource(
  307. "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
  308. )
  309. imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
  310. db.session.commit()
  311. table_over, table_over_dict = self.create_druid_datasource(
  312. "druid_override",
  313. id=ID_PREFIX + 3,
  314. cols_names=["new_col1", "col2", "col3"],
  315. metric_names=["new_metric1"],
  316. )
  317. imported_over_cluster = DruidDatasource.import_from_dict(
  318. db.session, table_over_dict
  319. )
  320. db.session.commit()
  321. imported_over = self.get_datasource(imported_over_cluster.id)
  322. self.assertEqual(imported_cluster.id, imported_over.id)
  323. expected_datasource, _ = self.create_druid_datasource(
  324. "druid_override",
  325. id=ID_PREFIX + 3,
  326. metric_names=["new_metric1", "m1"],
  327. cols_names=["col1", "new_col1", "col2", "col3"],
  328. )
  329. self.assert_datasource_equals(expected_datasource, imported_over)
  330. def test_import_druid_override_sync(self):
  331. datasource, dict_datasource = self.create_druid_datasource(
  332. "druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
  333. )
  334. imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
  335. db.session.commit()
  336. table_over, table_over_dict = self.create_druid_datasource(
  337. "druid_override",
  338. id=ID_PREFIX + 3,
  339. cols_names=["new_col1", "col2", "col3"],
  340. metric_names=["new_metric1"],
  341. )
  342. imported_over_cluster = DruidDatasource.import_from_dict(
  343. session=db.session, dict_rep=table_over_dict, sync=["metrics", "columns"]
  344. ) # syncing metrics and columns
  345. db.session.commit()
  346. imported_over = self.get_datasource(imported_over_cluster.id)
  347. self.assertEqual(imported_cluster.id, imported_over.id)
  348. expected_datasource, _ = self.create_druid_datasource(
  349. "druid_override",
  350. id=ID_PREFIX + 3,
  351. metric_names=["new_metric1"],
  352. cols_names=["new_col1", "col2", "col3"],
  353. )
  354. self.assert_datasource_equals(expected_datasource, imported_over)
  355. def test_import_druid_override_identical(self):
  356. datasource, dict_datasource = self.create_druid_datasource(
  357. "copy_cat",
  358. id=ID_PREFIX + 4,
  359. cols_names=["new_col1", "col2", "col3"],
  360. metric_names=["new_metric1"],
  361. )
  362. imported = DruidDatasource.import_from_dict(
  363. session=db.session, dict_rep=dict_datasource
  364. )
  365. db.session.commit()
  366. copy_datasource, dict_cp_datasource = self.create_druid_datasource(
  367. "copy_cat",
  368. id=ID_PREFIX + 4,
  369. cols_names=["new_col1", "col2", "col3"],
  370. metric_names=["new_metric1"],
  371. )
  372. imported_copy = DruidDatasource.import_from_dict(db.session, dict_cp_datasource)
  373. db.session.commit()
  374. self.assertEqual(imported.id, imported_copy.id)
  375. self.assert_datasource_equals(copy_datasource, self.get_datasource(imported.id))
  376. if __name__ == "__main__":
  377. unittest.main()