sql_parse.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # pylint: disable=C,R,W
  2. import logging
  3. import sqlparse
  4. from sqlparse.sql import Identifier, IdentifierList
  5. from sqlparse.tokens import Keyword, Name
  6. RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
  7. ON_KEYWORD = 'ON'
  8. PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
  9. class SupersetQuery(object):
  10. def __init__(self, sql_statement):
  11. self.sql = sql_statement
  12. self._table_names = set()
  13. self._alias_names = set()
  14. self._limit = None
  15. # TODO: multistatement support
  16. logging.info('Parsing with sqlparse statement {}'.format(self.sql))
  17. self._parsed = sqlparse.parse(self.sql)
  18. for statement in self._parsed:
  19. self.__extract_from_token(statement)
  20. self._limit = self._extract_limit_from_query(statement)
  21. self._table_names = self._table_names - self._alias_names
  22. @property
  23. def tables(self):
  24. return self._table_names
  25. @property
  26. def limit(self):
  27. return self._limit
  28. def is_select(self):
  29. return self._parsed[0].get_type() == 'SELECT'
  30. def is_explain(self):
  31. return self.sql.strip().upper().startswith('EXPLAIN')
  32. def is_readonly(self):
  33. """Pessimistic readonly, 100% sure statement won't mutate anything"""
  34. return self.is_select() or self.is_explain()
  35. def stripped(self):
  36. return self.sql.strip(' \t\n;')
  37. @staticmethod
  38. def __precedes_table_name(token_value):
  39. for keyword in PRECEDES_TABLE_NAME:
  40. if keyword in token_value:
  41. return True
  42. return False
  43. @staticmethod
  44. def __get_full_name(identifier):
  45. if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
  46. return '{}.{}'.format(identifier.tokens[0].value,
  47. identifier.tokens[2].value)
  48. return identifier.get_real_name()
  49. @staticmethod
  50. def __is_result_operation(keyword):
  51. for operation in RESULT_OPERATIONS:
  52. if operation in keyword.upper():
  53. return True
  54. return False
  55. @staticmethod
  56. def __is_identifier(token):
  57. return isinstance(token, (IdentifierList, Identifier))
  58. def __process_identifier(self, identifier):
  59. # exclude subselects
  60. if '(' not in '{}'.format(identifier):
  61. self._table_names.add(self.__get_full_name(identifier))
  62. return
  63. # store aliases
  64. if hasattr(identifier, 'get_alias'):
  65. self._alias_names.add(identifier.get_alias())
  66. if hasattr(identifier, 'tokens'):
  67. # some aliases are not parsed properly
  68. if identifier.tokens[0].ttype == Name:
  69. self._alias_names.add(identifier.tokens[0].value)
  70. self.__extract_from_token(identifier)
  71. def as_create_table(self, table_name, overwrite=False):
  72. """Reformats the query into the create table as query.
  73. Works only for the single select SQL statements, in all other cases
  74. the sql query is not modified.
  75. :param superset_query: string, sql query that will be executed
  76. :param table_name: string, will contain the results of the
  77. query execution
  78. :param overwrite, boolean, table table_name will be dropped if true
  79. :return: string, create table as query
  80. """
  81. exec_sql = ''
  82. sql = self.stripped()
  83. if overwrite:
  84. exec_sql = 'DROP TABLE IF EXISTS {table_name};\n'
  85. exec_sql += 'CREATE TABLE {table_name} AS \n{sql}'
  86. return exec_sql.format(**locals())
  87. def __extract_from_token(self, token):
  88. if not hasattr(token, 'tokens'):
  89. return
  90. table_name_preceding_token = False
  91. for item in token.tokens:
  92. if item.is_group and not self.__is_identifier(item):
  93. self.__extract_from_token(item)
  94. if item.ttype in Keyword:
  95. if self.__precedes_table_name(item.value.upper()):
  96. table_name_preceding_token = True
  97. continue
  98. if not table_name_preceding_token:
  99. continue
  100. if item.ttype in Keyword or item.value == ',':
  101. if (self.__is_result_operation(item.value) or
  102. item.value.upper() == ON_KEYWORD):
  103. table_name_preceding_token = False
  104. continue
  105. # FROM clause is over
  106. break
  107. if isinstance(item, Identifier):
  108. self.__process_identifier(item)
  109. if isinstance(item, IdentifierList):
  110. for token in item.tokens:
  111. if self.__is_identifier(token):
  112. self.__process_identifier(token)
  113. def _get_limit_from_token(self, token):
  114. if token.ttype == sqlparse.tokens.Literal.Number.Integer:
  115. return int(token.value)
  116. elif token.is_group:
  117. return int(token.get_token_at_offset(1).value)
  118. def _extract_limit_from_query(self, statement):
  119. limit_token = None
  120. for pos, item in enumerate(statement.tokens):
  121. if item.ttype in Keyword and item.value.lower() == 'limit':
  122. limit_token = statement.tokens[pos + 2]
  123. return self._get_limit_from_token(limit_token)
  124. def get_query_with_new_limit(self, new_limit):
  125. """returns the query with the specified limit"""
  126. """does not change the underlying query"""
  127. if not self._limit:
  128. return self.sql + ' LIMIT ' + str(new_limit)
  129. limit_pos = None
  130. tokens = self._parsed[0].tokens
  131. # Add all items to before_str until there is a limit
  132. for pos, item in enumerate(tokens):
  133. if item.ttype in Keyword and item.value.lower() == 'limit':
  134. limit_pos = pos
  135. break
  136. limit = tokens[limit_pos + 2]
  137. if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
  138. tokens[limit_pos + 2].value = new_limit
  139. elif limit.is_group:
  140. tokens[limit_pos + 2].value = (
  141. '{}, {}'.format(next(limit.get_identifiers()), new_limit)
  142. )
  143. str_res = ''
  144. for i in tokens:
  145. str_res += str(i.value)
  146. return str_res