presto_db.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. import logging
  18. import time
  19. from contextlib import closing
  20. from typing import Any, Dict, List, Optional
  21. from flask import g
  22. from superset import app, security_manager
  23. from superset.sql_parse import ParsedQuery
  24. from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation
  25. from superset.utils.core import QuerySource
  26. MAX_ERROR_ROWS = 10
  27. config = app.config
  28. logger = logging.getLogger(__name__)
  29. class PrestoSQLValidationError(Exception):
  30. """Error in the process of asking Presto to validate SQL querytext"""
  31. class PrestoDBSQLValidator(BaseSQLValidator):
  32. """Validate SQL queries using Presto's built-in EXPLAIN subtype"""
  33. name = "PrestoDBSQLValidator"
  34. @classmethod
  35. def validate_statement(
  36. cls, statement, database, cursor, user_name
  37. ) -> Optional[SQLValidationAnnotation]:
  38. # pylint: disable=too-many-locals
  39. db_engine_spec = database.db_engine_spec
  40. parsed_query = ParsedQuery(statement)
  41. sql = parsed_query.stripped()
  42. # Hook to allow environment-specific mutation (usually comments) to the SQL
  43. # pylint: disable=invalid-name
  44. SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
  45. if SQL_QUERY_MUTATOR:
  46. sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
  47. # Transform the final statement to an explain call before sending it on
  48. # to presto to validate
  49. sql = f"EXPLAIN (TYPE VALIDATE) {sql}"
  50. # Invoke the query against presto. NB this deliberately doesn't use the
  51. # engine spec's handle_cursor implementation since we don't record
  52. # these EXPLAIN queries done in validation as proper Query objects
  53. # in the superset ORM.
  54. from pyhive.exc import DatabaseError
  55. try:
  56. db_engine_spec.execute(cursor, sql)
  57. polled = cursor.poll()
  58. while polled:
  59. logger.info("polling presto for validation progress")
  60. stats = polled.get("stats", {})
  61. if stats:
  62. state = stats.get("state")
  63. if state == "FINISHED":
  64. break
  65. time.sleep(0.2)
  66. polled = cursor.poll()
  67. db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS)
  68. return None
  69. except DatabaseError as db_error:
  70. # The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses
  71. # as though they were normal queries. In other words, it doesn't
  72. # know that errors here are not exceptional. To map this back to
  73. # ordinary control flow, we have to trap the category of exception
  74. # raised by the underlying client, match the exception arguments
  75. # pyhive provides against the shape of dictionary for a presto query
  76. # invalid error, and restructure that error as an annotation we can
  77. # return up.
  78. # If the first element in the DatabaseError is not a dictionary, but
  79. # is a string, return that message.
  80. if db_error.args and isinstance(db_error.args[0], str):
  81. raise PrestoSQLValidationError(db_error.args[0]) from db_error
  82. # Confirm the first element in the DatabaseError constructor is a
  83. # dictionary with error information. This is currently provided by
  84. # the pyhive client, but may break if their interface changes when
  85. # we update at some point in the future.
  86. if not db_error.args or not isinstance(db_error.args[0], dict):
  87. raise PrestoSQLValidationError(
  88. "The pyhive presto client returned an unhandled " "database error."
  89. ) from db_error
  90. error_args: Dict[str, Any] = db_error.args[0]
  91. # Confirm the two fields we need to be able to present an annotation
  92. # are present in the error response -- a message, and a location.
  93. if "message" not in error_args:
  94. raise PrestoSQLValidationError(
  95. "The pyhive presto client did not report an error message"
  96. ) from db_error
  97. if "errorLocation" not in error_args:
  98. # Pylint is confused about the type of error_args, despite the hints
  99. # and checks above.
  100. # pylint: disable=invalid-sequence-index
  101. message = error_args["message"] + "\n(Error location unknown)"
  102. # If we have a message but no error location, return the message and
  103. # set the location as the beginning.
  104. return SQLValidationAnnotation(
  105. message=message, line_number=1, start_column=1, end_column=1
  106. )
  107. # pylint: disable=invalid-sequence-index
  108. message = error_args["message"]
  109. err_loc = error_args["errorLocation"]
  110. line_number = err_loc.get("lineNumber", None)
  111. start_column = err_loc.get("columnNumber", None)
  112. end_column = err_loc.get("columnNumber", None)
  113. return SQLValidationAnnotation(
  114. message=message,
  115. line_number=line_number,
  116. start_column=start_column,
  117. end_column=end_column,
  118. )
  119. except Exception as e:
  120. logger.exception(f"Unexpected error running validation query: {e}")
  121. raise e
  122. @classmethod
  123. def validate(
  124. cls, sql: str, schema: str, database: Any
  125. ) -> List[SQLValidationAnnotation]:
  126. """
  127. Presto supports query-validation queries by running them with a
  128. prepended explain.
  129. For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
  130. VALIDATE) SELECT 1 FROM default.mytable.
  131. """
  132. user_name = g.user.username if g.user else None
  133. parsed_query = ParsedQuery(sql)
  134. statements = parsed_query.get_statements()
  135. logger.info(f"Validating {len(statements)} statement(s)")
  136. engine = database.get_sqla_engine(
  137. schema=schema,
  138. nullpool=True,
  139. user_name=user_name,
  140. source=QuerySource.SQL_LAB,
  141. )
  142. # Sharing a single connection and cursor across the
  143. # execution of all statements (if many)
  144. annotations: List[SQLValidationAnnotation] = []
  145. with closing(engine.raw_connection()) as conn:
  146. with closing(conn.cursor()) as cursor:
  147. for statement in parsed_query.get_statements():
  148. annotation = cls.validate_statement(
  149. statement, database, cursor, user_name
  150. )
  151. if annotation:
  152. annotations.append(annotation)
  153. logger.debug(f"Validation found {len(annotations)} error(s)")
  154. return annotations