123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- import unittest
- from superset import sql_parse
- class SupersetTestCase(unittest.TestCase):
- def extract_tables(self, query):
- sq = sql_parse.ParsedQuery(query)
- return sq.tables
- def test_simple_select(self):
- query = "SELECT * FROM tbname"
- self.assertEqual({"tbname"}, self.extract_tables(query))
- query = "SELECT * FROM tbname foo"
- self.assertEqual({"tbname"}, self.extract_tables(query))
- query = "SELECT * FROM tbname AS foo"
- self.assertEqual({"tbname"}, self.extract_tables(query))
- # underscores
- query = "SELECT * FROM tb_name"
- self.assertEqual({"tb_name"}, self.extract_tables(query))
- # quotes
- query = 'SELECT * FROM "tbname"'
- self.assertEqual({"tbname"}, self.extract_tables(query))
- # unicode encoding
- query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"'
- self.assertEqual({"tb_name"}, self.extract_tables(query))
- # schema
- self.assertEqual(
- {"schemaname.tbname"},
- self.extract_tables("SELECT * FROM schemaname.tbname"),
- )
- self.assertEqual(
- {"schemaname.tbname"},
- self.extract_tables('SELECT * FROM "schemaname"."tbname"'),
- )
- self.assertEqual(
- {"schemaname.tbname"},
- self.extract_tables("SELECT * FROM schemaname.tbname foo"),
- )
- self.assertEqual(
- {"schemaname.tbname"},
- self.extract_tables("SELECT * FROM schemaname.tbname AS foo"),
- )
- # cluster
- self.assertEqual(
- {"clustername.schemaname.tbname"},
- self.extract_tables("SELECT * FROM clustername.schemaname.tbname"),
- )
- # Ill-defined cluster/schema/table.
- self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname."))
- self.assertEqual(
- set(), self.extract_tables("SELECT * FROM clustername.schemaname.")
- )
- self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername.."))
- self.assertEqual(
- set(), self.extract_tables("SELECT * FROM clustername..tbname")
- )
- # quotes
- query = "SELECT field1, field2 FROM tb_name"
- self.assertEqual({"tb_name"}, self.extract_tables(query))
- query = "SELECT t1.f1, t2.f2 FROM t1, t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- def test_select_named_table(self):
- query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
- self.assertEqual({"left_table"}, self.extract_tables(query))
- def test_reverse_select(self):
- query = "FROM t1 SELECT field"
- self.assertEqual({"t1"}, self.extract_tables(query))
- def test_subselect(self):
- query = """
- SELECT sub.*
- FROM (
- SELECT *
- FROM s1.t1
- WHERE day_of_week = 'Friday'
- ) sub, s2.t2
- WHERE sub.resolution = 'NONE'
- """
- self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query))
- query = """
- SELECT sub.*
- FROM (
- SELECT *
- FROM s1.t1
- WHERE day_of_week = 'Friday'
- ) sub
- WHERE sub.resolution = 'NONE'
- """
- self.assertEqual({"s1.t1"}, self.extract_tables(query))
- query = """
- SELECT * FROM t1
- WHERE s11 > ANY
- (SELECT COUNT(*) /* no hint */ FROM t2
- WHERE NOT EXISTS
- (SELECT * FROM t3
- WHERE ROW(5*t2.s1,77)=
- (SELECT 50,11*s1 FROM t4)));
- """
- self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query))
- def test_select_in_expression(self):
- query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- def test_union(self):
- query = "SELECT * FROM t1 UNION SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- def test_select_from_values(self):
- query = "SELECT * FROM VALUES (13, 42)"
- self.assertFalse(self.extract_tables(query))
- def test_select_array(self):
- query = """
- SELECT ARRAY[1, 2, 3] AS my_array
- FROM t1 LIMIT 10
- """
- self.assertEqual({"t1"}, self.extract_tables(query))
- def test_select_if(self):
- query = """
- SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
- FROM t1 LIMIT 10
- """
- self.assertEqual({"t1"}, self.extract_tables(query))
- # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
- def test_show_tables(self):
- query = "SHOW TABLES FROM s1 like '%order%'"
- # TODO: figure out what should code do here
- self.assertEqual({"s1"}, self.extract_tables(query))
- # SHOW COLUMNS (FROM | IN) qualifiedName
- def test_show_columns(self):
- query = "SHOW COLUMNS FROM t1"
- self.assertEqual({"t1"}, self.extract_tables(query))
- def test_where_subquery(self):
- query = """
- SELECT name
- FROM t1
- WHERE regionkey = (SELECT max(regionkey) FROM t2)
- """
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- query = """
- SELECT name
- FROM t1
- WHERE regionkey IN (SELECT regionkey FROM t2)
- """
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- query = """
- SELECT name
- FROM t1
- WHERE regionkey EXISTS (SELECT regionkey FROM t2)
- """
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- # DESCRIBE | DESC qualifiedName
- def test_describe(self):
- self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1"))
- # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
- # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
- def test_show_partitions(self):
- query = """
- SHOW PARTITIONS FROM orders
- WHERE ds >= '2013-01-01' ORDER BY ds DESC;
- """
- self.assertEqual({"orders"}, self.extract_tables(query))
- def test_join(self):
- query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- # subquery + join
- query = """
- SELECT a.date, b.name FROM
- left_table a
- JOIN (
- SELECT
- CAST((b.year) as VARCHAR) date,
- name
- FROM right_table
- ) b
- ON a.date = b.date
- """
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
- query = """
- SELECT a.date, b.name FROM
- left_table a
- LEFT INNER JOIN (
- SELECT
- CAST((b.year) as VARCHAR) date,
- name
- FROM right_table
- ) b
- ON a.date = b.date
- """
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
- query = """
- SELECT a.date, b.name FROM
- left_table a
- RIGHT OUTER JOIN (
- SELECT
- CAST((b.year) as VARCHAR) date,
- name
- FROM right_table
- ) b
- ON a.date = b.date
- """
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
- query = """
- SELECT a.date, b.name FROM
- left_table a
- FULL OUTER JOIN (
- SELECT
- CAST((b.year) as VARCHAR) date,
- name
- FROM right_table
- ) b
- ON a.date = b.date
- """
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
- # TODO: add SEMI join support, SQL Parse does not handle it.
- # query = """
- # SELECT a.date, b.name FROM
- # left_table a
- # LEFT SEMI JOIN (
- # SELECT
- # CAST((b.year) as VARCHAR) date,
- # name
- # FROM right_table
- # ) b
- # ON a.date = b.date
- # """
- # self.assertEqual({'left_table', 'right_table'},
- # sql_parse.extract_tables(query))
- def test_combinations(self):
- query = """
- SELECT * FROM t1
- WHERE s11 > ANY
- (SELECT * FROM t1 UNION ALL SELECT * FROM (
- SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join
- WHERE NOT EXISTS
- (SELECT * FROM t3
- WHERE ROW(5*t3.s1,77)=
- (SELECT 50,11*s1 FROM t4)));
- """
- self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query))
- query = """
- SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
- AS S1) AS S2) AS S3;
- """
- self.assertEqual({"EmployeeS"}, self.extract_tables(query))
- def test_with(self):
- query = """
- WITH
- x AS (SELECT a FROM t1),
- y AS (SELECT a AS b FROM t2),
- z AS (SELECT b AS c FROM t3)
- SELECT c FROM z;
- """
- self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query))
- query = """
- WITH
- x AS (SELECT a FROM t1),
- y AS (SELECT a AS b FROM x),
- z AS (SELECT b AS c FROM y)
- SELECT c FROM z;
- """
- self.assertEqual({"t1"}, self.extract_tables(query))
- def test_reusing_aliases(self):
- query = """
- with q1 as ( select key from q2 where key = '5'),
- q2 as ( select key from src where key = '5')
- select * from (select key from q1) a;
- """
- self.assertEqual({"src"}, self.extract_tables(query))
- def test_multistatement(self):
- query = "SELECT * FROM t1; SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- query = "SELECT * FROM t1; SELECT * FROM t2;"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
- def test_update_not_select(self):
- sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL")
- self.assertEqual(False, sql.is_select())
- self.assertEqual(False, sql.is_readonly())
- def test_explain(self):
- sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1")
- self.assertEqual(True, sql.is_explain())
- self.assertEqual(False, sql.is_select())
- self.assertEqual(True, sql.is_readonly())
- def test_complex_extract_tables(self):
- query = """SELECT sum(m_examples) AS "sum__m_example"
- FROM
- (SELECT COUNT(DISTINCT id_userid) AS m_examples,
- some_more_info
- FROM my_b_table b
- JOIN my_t_table t ON b.ds=t.ds
- JOIN my_l_table l ON b.uid=l.uid
- WHERE b.rid IN
- (SELECT other_col
- FROM inner_table)
- AND l.bla IN ('x', 'y')
- GROUP BY 2
- ORDER BY 2 ASC) AS "meh"
- ORDER BY "sum__m_example" DESC
- LIMIT 10;"""
- self.assertEqual(
- {"my_l_table", "my_b_table", "my_t_table", "inner_table"},
- self.extract_tables(query),
- )
- def test_complex_extract_tables2(self):
- query = """SELECT *
- FROM table_a AS a, table_b AS b, table_c as c
- WHERE a.id = b.id and b.id = c.id"""
- self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
- def test_mixed_from_clause(self):
- query = """SELECT *
- FROM table_a AS a, (select * from table_b) AS b, table_c as c
- WHERE a.id = b.id and b.id = c.id"""
- self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
- def test_nested_selects(self):
- query = """
- select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
- from INFORMATION_SCHEMA.COLUMNS
- WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
- """
- self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
- query = """
- select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
- from INFORMATION_SCHEMA.COLUMNS
- WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
- """
- self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
- def test_complex_extract_tables3(self):
- query = """SELECT somecol AS somecol
- FROM
- (WITH bla AS
- (SELECT col_a
- FROM a
- WHERE 1=1
- AND column_of_choice NOT IN
- ( SELECT interesting_col
- FROM b ) ),
- rb AS
- ( SELECT yet_another_column
- FROM
- ( SELECT a
- FROM c
- GROUP BY the_other_col ) not_table
- LEFT JOIN bla foo ON foo.prop = not_table.bad_col0
- WHERE 1=1
- GROUP BY not_table.bad_col1 ,
- not_table.bad_col2 ,
- ORDER BY not_table.bad_col_3 DESC ,
- not_table.bad_col4 ,
- not_table.bad_col5) SELECT random_col
- FROM d
- WHERE 1=1
- UNION ALL SELECT even_more_cols
- FROM e
- WHERE 1=1
- UNION ALL SELECT lets_go_deeper
- FROM f
- WHERE 1=1
- WHERE 2=2
- GROUP BY last_col
- LIMIT 50000;"""
- self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query))
- def test_complex_cte_with_prefix(self):
- query = """
- WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
- AS (
- SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
- FROM SalesOrderHeader
- WHERE SalesPersonID IS NOT NULL
- )
- SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
- FROM CTE__test
- GROUP BY SalesYear, SalesPersonID
- ORDER BY SalesPersonID, SalesYear;
- """
- self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query))
- def test_get_query_with_new_limit_comment(self):
- sql = "SELECT * FROM birth_names -- SOME COMMENT"
- parsed = sql_parse.ParsedQuery(sql)
- newsql = parsed.get_query_with_new_limit(1000)
- self.assertEqual(newsql, sql + "\nLIMIT 1000")
- def test_get_query_with_new_limit_comment_with_limit(self):
- sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
- parsed = sql_parse.ParsedQuery(sql)
- newsql = parsed.get_query_with_new_limit(1000)
- self.assertEqual(newsql, sql + "\nLIMIT 1000")
- def test_get_query_with_new_limit(self):
- sql = "SELECT * FROM birth_names LIMIT 555"
- parsed = sql_parse.ParsedQuery(sql)
- newsql = parsed.get_query_with_new_limit(1000)
- expected = "SELECT * FROM birth_names LIMIT 1000"
- self.assertEqual(newsql, expected)
- def test_basic_breakdown_statements(self):
- multi_sql = """
- SELECT * FROM birth_names;
- SELECT * FROM birth_names LIMIT 1;
- """
- parsed = sql_parse.ParsedQuery(multi_sql)
- statements = parsed.get_statements()
- self.assertEqual(len(statements), 2)
- expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
- self.assertEqual(statements, expected)
- def test_messy_breakdown_statements(self):
- multi_sql = """
- SELECT 1;\t\n\n\n \t
- \t\nSELECT 2;
- SELECT * FROM birth_names;;;
- SELECT * FROM birth_names LIMIT 1
- """
- parsed = sql_parse.ParsedQuery(multi_sql)
- statements = parsed.get_statements()
- self.assertEqual(len(statements), 4)
- expected = [
- "SELECT 1",
- "SELECT 2",
- "SELECT * FROM birth_names",
- "SELECT * FROM birth_names LIMIT 1",
- ]
- self.assertEqual(statements, expected)
- def test_identifier_list_with_keyword_as_alias(self):
- query = """
- WITH
- f AS (SELECT * FROM foo),
- match AS (SELECT * FROM f)
- SELECT * FROM match
- """
- self.assertEqual({"foo"}, self.extract_tables(query))
|