druid_func_tests.py 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. # isort:skip_file
  18. import json
  19. import unittest
  20. from unittest.mock import Mock
  21. import tests.test_app
  22. import superset.connectors.druid.models as models
  23. from superset.connectors.druid.models import DruidColumn, DruidDatasource, DruidMetric
  24. from superset.exceptions import SupersetException
  25. from .base_tests import SupersetTestCase
  26. try:
  27. from pydruid.utils.dimensions import (
  28. MapLookupExtraction,
  29. RegexExtraction,
  30. RegisteredLookupExtraction,
  31. TimeFormatExtraction,
  32. )
  33. import pydruid.utils.postaggregator as postaggs
  34. except ImportError:
  35. pass
  36. def mock_metric(metric_name, is_postagg=False):
  37. metric = Mock()
  38. metric.metric_name = metric_name
  39. metric.metric_type = "postagg" if is_postagg else "metric"
  40. return metric
  41. def emplace(metrics_dict, metric_name, is_postagg=False):
  42. metrics_dict[metric_name] = mock_metric(metric_name, is_postagg)
  43. # Unit tests that can be run without initializing base tests
  44. class DruidFuncTestCase(SupersetTestCase):
  45. @unittest.skipUnless(
  46. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  47. )
  48. def test_get_filters_extraction_fn_map(self):
  49. filters = [{"col": "deviceName", "val": ["iPhone X"], "op": "in"}]
  50. dimension_spec = {
  51. "type": "extraction",
  52. "dimension": "device",
  53. "outputName": "deviceName",
  54. "outputType": "STRING",
  55. "extractionFn": {
  56. "type": "lookup",
  57. "dimension": "dimensionName",
  58. "outputName": "dimensionOutputName",
  59. "replaceMissingValueWith": "missing_value",
  60. "retainMissingValue": False,
  61. "lookup": {
  62. "type": "map",
  63. "map": {
  64. "iPhone10,1": "iPhone 8",
  65. "iPhone10,4": "iPhone 8",
  66. "iPhone10,2": "iPhone 8 Plus",
  67. "iPhone10,5": "iPhone 8 Plus",
  68. "iPhone10,3": "iPhone X",
  69. "iPhone10,6": "iPhone X",
  70. },
  71. "isOneToOne": False,
  72. },
  73. },
  74. }
  75. spec_json = json.dumps(dimension_spec)
  76. col = DruidColumn(column_name="deviceName", dimension_spec_json=spec_json)
  77. column_dict = {"deviceName": col}
  78. f = DruidDatasource.get_filters(filters, [], column_dict)
  79. assert isinstance(f.extraction_function, MapLookupExtraction)
  80. dim_ext_fn = dimension_spec["extractionFn"]
  81. f_ext_fn = f.extraction_function
  82. self.assertEqual(dim_ext_fn["lookup"]["map"], f_ext_fn._mapping)
  83. self.assertEqual(dim_ext_fn["lookup"]["isOneToOne"], f_ext_fn._injective)
  84. self.assertEqual(
  85. dim_ext_fn["replaceMissingValueWith"], f_ext_fn._replace_missing_values
  86. )
  87. self.assertEqual(
  88. dim_ext_fn["retainMissingValue"], f_ext_fn._retain_missing_values
  89. )
  90. @unittest.skipUnless(
  91. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  92. )
  93. def test_get_filters_extraction_fn_regex(self):
  94. filters = [{"col": "buildPrefix", "val": ["22B"], "op": "in"}]
  95. dimension_spec = {
  96. "type": "extraction",
  97. "dimension": "build",
  98. "outputName": "buildPrefix",
  99. "outputType": "STRING",
  100. "extractionFn": {"type": "regex", "expr": "(^[0-9A-Za-z]{3})"},
  101. }
  102. spec_json = json.dumps(dimension_spec)
  103. col = DruidColumn(column_name="buildPrefix", dimension_spec_json=spec_json)
  104. column_dict = {"buildPrefix": col}
  105. f = DruidDatasource.get_filters(filters, [], column_dict)
  106. assert isinstance(f.extraction_function, RegexExtraction)
  107. dim_ext_fn = dimension_spec["extractionFn"]
  108. f_ext_fn = f.extraction_function
  109. self.assertEqual(dim_ext_fn["expr"], f_ext_fn._expr)
  110. @unittest.skipUnless(
  111. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  112. )
  113. def test_get_filters_extraction_fn_registered_lookup_extraction(self):
  114. filters = [{"col": "country", "val": ["Spain"], "op": "in"}]
  115. dimension_spec = {
  116. "type": "extraction",
  117. "dimension": "country_name",
  118. "outputName": "country",
  119. "outputType": "STRING",
  120. "extractionFn": {"type": "registeredLookup", "lookup": "country_name"},
  121. }
  122. spec_json = json.dumps(dimension_spec)
  123. col = DruidColumn(column_name="country", dimension_spec_json=spec_json)
  124. column_dict = {"country": col}
  125. f = DruidDatasource.get_filters(filters, [], column_dict)
  126. assert isinstance(f.extraction_function, RegisteredLookupExtraction)
  127. dim_ext_fn = dimension_spec["extractionFn"]
  128. self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type)
  129. self.assertEqual(dim_ext_fn["lookup"], f.extraction_function._lookup)
  130. @unittest.skipUnless(
  131. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  132. )
  133. def test_get_filters_extraction_fn_time_format(self):
  134. filters = [{"col": "dayOfMonth", "val": ["1", "20"], "op": "in"}]
  135. dimension_spec = {
  136. "type": "extraction",
  137. "dimension": "__time",
  138. "outputName": "dayOfMonth",
  139. "extractionFn": {
  140. "type": "timeFormat",
  141. "format": "d",
  142. "timeZone": "Asia/Kolkata",
  143. "locale": "en",
  144. },
  145. }
  146. spec_json = json.dumps(dimension_spec)
  147. col = DruidColumn(column_name="dayOfMonth", dimension_spec_json=spec_json)
  148. column_dict = {"dayOfMonth": col}
  149. f = DruidDatasource.get_filters(filters, [], column_dict)
  150. assert isinstance(f.extraction_function, TimeFormatExtraction)
  151. dim_ext_fn = dimension_spec["extractionFn"]
  152. self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type)
  153. self.assertEqual(dim_ext_fn["format"], f.extraction_function._format)
  154. self.assertEqual(dim_ext_fn["timeZone"], f.extraction_function._time_zone)
  155. self.assertEqual(dim_ext_fn["locale"], f.extraction_function._locale)
  156. @unittest.skipUnless(
  157. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  158. )
  159. def test_get_filters_ignores_invalid_filter_objects(self):
  160. filtr = {"col": "col1", "op": "=="}
  161. filters = [filtr]
  162. col = DruidColumn(column_name="col1")
  163. column_dict = {"col1": col}
  164. self.assertIsNone(DruidDatasource.get_filters(filters, [], column_dict))
  165. @unittest.skipUnless(
  166. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  167. )
  168. def test_get_filters_constructs_filter_in(self):
  169. filtr = {"col": "A", "op": "in", "val": ["a", "b", "c"]}
  170. col = DruidColumn(column_name="A")
  171. column_dict = {"A": col}
  172. res = DruidDatasource.get_filters([filtr], [], column_dict)
  173. self.assertIn("filter", res.filter)
  174. self.assertIn("fields", res.filter["filter"])
  175. self.assertEqual("or", res.filter["filter"]["type"])
  176. self.assertEqual(3, len(res.filter["filter"]["fields"]))
  177. @unittest.skipUnless(
  178. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  179. )
  180. def test_get_filters_constructs_filter_not_in(self):
  181. filtr = {"col": "A", "op": "not in", "val": ["a", "b", "c"]}
  182. col = DruidColumn(column_name="A")
  183. column_dict = {"A": col}
  184. res = DruidDatasource.get_filters([filtr], [], column_dict)
  185. self.assertIn("filter", res.filter)
  186. self.assertIn("type", res.filter["filter"])
  187. self.assertEqual("not", res.filter["filter"]["type"])
  188. self.assertIn("field", res.filter["filter"])
  189. self.assertEqual(
  190. 3, len(res.filter["filter"]["field"].filter["filter"]["fields"])
  191. )
  192. @unittest.skipUnless(
  193. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  194. )
  195. def test_get_filters_constructs_filter_equals(self):
  196. filtr = {"col": "A", "op": "==", "val": "h"}
  197. col = DruidColumn(column_name="A")
  198. column_dict = {"A": col}
  199. res = DruidDatasource.get_filters([filtr], [], column_dict)
  200. self.assertEqual("selector", res.filter["filter"]["type"])
  201. self.assertEqual("A", res.filter["filter"]["dimension"])
  202. self.assertEqual("h", res.filter["filter"]["value"])
  203. @unittest.skipUnless(
  204. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  205. )
  206. def test_get_filters_constructs_filter_not_equals(self):
  207. filtr = {"col": "A", "op": "!=", "val": "h"}
  208. col = DruidColumn(column_name="A")
  209. column_dict = {"A": col}
  210. res = DruidDatasource.get_filters([filtr], [], column_dict)
  211. self.assertEqual("not", res.filter["filter"]["type"])
  212. self.assertEqual("h", res.filter["filter"]["field"].filter["filter"]["value"])
  213. @unittest.skipUnless(
  214. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  215. )
  216. def test_get_filters_constructs_bounds_filter(self):
  217. filtr = {"col": "A", "op": ">=", "val": "h"}
  218. col = DruidColumn(column_name="A")
  219. column_dict = {"A": col}
  220. res = DruidDatasource.get_filters([filtr], [], column_dict)
  221. self.assertFalse(res.filter["filter"]["lowerStrict"])
  222. self.assertEqual("A", res.filter["filter"]["dimension"])
  223. self.assertEqual("h", res.filter["filter"]["lower"])
  224. self.assertEqual("lexicographic", res.filter["filter"]["ordering"])
  225. filtr["op"] = ">"
  226. res = DruidDatasource.get_filters([filtr], [], column_dict)
  227. self.assertTrue(res.filter["filter"]["lowerStrict"])
  228. filtr["op"] = "<="
  229. res = DruidDatasource.get_filters([filtr], [], column_dict)
  230. self.assertFalse(res.filter["filter"]["upperStrict"])
  231. self.assertEqual("h", res.filter["filter"]["upper"])
  232. filtr["op"] = "<"
  233. res = DruidDatasource.get_filters([filtr], [], column_dict)
  234. self.assertTrue(res.filter["filter"]["upperStrict"])
  235. filtr["val"] = 1
  236. res = DruidDatasource.get_filters([filtr], ["A"], column_dict)
  237. self.assertEqual("numeric", res.filter["filter"]["ordering"])
  238. @unittest.skipUnless(
  239. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  240. )
  241. def test_get_filters_is_null_filter(self):
  242. filtr = {"col": "A", "op": "IS NULL"}
  243. col = DruidColumn(column_name="A")
  244. column_dict = {"A": col}
  245. res = DruidDatasource.get_filters([filtr], [], column_dict)
  246. self.assertEqual("selector", res.filter["filter"]["type"])
  247. self.assertEqual("", res.filter["filter"]["value"])
  248. @unittest.skipUnless(
  249. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  250. )
  251. def test_get_filters_is_not_null_filter(self):
  252. filtr = {"col": "A", "op": "IS NOT NULL"}
  253. col = DruidColumn(column_name="A")
  254. column_dict = {"A": col}
  255. res = DruidDatasource.get_filters([filtr], [], column_dict)
  256. self.assertEqual("not", res.filter["filter"]["type"])
  257. self.assertIn("field", res.filter["filter"])
  258. self.assertEqual(
  259. "selector", res.filter["filter"]["field"].filter["filter"]["type"]
  260. )
  261. self.assertEqual("", res.filter["filter"]["field"].filter["filter"]["value"])
  262. @unittest.skipUnless(
  263. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  264. )
  265. def test_get_filters_constructs_regex_filter(self):
  266. filtr = {"col": "A", "op": "regex", "val": "[abc]"}
  267. col = DruidColumn(column_name="A")
  268. column_dict = {"A": col}
  269. res = DruidDatasource.get_filters([filtr], [], column_dict)
  270. self.assertEqual("regex", res.filter["filter"]["type"])
  271. self.assertEqual("[abc]", res.filter["filter"]["pattern"])
  272. self.assertEqual("A", res.filter["filter"]["dimension"])
  273. @unittest.skipUnless(
  274. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  275. )
  276. def test_get_filters_composes_multiple_filters(self):
  277. filtr1 = {"col": "A", "op": "!=", "val": "y"}
  278. filtr2 = {"col": "B", "op": "in", "val": ["a", "b", "c"]}
  279. cola = DruidColumn(column_name="A")
  280. colb = DruidColumn(column_name="B")
  281. column_dict = {"A": cola, "B": colb}
  282. res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict)
  283. self.assertEqual("and", res.filter["filter"]["type"])
  284. self.assertEqual(2, len(res.filter["filter"]["fields"]))
  285. @unittest.skipUnless(
  286. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  287. )
  288. def test_get_filters_ignores_in_not_in_with_empty_value(self):
  289. filtr1 = {"col": "A", "op": "in", "val": []}
  290. filtr2 = {"col": "A", "op": "not in", "val": []}
  291. col = DruidColumn(column_name="A")
  292. column_dict = {"A": col}
  293. res = DruidDatasource.get_filters([filtr1, filtr2], [], column_dict)
  294. self.assertIsNone(res)
  295. @unittest.skipUnless(
  296. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  297. )
  298. def test_get_filters_constructs_equals_for_in_not_in_single_value(self):
  299. filtr = {"col": "A", "op": "in", "val": ["a"]}
  300. cola = DruidColumn(column_name="A")
  301. colb = DruidColumn(column_name="B")
  302. column_dict = {"A": cola, "B": colb}
  303. res = DruidDatasource.get_filters([filtr], [], column_dict)
  304. self.assertEqual("selector", res.filter["filter"]["type"])
  305. @unittest.skipUnless(
  306. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  307. )
  308. def test_get_filters_handles_arrays_for_string_types(self):
  309. filtr = {"col": "A", "op": "==", "val": ["a", "b"]}
  310. col = DruidColumn(column_name="A")
  311. column_dict = {"A": col}
  312. res = DruidDatasource.get_filters([filtr], [], column_dict)
  313. self.assertEqual("a", res.filter["filter"]["value"])
  314. filtr = {"col": "A", "op": "==", "val": []}
  315. res = DruidDatasource.get_filters([filtr], [], column_dict)
  316. self.assertIsNone(res.filter["filter"]["value"])
  317. @unittest.skipUnless(
  318. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  319. )
  320. def test_get_filters_handles_none_for_string_types(self):
  321. filtr = {"col": "A", "op": "==", "val": None}
  322. col = DruidColumn(column_name="A")
  323. column_dict = {"A": col}
  324. res = DruidDatasource.get_filters([filtr], [], column_dict)
  325. self.assertIsNone(res)
  326. @unittest.skipUnless(
  327. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  328. )
  329. def test_get_filters_extracts_values_in_quotes(self):
  330. filtr = {"col": "A", "op": "in", "val": ['"a"']}
  331. col = DruidColumn(column_name="A")
  332. column_dict = {"A": col}
  333. res = DruidDatasource.get_filters([filtr], [], column_dict)
  334. self.assertEqual("a", res.filter["filter"]["value"])
  335. @unittest.skipUnless(
  336. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  337. )
  338. def test_get_filters_keeps_trailing_spaces(self):
  339. filtr = {"col": "A", "op": "in", "val": ["a "]}
  340. col = DruidColumn(column_name="A")
  341. column_dict = {"A": col}
  342. res = DruidDatasource.get_filters([filtr], [], column_dict)
  343. self.assertEqual("a ", res.filter["filter"]["value"])
  344. @unittest.skipUnless(
  345. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  346. )
  347. def test_get_filters_converts_strings_to_num(self):
  348. filtr = {"col": "A", "op": "in", "val": ["6"]}
  349. col = DruidColumn(column_name="A")
  350. column_dict = {"A": col}
  351. res = DruidDatasource.get_filters([filtr], ["A"], column_dict)
  352. self.assertEqual(6, res.filter["filter"]["value"])
  353. filtr = {"col": "A", "op": "==", "val": "6"}
  354. res = DruidDatasource.get_filters([filtr], ["A"], column_dict)
  355. self.assertEqual(6, res.filter["filter"]["value"])
  356. @unittest.skipUnless(
  357. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  358. )
  359. def test_run_query_no_groupby(self):
  360. client = Mock()
  361. from_dttm = Mock()
  362. to_dttm = Mock()
  363. from_dttm.replace = Mock(return_value=from_dttm)
  364. to_dttm.replace = Mock(return_value=to_dttm)
  365. from_dttm.isoformat = Mock(return_value="from")
  366. to_dttm.isoformat = Mock(return_value="to")
  367. timezone = "timezone"
  368. from_dttm.tzname = Mock(return_value=timezone)
  369. ds = DruidDatasource(datasource_name="datasource")
  370. metric1 = DruidMetric(metric_name="metric1")
  371. metric2 = DruidMetric(metric_name="metric2")
  372. ds.metrics = [metric1, metric2]
  373. col1 = DruidColumn(column_name="col1")
  374. col2 = DruidColumn(column_name="col2")
  375. ds.columns = [col1, col2]
  376. aggs = []
  377. post_aggs = ["some_agg"]
  378. ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
  379. groupby = []
  380. metrics = ["metric1"]
  381. ds.get_having_filters = Mock(return_value=[])
  382. client.query_builder = Mock()
  383. client.query_builder.last_query = Mock()
  384. client.query_builder.last_query.query_dict = {"mock": 0}
  385. # no groupby calls client.timeseries
  386. ds.run_query(
  387. groupby,
  388. metrics,
  389. None,
  390. from_dttm,
  391. to_dttm,
  392. client=client,
  393. filter=[],
  394. row_limit=100,
  395. )
  396. self.assertEqual(0, len(client.topn.call_args_list))
  397. self.assertEqual(0, len(client.groupby.call_args_list))
  398. self.assertEqual(1, len(client.timeseries.call_args_list))
  399. # check that there is no dimensions entry
  400. called_args = client.timeseries.call_args_list[0][1]
  401. self.assertNotIn("dimensions", called_args)
  402. self.assertIn("post_aggregations", called_args)
  403. # restore functions
  404. @unittest.skipUnless(
  405. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  406. )
  407. def test_run_query_with_adhoc_metric(self):
  408. client = Mock()
  409. from_dttm = Mock()
  410. to_dttm = Mock()
  411. from_dttm.replace = Mock(return_value=from_dttm)
  412. to_dttm.replace = Mock(return_value=to_dttm)
  413. from_dttm.isoformat = Mock(return_value="from")
  414. to_dttm.isoformat = Mock(return_value="to")
  415. timezone = "timezone"
  416. from_dttm.tzname = Mock(return_value=timezone)
  417. ds = DruidDatasource(datasource_name="datasource")
  418. metric1 = DruidMetric(metric_name="metric1")
  419. metric2 = DruidMetric(metric_name="metric2")
  420. ds.metrics = [metric1, metric2]
  421. col1 = DruidColumn(column_name="col1")
  422. col2 = DruidColumn(column_name="col2")
  423. ds.columns = [col1, col2]
  424. all_metrics = []
  425. post_aggs = ["some_agg"]
  426. ds._metrics_and_post_aggs = Mock(return_value=(all_metrics, post_aggs))
  427. groupby = []
  428. metrics = [
  429. {
  430. "expressionType": "SIMPLE",
  431. "column": {"type": "DOUBLE", "column_name": "col1"},
  432. "aggregate": "SUM",
  433. "label": "My Adhoc Metric",
  434. }
  435. ]
  436. ds.get_having_filters = Mock(return_value=[])
  437. client.query_builder = Mock()
  438. client.query_builder.last_query = Mock()
  439. client.query_builder.last_query.query_dict = {"mock": 0}
  440. # no groupby calls client.timeseries
  441. ds.run_query(
  442. groupby,
  443. metrics,
  444. None,
  445. from_dttm,
  446. to_dttm,
  447. client=client,
  448. filter=[],
  449. row_limit=100,
  450. )
  451. self.assertEqual(0, len(client.topn.call_args_list))
  452. self.assertEqual(0, len(client.groupby.call_args_list))
  453. self.assertEqual(1, len(client.timeseries.call_args_list))
  454. # check that there is no dimensions entry
  455. called_args = client.timeseries.call_args_list[0][1]
  456. self.assertNotIn("dimensions", called_args)
  457. self.assertIn("post_aggregations", called_args)
  458. # restore functions
  459. @unittest.skipUnless(
  460. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  461. )
  462. def test_run_query_single_groupby(self):
  463. client = Mock()
  464. from_dttm = Mock()
  465. to_dttm = Mock()
  466. from_dttm.replace = Mock(return_value=from_dttm)
  467. to_dttm.replace = Mock(return_value=to_dttm)
  468. from_dttm.isoformat = Mock(return_value="from")
  469. to_dttm.isoformat = Mock(return_value="to")
  470. timezone = "timezone"
  471. from_dttm.tzname = Mock(return_value=timezone)
  472. ds = DruidDatasource(datasource_name="datasource")
  473. metric1 = DruidMetric(metric_name="metric1")
  474. metric2 = DruidMetric(metric_name="metric2")
  475. ds.metrics = [metric1, metric2]
  476. col1 = DruidColumn(column_name="col1")
  477. col2 = DruidColumn(column_name="col2")
  478. ds.columns = [col1, col2]
  479. aggs = ["metric1"]
  480. post_aggs = ["some_agg"]
  481. ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
  482. groupby = ["col1"]
  483. metrics = ["metric1"]
  484. ds.get_having_filters = Mock(return_value=[])
  485. client.query_builder.last_query.query_dict = {"mock": 0}
  486. # client.topn is called twice
  487. ds.run_query(
  488. groupby,
  489. metrics,
  490. None,
  491. from_dttm,
  492. to_dttm,
  493. timeseries_limit=100,
  494. client=client,
  495. order_desc=True,
  496. filter=[],
  497. )
  498. self.assertEqual(2, len(client.topn.call_args_list))
  499. self.assertEqual(0, len(client.groupby.call_args_list))
  500. self.assertEqual(0, len(client.timeseries.call_args_list))
  501. # check that there is no dimensions entry
  502. called_args_pre = client.topn.call_args_list[0][1]
  503. self.assertNotIn("dimensions", called_args_pre)
  504. self.assertIn("dimension", called_args_pre)
  505. called_args = client.topn.call_args_list[1][1]
  506. self.assertIn("dimension", called_args)
  507. self.assertEqual("col1", called_args["dimension"])
  508. # not order_desc
  509. client = Mock()
  510. client.query_builder.last_query.query_dict = {"mock": 0}
  511. ds.run_query(
  512. groupby,
  513. metrics,
  514. None,
  515. from_dttm,
  516. to_dttm,
  517. client=client,
  518. order_desc=False,
  519. filter=[],
  520. row_limit=100,
  521. )
  522. self.assertEqual(0, len(client.topn.call_args_list))
  523. self.assertEqual(1, len(client.groupby.call_args_list))
  524. self.assertEqual(0, len(client.timeseries.call_args_list))
  525. self.assertIn("dimensions", client.groupby.call_args_list[0][1])
  526. self.assertEqual(["col1"], client.groupby.call_args_list[0][1]["dimensions"])
  527. # order_desc but timeseries and dimension spec
  528. # calls topn with single dimension spec 'dimension'
  529. spec = {"outputName": "hello", "dimension": "matcho"}
  530. spec_json = json.dumps(spec)
  531. col3 = DruidColumn(column_name="col3", dimension_spec_json=spec_json)
  532. ds.columns.append(col3)
  533. groupby = ["col3"]
  534. client = Mock()
  535. client.query_builder.last_query.query_dict = {"mock": 0}
  536. ds.run_query(
  537. groupby,
  538. metrics,
  539. None,
  540. from_dttm,
  541. to_dttm,
  542. client=client,
  543. order_desc=True,
  544. timeseries_limit=5,
  545. filter=[],
  546. row_limit=100,
  547. )
  548. self.assertEqual(2, len(client.topn.call_args_list))
  549. self.assertEqual(0, len(client.groupby.call_args_list))
  550. self.assertEqual(0, len(client.timeseries.call_args_list))
  551. self.assertIn("dimension", client.topn.call_args_list[0][1])
  552. self.assertIn("dimension", client.topn.call_args_list[1][1])
  553. # uses dimension for pre query and full spec for final query
  554. self.assertEqual("matcho", client.topn.call_args_list[0][1]["dimension"])
  555. self.assertEqual(spec, client.topn.call_args_list[1][1]["dimension"])
  556. @unittest.skipUnless(
  557. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  558. )
  559. def test_run_query_multiple_groupby(self):
  560. client = Mock()
  561. from_dttm = Mock()
  562. to_dttm = Mock()
  563. from_dttm.replace = Mock(return_value=from_dttm)
  564. to_dttm.replace = Mock(return_value=to_dttm)
  565. from_dttm.isoformat = Mock(return_value="from")
  566. to_dttm.isoformat = Mock(return_value="to")
  567. timezone = "timezone"
  568. from_dttm.tzname = Mock(return_value=timezone)
  569. ds = DruidDatasource(datasource_name="datasource")
  570. metric1 = DruidMetric(metric_name="metric1")
  571. metric2 = DruidMetric(metric_name="metric2")
  572. ds.metrics = [metric1, metric2]
  573. col1 = DruidColumn(column_name="col1")
  574. col2 = DruidColumn(column_name="col2")
  575. ds.columns = [col1, col2]
  576. aggs = []
  577. post_aggs = ["some_agg"]
  578. ds._metrics_and_post_aggs = Mock(return_value=(aggs, post_aggs))
  579. groupby = ["col1", "col2"]
  580. metrics = ["metric1"]
  581. ds.get_having_filters = Mock(return_value=[])
  582. client.query_builder = Mock()
  583. client.query_builder.last_query = Mock()
  584. client.query_builder.last_query.query_dict = {"mock": 0}
  585. # no groupby calls client.timeseries
  586. ds.run_query(
  587. groupby,
  588. metrics,
  589. None,
  590. from_dttm,
  591. to_dttm,
  592. client=client,
  593. row_limit=100,
  594. filter=[],
  595. )
  596. self.assertEqual(0, len(client.topn.call_args_list))
  597. self.assertEqual(1, len(client.groupby.call_args_list))
  598. self.assertEqual(0, len(client.timeseries.call_args_list))
  599. # check that there is no dimensions entry
  600. called_args = client.groupby.call_args_list[0][1]
  601. self.assertIn("dimensions", called_args)
  602. self.assertEqual(["col1", "col2"], called_args["dimensions"])
  603. @unittest.skipUnless(
  604. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  605. )
  606. def test_get_post_agg_returns_correct_agg_type(self):
  607. get_post_agg = DruidDatasource.get_post_agg
  608. # javascript PostAggregators
  609. function = "function(field1, field2) { return field1 + field2; }"
  610. conf = {
  611. "type": "javascript",
  612. "name": "postagg_name",
  613. "fieldNames": ["field1", "field2"],
  614. "function": function,
  615. }
  616. postagg = get_post_agg(conf)
  617. self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator))
  618. self.assertEqual(postagg.name, "postagg_name")
  619. self.assertEqual(postagg.post_aggregator["type"], "javascript")
  620. self.assertEqual(postagg.post_aggregator["fieldNames"], ["field1", "field2"])
  621. self.assertEqual(postagg.post_aggregator["name"], "postagg_name")
  622. self.assertEqual(postagg.post_aggregator["function"], function)
  623. # Quantile
  624. conf = {"type": "quantile", "name": "postagg_name", "probability": "0.5"}
  625. postagg = get_post_agg(conf)
  626. self.assertTrue(isinstance(postagg, postaggs.Quantile))
  627. self.assertEqual(postagg.name, "postagg_name")
  628. self.assertEqual(postagg.post_aggregator["probability"], "0.5")
  629. # Quantiles
  630. conf = {
  631. "type": "quantiles",
  632. "name": "postagg_name",
  633. "probabilities": "0.4,0.5,0.6",
  634. }
  635. postagg = get_post_agg(conf)
  636. self.assertTrue(isinstance(postagg, postaggs.Quantiles))
  637. self.assertEqual(postagg.name, "postagg_name")
  638. self.assertEqual(postagg.post_aggregator["probabilities"], "0.4,0.5,0.6")
  639. # FieldAccess
  640. conf = {"type": "fieldAccess", "name": "field_name"}
  641. postagg = get_post_agg(conf)
  642. self.assertTrue(isinstance(postagg, postaggs.Field))
  643. self.assertEqual(postagg.name, "field_name")
  644. # constant
  645. conf = {"type": "constant", "value": 1234, "name": "postagg_name"}
  646. postagg = get_post_agg(conf)
  647. self.assertTrue(isinstance(postagg, postaggs.Const))
  648. self.assertEqual(postagg.name, "postagg_name")
  649. self.assertEqual(postagg.post_aggregator["value"], 1234)
  650. # hyperUniqueCardinality
  651. conf = {"type": "hyperUniqueCardinality", "name": "unique_name"}
  652. postagg = get_post_agg(conf)
  653. self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality))
  654. self.assertEqual(postagg.name, "unique_name")
  655. # arithmetic
  656. conf = {
  657. "type": "arithmetic",
  658. "fn": "+",
  659. "fields": ["field1", "field2"],
  660. "name": "postagg_name",
  661. }
  662. postagg = get_post_agg(conf)
  663. self.assertTrue(isinstance(postagg, postaggs.Postaggregator))
  664. self.assertEqual(postagg.name, "postagg_name")
  665. self.assertEqual(postagg.post_aggregator["fn"], "+")
  666. self.assertEqual(postagg.post_aggregator["fields"], ["field1", "field2"])
  667. # custom post aggregator
  668. conf = {"type": "custom", "name": "custom_name", "stuff": "more_stuff"}
  669. postagg = get_post_agg(conf)
  670. self.assertTrue(isinstance(postagg, models.CustomPostAggregator))
  671. self.assertEqual(postagg.name, "custom_name")
  672. self.assertEqual(postagg.post_aggregator["stuff"], "more_stuff")
  673. @unittest.skipUnless(
  674. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  675. )
  676. def test_find_postaggs_for_returns_postaggs_and_removes(self):
  677. find_postaggs_for = DruidDatasource.find_postaggs_for
  678. postagg_names = set(["pa2", "pa3", "pa4", "m1", "m2", "m3", "m4"])
  679. metrics = {}
  680. for i in range(1, 6):
  681. emplace(metrics, "pa" + str(i), True)
  682. emplace(metrics, "m" + str(i), False)
  683. postagg_list = find_postaggs_for(postagg_names, metrics)
  684. self.assertEqual(3, len(postagg_list))
  685. self.assertEqual(4, len(postagg_names))
  686. expected_metrics = ["m1", "m2", "m3", "m4"]
  687. expected_postaggs = set(["pa2", "pa3", "pa4"])
  688. for postagg in postagg_list:
  689. expected_postaggs.remove(postagg.metric_name)
  690. for metric in expected_metrics:
  691. postagg_names.remove(metric)
  692. self.assertEqual(0, len(expected_postaggs))
  693. self.assertEqual(0, len(postagg_names))
  694. @unittest.skipUnless(
  695. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  696. )
  697. def test_recursive_get_fields(self):
  698. conf = {
  699. "type": "quantile",
  700. "fieldName": "f1",
  701. "field": {
  702. "type": "custom",
  703. "fields": [
  704. {"type": "fieldAccess", "fieldName": "f2"},
  705. {"type": "fieldAccess", "fieldName": "f3"},
  706. {
  707. "type": "quantiles",
  708. "fieldName": "f4",
  709. "field": {"type": "custom"},
  710. },
  711. {
  712. "type": "custom",
  713. "fields": [
  714. {"type": "fieldAccess", "fieldName": "f5"},
  715. {
  716. "type": "fieldAccess",
  717. "fieldName": "f2",
  718. "fields": [
  719. {"type": "fieldAccess", "fieldName": "f3"},
  720. {"type": "fieldIgnoreMe", "fieldName": "f6"},
  721. ],
  722. },
  723. ],
  724. },
  725. ],
  726. },
  727. }
  728. fields = DruidDatasource.recursive_get_fields(conf)
  729. expected = set(["f1", "f2", "f3", "f4", "f5"])
  730. self.assertEqual(5, len(fields))
  731. for field in fields:
  732. expected.remove(field)
  733. self.assertEqual(0, len(expected))
  734. @unittest.skipUnless(
  735. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  736. )
  737. def test_metrics_and_post_aggs_tree(self):
  738. metrics = ["A", "B", "m1", "m2"]
  739. metrics_dict = {}
  740. for i in range(ord("A"), ord("K") + 1):
  741. emplace(metrics_dict, chr(i), True)
  742. for i in range(1, 10):
  743. emplace(metrics_dict, "m" + str(i), False)
  744. def depends_on(index, fields):
  745. dependents = fields if isinstance(fields, list) else [fields]
  746. metrics_dict[index].json_obj = {"fieldNames": dependents}
  747. depends_on("A", ["m1", "D", "C"])
  748. depends_on("B", ["B", "C", "E", "F", "m3"])
  749. depends_on("C", ["H", "I"])
  750. depends_on("D", ["m2", "m5", "G", "C"])
  751. depends_on("E", ["H", "I", "J"])
  752. depends_on("F", ["J", "m5"])
  753. depends_on("G", ["m4", "m7", "m6", "A"])
  754. depends_on("H", ["A", "m4", "I"])
  755. depends_on("I", ["H", "K"])
  756. depends_on("J", "K")
  757. depends_on("K", ["m8", "m9"])
  758. aggs, postaggs = DruidDatasource.metrics_and_post_aggs(metrics, metrics_dict)
  759. expected_metrics = set(aggs.keys())
  760. self.assertEqual(9, len(aggs))
  761. for i in range(1, 10):
  762. expected_metrics.remove("m" + str(i))
  763. self.assertEqual(0, len(expected_metrics))
  764. self.assertEqual(11, len(postaggs))
  765. for i in range(ord("A"), ord("K") + 1):
  766. del postaggs[chr(i)]
  767. self.assertEqual(0, len(postaggs))
  768. @unittest.skipUnless(
  769. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  770. )
  771. def test_metrics_and_post_aggs(self):
  772. """
  773. Test generation of metrics and post-aggregations from an initial list
  774. of superset metrics (which may include the results of either). This
  775. primarily tests that specifying a post-aggregator metric will also
  776. require the raw aggregation of the associated druid metric column.
  777. """
  778. metrics_dict = {
  779. "unused_count": DruidMetric(
  780. metric_name="unused_count",
  781. verbose_name="COUNT(*)",
  782. metric_type="count",
  783. json=json.dumps({"type": "count", "name": "unused_count"}),
  784. ),
  785. "some_sum": DruidMetric(
  786. metric_name="some_sum",
  787. verbose_name="SUM(*)",
  788. metric_type="sum",
  789. json=json.dumps({"type": "sum", "name": "sum"}),
  790. ),
  791. "a_histogram": DruidMetric(
  792. metric_name="a_histogram",
  793. verbose_name="APPROXIMATE_HISTOGRAM(*)",
  794. metric_type="approxHistogramFold",
  795. json=json.dumps({"type": "approxHistogramFold", "name": "a_histogram"}),
  796. ),
  797. "aCustomMetric": DruidMetric(
  798. metric_name="aCustomMetric",
  799. verbose_name="MY_AWESOME_METRIC(*)",
  800. metric_type="aCustomType",
  801. json=json.dumps({"type": "customMetric", "name": "aCustomMetric"}),
  802. ),
  803. "quantile_p95": DruidMetric(
  804. metric_name="quantile_p95",
  805. verbose_name="P95(*)",
  806. metric_type="postagg",
  807. json=json.dumps(
  808. {
  809. "type": "quantile",
  810. "probability": 0.95,
  811. "name": "p95",
  812. "fieldName": "a_histogram",
  813. }
  814. ),
  815. ),
  816. "aCustomPostAgg": DruidMetric(
  817. metric_name="aCustomPostAgg",
  818. verbose_name="CUSTOM_POST_AGG(*)",
  819. metric_type="postagg",
  820. json=json.dumps(
  821. {
  822. "type": "customPostAgg",
  823. "name": "aCustomPostAgg",
  824. "field": {"type": "fieldAccess", "fieldName": "aCustomMetric"},
  825. }
  826. ),
  827. ),
  828. }
  829. adhoc_metric = {
  830. "expressionType": "SIMPLE",
  831. "column": {"type": "DOUBLE", "column_name": "value"},
  832. "aggregate": "SUM",
  833. "label": "My Adhoc Metric",
  834. }
  835. metrics = ["some_sum"]
  836. saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
  837. metrics, metrics_dict
  838. )
  839. assert set(saved_metrics.keys()) == {"some_sum"}
  840. assert post_aggs == {}
  841. metrics = [adhoc_metric]
  842. saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
  843. metrics, metrics_dict
  844. )
  845. assert set(saved_metrics.keys()) == set([adhoc_metric["label"]])
  846. assert post_aggs == {}
  847. metrics = ["some_sum", adhoc_metric]
  848. saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
  849. metrics, metrics_dict
  850. )
  851. assert set(saved_metrics.keys()) == {"some_sum", adhoc_metric["label"]}
  852. assert post_aggs == {}
  853. metrics = ["quantile_p95"]
  854. saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
  855. metrics, metrics_dict
  856. )
  857. result_postaggs = set(["quantile_p95"])
  858. assert set(saved_metrics.keys()) == {"a_histogram"}
  859. assert set(post_aggs.keys()) == result_postaggs
  860. metrics = ["aCustomPostAgg"]
  861. saved_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
  862. metrics, metrics_dict
  863. )
  864. result_postaggs = set(["aCustomPostAgg"])
  865. assert set(saved_metrics.keys()) == {"aCustomMetric"}
  866. assert set(post_aggs.keys()) == result_postaggs
  867. @unittest.skipUnless(
  868. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  869. )
  870. def test_druid_type_from_adhoc_metric(self):
  871. druid_type = DruidDatasource.druid_type_from_adhoc_metric(
  872. {
  873. "column": {"type": "DOUBLE", "column_name": "value"},
  874. "aggregate": "SUM",
  875. "label": "My Adhoc Metric",
  876. }
  877. )
  878. assert druid_type == "doubleSum"
  879. druid_type = DruidDatasource.druid_type_from_adhoc_metric(
  880. {
  881. "column": {"type": "LONG", "column_name": "value"},
  882. "aggregate": "MAX",
  883. "label": "My Adhoc Metric",
  884. }
  885. )
  886. assert druid_type == "longMax"
  887. druid_type = DruidDatasource.druid_type_from_adhoc_metric(
  888. {
  889. "column": {"type": "VARCHAR(255)", "column_name": "value"},
  890. "aggregate": "COUNT",
  891. "label": "My Adhoc Metric",
  892. }
  893. )
  894. assert druid_type == "count"
  895. druid_type = DruidDatasource.druid_type_from_adhoc_metric(
  896. {
  897. "column": {"type": "VARCHAR(255)", "column_name": "value"},
  898. "aggregate": "COUNT_DISTINCT",
  899. "label": "My Adhoc Metric",
  900. }
  901. )
  902. assert druid_type == "cardinality"
  903. druid_type = DruidDatasource.druid_type_from_adhoc_metric(
  904. {
  905. "column": {"type": "hyperUnique", "column_name": "value"},
  906. "aggregate": "COUNT_DISTINCT",
  907. "label": "My Adhoc Metric",
  908. }
  909. )
  910. assert druid_type == "hyperUnique"
  911. @unittest.skipUnless(
  912. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  913. )
  914. def test_run_query_order_by_metrics(self):
  915. client = Mock()
  916. client.query_builder.last_query.query_dict = {"mock": 0}
  917. from_dttm = Mock()
  918. to_dttm = Mock()
  919. ds = DruidDatasource(datasource_name="datasource")
  920. ds.get_having_filters = Mock(return_value=[])
  921. dim1 = DruidColumn(column_name="dim1")
  922. dim2 = DruidColumn(column_name="dim2")
  923. metrics_dict = {
  924. "count1": DruidMetric(
  925. metric_name="count1",
  926. metric_type="count",
  927. json=json.dumps({"type": "count", "name": "count1"}),
  928. ),
  929. "sum1": DruidMetric(
  930. metric_name="sum1",
  931. metric_type="doubleSum",
  932. json=json.dumps({"type": "doubleSum", "name": "sum1"}),
  933. ),
  934. "sum2": DruidMetric(
  935. metric_name="sum2",
  936. metric_type="doubleSum",
  937. json=json.dumps({"type": "doubleSum", "name": "sum2"}),
  938. ),
  939. "div1": DruidMetric(
  940. metric_name="div1",
  941. metric_type="postagg",
  942. json=json.dumps(
  943. {
  944. "fn": "/",
  945. "type": "arithmetic",
  946. "name": "div1",
  947. "fields": [
  948. {"fieldName": "sum1", "type": "fieldAccess"},
  949. {"fieldName": "sum2", "type": "fieldAccess"},
  950. ],
  951. }
  952. ),
  953. ),
  954. }
  955. ds.columns = [dim1, dim2]
  956. ds.metrics = list(metrics_dict.values())
  957. groupby = ["dim1"]
  958. metrics = ["count1"]
  959. granularity = "all"
  960. # get the counts of the top 5 'dim1's, order by 'sum1'
  961. ds.run_query(
  962. groupby,
  963. metrics,
  964. granularity,
  965. from_dttm,
  966. to_dttm,
  967. timeseries_limit=5,
  968. timeseries_limit_metric="sum1",
  969. client=client,
  970. order_desc=True,
  971. filter=[],
  972. )
  973. qry_obj = client.topn.call_args_list[0][1]
  974. self.assertEqual("dim1", qry_obj["dimension"])
  975. self.assertEqual("sum1", qry_obj["metric"])
  976. aggregations = qry_obj["aggregations"]
  977. post_aggregations = qry_obj["post_aggregations"]
  978. self.assertEqual({"count1", "sum1"}, set(aggregations.keys()))
  979. self.assertEqual(set(), set(post_aggregations.keys()))
  980. # get the counts of the top 5 'dim1's, order by 'div1'
  981. ds.run_query(
  982. groupby,
  983. metrics,
  984. granularity,
  985. from_dttm,
  986. to_dttm,
  987. timeseries_limit=5,
  988. timeseries_limit_metric="div1",
  989. client=client,
  990. order_desc=True,
  991. filter=[],
  992. )
  993. qry_obj = client.topn.call_args_list[1][1]
  994. self.assertEqual("dim1", qry_obj["dimension"])
  995. self.assertEqual("div1", qry_obj["metric"])
  996. aggregations = qry_obj["aggregations"]
  997. post_aggregations = qry_obj["post_aggregations"]
  998. self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys()))
  999. self.assertEqual({"div1"}, set(post_aggregations.keys()))
  1000. groupby = ["dim1", "dim2"]
  1001. # get the counts of the top 5 ['dim1', 'dim2']s, order by 'sum1'
  1002. ds.run_query(
  1003. groupby,
  1004. metrics,
  1005. granularity,
  1006. from_dttm,
  1007. to_dttm,
  1008. timeseries_limit=5,
  1009. timeseries_limit_metric="sum1",
  1010. client=client,
  1011. order_desc=True,
  1012. filter=[],
  1013. )
  1014. qry_obj = client.groupby.call_args_list[0][1]
  1015. self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"]))
  1016. self.assertEqual("sum1", qry_obj["limit_spec"]["columns"][0]["dimension"])
  1017. aggregations = qry_obj["aggregations"]
  1018. post_aggregations = qry_obj["post_aggregations"]
  1019. self.assertEqual({"count1", "sum1"}, set(aggregations.keys()))
  1020. self.assertEqual(set(), set(post_aggregations.keys()))
  1021. # get the counts of the top 5 ['dim1', 'dim2']s, order by 'div1'
  1022. ds.run_query(
  1023. groupby,
  1024. metrics,
  1025. granularity,
  1026. from_dttm,
  1027. to_dttm,
  1028. timeseries_limit=5,
  1029. timeseries_limit_metric="div1",
  1030. client=client,
  1031. order_desc=True,
  1032. filter=[],
  1033. )
  1034. qry_obj = client.groupby.call_args_list[1][1]
  1035. self.assertEqual({"dim1", "dim2"}, set(qry_obj["dimensions"]))
  1036. self.assertEqual("div1", qry_obj["limit_spec"]["columns"][0]["dimension"])
  1037. aggregations = qry_obj["aggregations"]
  1038. post_aggregations = qry_obj["post_aggregations"]
  1039. self.assertEqual({"count1", "sum1", "sum2"}, set(aggregations.keys()))
  1040. self.assertEqual({"div1"}, set(post_aggregations.keys()))
  1041. @unittest.skipUnless(
  1042. SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
  1043. )
  1044. def test_get_aggregations(self):
  1045. ds = DruidDatasource(datasource_name="datasource")
  1046. metrics_dict = {
  1047. "sum1": DruidMetric(
  1048. metric_name="sum1",
  1049. metric_type="doubleSum",
  1050. json=json.dumps({"type": "doubleSum", "name": "sum1"}),
  1051. ),
  1052. "sum2": DruidMetric(
  1053. metric_name="sum2",
  1054. metric_type="doubleSum",
  1055. json=json.dumps({"type": "doubleSum", "name": "sum2"}),
  1056. ),
  1057. "div1": DruidMetric(
  1058. metric_name="div1",
  1059. metric_type="postagg",
  1060. json=json.dumps(
  1061. {
  1062. "fn": "/",
  1063. "type": "arithmetic",
  1064. "name": "div1",
  1065. "fields": [
  1066. {"fieldName": "sum1", "type": "fieldAccess"},
  1067. {"fieldName": "sum2", "type": "fieldAccess"},
  1068. ],
  1069. }
  1070. ),
  1071. ),
  1072. }
  1073. metric_names = ["sum1", "sum2"]
  1074. aggs = ds.get_aggregations(metrics_dict, metric_names)
  1075. expected_agg = {name: metrics_dict[name].json_obj for name in metric_names}
  1076. self.assertEqual(expected_agg, aggs)
  1077. metric_names = ["sum1", "col1"]
  1078. self.assertRaises(
  1079. SupersetException, ds.get_aggregations, metrics_dict, metric_names
  1080. )
  1081. metric_names = ["sum1", "div1"]
  1082. self.assertRaises(
  1083. SupersetException, ds.get_aggregations, metrics_dict, metric_names
  1084. )