sql_parse_tests.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. import unittest
  18. from superset import sql_parse
  19. class SupersetTestCase(unittest.TestCase):
  20. def extract_tables(self, query):
  21. sq = sql_parse.ParsedQuery(query)
  22. return sq.tables
  23. def test_simple_select(self):
  24. query = "SELECT * FROM tbname"
  25. self.assertEqual({"tbname"}, self.extract_tables(query))
  26. query = "SELECT * FROM tbname foo"
  27. self.assertEqual({"tbname"}, self.extract_tables(query))
  28. query = "SELECT * FROM tbname AS foo"
  29. self.assertEqual({"tbname"}, self.extract_tables(query))
  30. # underscores
  31. query = "SELECT * FROM tb_name"
  32. self.assertEqual({"tb_name"}, self.extract_tables(query))
  33. # quotes
  34. query = 'SELECT * FROM "tbname"'
  35. self.assertEqual({"tbname"}, self.extract_tables(query))
  36. # unicode encoding
  37. query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"'
  38. self.assertEqual({"tb_name"}, self.extract_tables(query))
  39. # schema
  40. self.assertEqual(
  41. {"schemaname.tbname"},
  42. self.extract_tables("SELECT * FROM schemaname.tbname"),
  43. )
  44. self.assertEqual(
  45. {"schemaname.tbname"},
  46. self.extract_tables('SELECT * FROM "schemaname"."tbname"'),
  47. )
  48. self.assertEqual(
  49. {"schemaname.tbname"},
  50. self.extract_tables("SELECT * FROM schemaname.tbname foo"),
  51. )
  52. self.assertEqual(
  53. {"schemaname.tbname"},
  54. self.extract_tables("SELECT * FROM schemaname.tbname AS foo"),
  55. )
  56. # cluster
  57. self.assertEqual(
  58. {"clustername.schemaname.tbname"},
  59. self.extract_tables("SELECT * FROM clustername.schemaname.tbname"),
  60. )
  61. # Ill-defined cluster/schema/table.
  62. self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname."))
  63. self.assertEqual(
  64. set(), self.extract_tables("SELECT * FROM clustername.schemaname.")
  65. )
  66. self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername.."))
  67. self.assertEqual(
  68. set(), self.extract_tables("SELECT * FROM clustername..tbname")
  69. )
  70. # quotes
  71. query = "SELECT field1, field2 FROM tb_name"
  72. self.assertEqual({"tb_name"}, self.extract_tables(query))
  73. query = "SELECT t1.f1, t2.f2 FROM t1, t2"
  74. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  75. def test_select_named_table(self):
  76. query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
  77. self.assertEqual({"left_table"}, self.extract_tables(query))
  78. def test_reverse_select(self):
  79. query = "FROM t1 SELECT field"
  80. self.assertEqual({"t1"}, self.extract_tables(query))
  81. def test_subselect(self):
  82. query = """
  83. SELECT sub.*
  84. FROM (
  85. SELECT *
  86. FROM s1.t1
  87. WHERE day_of_week = 'Friday'
  88. ) sub, s2.t2
  89. WHERE sub.resolution = 'NONE'
  90. """
  91. self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query))
  92. query = """
  93. SELECT sub.*
  94. FROM (
  95. SELECT *
  96. FROM s1.t1
  97. WHERE day_of_week = 'Friday'
  98. ) sub
  99. WHERE sub.resolution = 'NONE'
  100. """
  101. self.assertEqual({"s1.t1"}, self.extract_tables(query))
  102. query = """
  103. SELECT * FROM t1
  104. WHERE s11 > ANY
  105. (SELECT COUNT(*) /* no hint */ FROM t2
  106. WHERE NOT EXISTS
  107. (SELECT * FROM t3
  108. WHERE ROW(5*t2.s1,77)=
  109. (SELECT 50,11*s1 FROM t4)));
  110. """
  111. self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query))
  112. def test_select_in_expression(self):
  113. query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
  114. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  115. def test_union(self):
  116. query = "SELECT * FROM t1 UNION SELECT * FROM t2"
  117. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  118. query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
  119. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  120. query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
  121. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  122. def test_select_from_values(self):
  123. query = "SELECT * FROM VALUES (13, 42)"
  124. self.assertFalse(self.extract_tables(query))
  125. def test_select_array(self):
  126. query = """
  127. SELECT ARRAY[1, 2, 3] AS my_array
  128. FROM t1 LIMIT 10
  129. """
  130. self.assertEqual({"t1"}, self.extract_tables(query))
  131. def test_select_if(self):
  132. query = """
  133. SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
  134. FROM t1 LIMIT 10
  135. """
  136. self.assertEqual({"t1"}, self.extract_tables(query))
  137. # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
  138. def test_show_tables(self):
  139. query = "SHOW TABLES FROM s1 like '%order%'"
  140. # TODO: figure out what should code do here
  141. self.assertEqual({"s1"}, self.extract_tables(query))
  142. # SHOW COLUMNS (FROM | IN) qualifiedName
  143. def test_show_columns(self):
  144. query = "SHOW COLUMNS FROM t1"
  145. self.assertEqual({"t1"}, self.extract_tables(query))
  146. def test_where_subquery(self):
  147. query = """
  148. SELECT name
  149. FROM t1
  150. WHERE regionkey = (SELECT max(regionkey) FROM t2)
  151. """
  152. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  153. query = """
  154. SELECT name
  155. FROM t1
  156. WHERE regionkey IN (SELECT regionkey FROM t2)
  157. """
  158. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  159. query = """
  160. SELECT name
  161. FROM t1
  162. WHERE regionkey EXISTS (SELECT regionkey FROM t2)
  163. """
  164. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  165. # DESCRIBE | DESC qualifiedName
  166. def test_describe(self):
  167. self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1"))
  168. # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
  169. # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
  170. def test_show_partitions(self):
  171. query = """
  172. SHOW PARTITIONS FROM orders
  173. WHERE ds >= '2013-01-01' ORDER BY ds DESC;
  174. """
  175. self.assertEqual({"orders"}, self.extract_tables(query))
  176. def test_join(self):
  177. query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
  178. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  179. # subquery + join
  180. query = """
  181. SELECT a.date, b.name FROM
  182. left_table a
  183. JOIN (
  184. SELECT
  185. CAST((b.year) as VARCHAR) date,
  186. name
  187. FROM right_table
  188. ) b
  189. ON a.date = b.date
  190. """
  191. self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
  192. query = """
  193. SELECT a.date, b.name FROM
  194. left_table a
  195. LEFT INNER JOIN (
  196. SELECT
  197. CAST((b.year) as VARCHAR) date,
  198. name
  199. FROM right_table
  200. ) b
  201. ON a.date = b.date
  202. """
  203. self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
  204. query = """
  205. SELECT a.date, b.name FROM
  206. left_table a
  207. RIGHT OUTER JOIN (
  208. SELECT
  209. CAST((b.year) as VARCHAR) date,
  210. name
  211. FROM right_table
  212. ) b
  213. ON a.date = b.date
  214. """
  215. self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
  216. query = """
  217. SELECT a.date, b.name FROM
  218. left_table a
  219. FULL OUTER JOIN (
  220. SELECT
  221. CAST((b.year) as VARCHAR) date,
  222. name
  223. FROM right_table
  224. ) b
  225. ON a.date = b.date
  226. """
  227. self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
  228. # TODO: add SEMI join support, SQL Parse does not handle it.
  229. # query = """
  230. # SELECT a.date, b.name FROM
  231. # left_table a
  232. # LEFT SEMI JOIN (
  233. # SELECT
  234. # CAST((b.year) as VARCHAR) date,
  235. # name
  236. # FROM right_table
  237. # ) b
  238. # ON a.date = b.date
  239. # """
  240. # self.assertEqual({'left_table', 'right_table'},
  241. # sql_parse.extract_tables(query))
  242. def test_combinations(self):
  243. query = """
  244. SELECT * FROM t1
  245. WHERE s11 > ANY
  246. (SELECT * FROM t1 UNION ALL SELECT * FROM (
  247. SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join
  248. WHERE NOT EXISTS
  249. (SELECT * FROM t3
  250. WHERE ROW(5*t3.s1,77)=
  251. (SELECT 50,11*s1 FROM t4)));
  252. """
  253. self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query))
  254. query = """
  255. SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
  256. AS S1) AS S2) AS S3;
  257. """
  258. self.assertEqual({"EmployeeS"}, self.extract_tables(query))
  259. def test_with(self):
  260. query = """
  261. WITH
  262. x AS (SELECT a FROM t1),
  263. y AS (SELECT a AS b FROM t2),
  264. z AS (SELECT b AS c FROM t3)
  265. SELECT c FROM z;
  266. """
  267. self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query))
  268. query = """
  269. WITH
  270. x AS (SELECT a FROM t1),
  271. y AS (SELECT a AS b FROM x),
  272. z AS (SELECT b AS c FROM y)
  273. SELECT c FROM z;
  274. """
  275. self.assertEqual({"t1"}, self.extract_tables(query))
  276. def test_reusing_aliases(self):
  277. query = """
  278. with q1 as ( select key from q2 where key = '5'),
  279. q2 as ( select key from src where key = '5')
  280. select * from (select key from q1) a;
  281. """
  282. self.assertEqual({"src"}, self.extract_tables(query))
  283. def test_multistatement(self):
  284. query = "SELECT * FROM t1; SELECT * FROM t2"
  285. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  286. query = "SELECT * FROM t1; SELECT * FROM t2;"
  287. self.assertEqual({"t1", "t2"}, self.extract_tables(query))
  288. def test_update_not_select(self):
  289. sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL")
  290. self.assertEqual(False, sql.is_select())
  291. self.assertEqual(False, sql.is_readonly())
  292. def test_explain(self):
  293. sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1")
  294. self.assertEqual(True, sql.is_explain())
  295. self.assertEqual(False, sql.is_select())
  296. self.assertEqual(True, sql.is_readonly())
  297. def test_complex_extract_tables(self):
  298. query = """SELECT sum(m_examples) AS "sum__m_example"
  299. FROM
  300. (SELECT COUNT(DISTINCT id_userid) AS m_examples,
  301. some_more_info
  302. FROM my_b_table b
  303. JOIN my_t_table t ON b.ds=t.ds
  304. JOIN my_l_table l ON b.uid=l.uid
  305. WHERE b.rid IN
  306. (SELECT other_col
  307. FROM inner_table)
  308. AND l.bla IN ('x', 'y')
  309. GROUP BY 2
  310. ORDER BY 2 ASC) AS "meh"
  311. ORDER BY "sum__m_example" DESC
  312. LIMIT 10;"""
  313. self.assertEqual(
  314. {"my_l_table", "my_b_table", "my_t_table", "inner_table"},
  315. self.extract_tables(query),
  316. )
  317. def test_complex_extract_tables2(self):
  318. query = """SELECT *
  319. FROM table_a AS a, table_b AS b, table_c as c
  320. WHERE a.id = b.id and b.id = c.id"""
  321. self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
  322. def test_mixed_from_clause(self):
  323. query = """SELECT *
  324. FROM table_a AS a, (select * from table_b) AS b, table_c as c
  325. WHERE a.id = b.id and b.id = c.id"""
  326. self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
  327. def test_nested_selects(self):
  328. query = """
  329. select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
  330. from INFORMATION_SCHEMA.COLUMNS
  331. WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
  332. """
  333. self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
  334. query = """
  335. select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
  336. from INFORMATION_SCHEMA.COLUMNS
  337. WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
  338. """
  339. self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
  340. def test_complex_extract_tables3(self):
  341. query = """SELECT somecol AS somecol
  342. FROM
  343. (WITH bla AS
  344. (SELECT col_a
  345. FROM a
  346. WHERE 1=1
  347. AND column_of_choice NOT IN
  348. ( SELECT interesting_col
  349. FROM b ) ),
  350. rb AS
  351. ( SELECT yet_another_column
  352. FROM
  353. ( SELECT a
  354. FROM c
  355. GROUP BY the_other_col ) not_table
  356. LEFT JOIN bla foo ON foo.prop = not_table.bad_col0
  357. WHERE 1=1
  358. GROUP BY not_table.bad_col1 ,
  359. not_table.bad_col2 ,
  360. ORDER BY not_table.bad_col_3 DESC ,
  361. not_table.bad_col4 ,
  362. not_table.bad_col5) SELECT random_col
  363. FROM d
  364. WHERE 1=1
  365. UNION ALL SELECT even_more_cols
  366. FROM e
  367. WHERE 1=1
  368. UNION ALL SELECT lets_go_deeper
  369. FROM f
  370. WHERE 1=1
  371. WHERE 2=2
  372. GROUP BY last_col
  373. LIMIT 50000;"""
  374. self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query))
  375. def test_complex_cte_with_prefix(self):
  376. query = """
  377. WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
  378. AS (
  379. SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
  380. FROM SalesOrderHeader
  381. WHERE SalesPersonID IS NOT NULL
  382. )
  383. SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
  384. FROM CTE__test
  385. GROUP BY SalesYear, SalesPersonID
  386. ORDER BY SalesPersonID, SalesYear;
  387. """
  388. self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query))
  389. def test_get_query_with_new_limit_comment(self):
  390. sql = "SELECT * FROM birth_names -- SOME COMMENT"
  391. parsed = sql_parse.ParsedQuery(sql)
  392. newsql = parsed.get_query_with_new_limit(1000)
  393. self.assertEqual(newsql, sql + "\nLIMIT 1000")
  394. def test_get_query_with_new_limit_comment_with_limit(self):
  395. sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
  396. parsed = sql_parse.ParsedQuery(sql)
  397. newsql = parsed.get_query_with_new_limit(1000)
  398. self.assertEqual(newsql, sql + "\nLIMIT 1000")
  399. def test_get_query_with_new_limit(self):
  400. sql = "SELECT * FROM birth_names LIMIT 555"
  401. parsed = sql_parse.ParsedQuery(sql)
  402. newsql = parsed.get_query_with_new_limit(1000)
  403. expected = "SELECT * FROM birth_names LIMIT 1000"
  404. self.assertEqual(newsql, expected)
  405. def test_basic_breakdown_statements(self):
  406. multi_sql = """
  407. SELECT * FROM birth_names;
  408. SELECT * FROM birth_names LIMIT 1;
  409. """
  410. parsed = sql_parse.ParsedQuery(multi_sql)
  411. statements = parsed.get_statements()
  412. self.assertEqual(len(statements), 2)
  413. expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
  414. self.assertEqual(statements, expected)
  415. def test_messy_breakdown_statements(self):
  416. multi_sql = """
  417. SELECT 1;\t\n\n\n \t
  418. \t\nSELECT 2;
  419. SELECT * FROM birth_names;;;
  420. SELECT * FROM birth_names LIMIT 1
  421. """
  422. parsed = sql_parse.ParsedQuery(multi_sql)
  423. statements = parsed.get_statements()
  424. self.assertEqual(len(statements), 4)
  425. expected = [
  426. "SELECT 1",
  427. "SELECT 2",
  428. "SELECT * FROM birth_names",
  429. "SELECT * FROM birth_names LIMIT 1",
  430. ]
  431. self.assertEqual(statements, expected)
  432. def test_identifier_list_with_keyword_as_alias(self):
  433. query = """
  434. WITH
  435. f AS (SELECT * FROM foo),
  436. match AS (SELECT * FROM f)
  437. SELECT * FROM match
  438. """
  439. self.assertEqual({"foo"}, self.extract_tables(query))