presto.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. import logging
  18. import re
  19. import textwrap
  20. import time
  21. from collections import defaultdict, deque
  22. from contextlib import closing
  23. from datetime import datetime
  24. from distutils.version import StrictVersion
  25. from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
  26. from urllib import parse
  27. import pandas as pd
  28. import simplejson as json
  29. from sqlalchemy import Column, literal_column
  30. from sqlalchemy.engine.base import Engine
  31. from sqlalchemy.engine.reflection import Inspector
  32. from sqlalchemy.engine.result import RowProxy
  33. from sqlalchemy.engine.url import URL
  34. from sqlalchemy.orm import Session
  35. from sqlalchemy.sql.expression import ColumnClause, Select
  36. from superset import app, cache, is_feature_enabled, security_manager
  37. from superset.db_engine_specs.base import BaseEngineSpec
  38. from superset.exceptions import SupersetTemplateException
  39. from superset.models.sql_lab import Query
  40. from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
  41. from superset.sql_parse import ParsedQuery
  42. from superset.utils import core as utils
  43. if TYPE_CHECKING:
  44. # prevent circular imports
  45. from superset.models.core import Database # pylint: disable=unused-import
  46. QueryStatus = utils.QueryStatus
  47. config = app.config
  48. logger = logging.getLogger(__name__)
  49. def get_children(column: Dict[str, str]) -> List[Dict[str, str]]:
  50. """
  51. Get the children of a complex Presto type (row or array).
  52. For arrays, we return a single list with the base type:
  53. >>> get_children(dict(name="a", type="ARRAY(BIGINT)"))
  54. [{"name": "a", "type": "BIGINT"}]
  55. For rows, we return a list of the columns:
  56. >>> get_children(dict(name="a", type="ROW(BIGINT,FOO VARCHAR)"))
  57. [{'name': 'a._col0', 'type': 'BIGINT'}, {'name': 'a.foo', 'type': 'VARCHAR'}]
  58. :param column: dictionary representing a Presto column
  59. :return: list of dictionaries representing children columns
  60. """
  61. pattern = re.compile(r"(?P<type>\w+)\((?P<children>.*)\)")
  62. match = pattern.match(column["type"])
  63. if not match:
  64. raise Exception(f"Unable to parse column type {column['type']}")
  65. group = match.groupdict()
  66. type_ = group["type"].upper()
  67. children_type = group["children"]
  68. if type_ == "ARRAY":
  69. return [{"name": column["name"], "type": children_type}]
  70. elif type_ == "ROW":
  71. nameless_columns = 0
  72. columns = []
  73. for child in utils.split(children_type, ","):
  74. parts = list(utils.split(child.strip(), " "))
  75. if len(parts) == 2:
  76. name, type_ = parts
  77. name = name.strip('"')
  78. else:
  79. name = f"_col{nameless_columns}"
  80. type_ = parts[0]
  81. nameless_columns += 1
  82. columns.append({"name": f"{column['name']}.{name.lower()}", "type": type_})
  83. return columns
  84. else:
  85. raise Exception(f"Unknown type {type_}!")
  86. class PrestoEngineSpec(BaseEngineSpec):
  87. engine = "presto"
  88. _time_grain_functions = {
  89. None: "{col}",
  90. "PT1S": "date_trunc('second', CAST({col} AS TIMESTAMP))",
  91. "PT1M": "date_trunc('minute', CAST({col} AS TIMESTAMP))",
  92. "PT1H": "date_trunc('hour', CAST({col} AS TIMESTAMP))",
  93. "P1D": "date_trunc('day', CAST({col} AS TIMESTAMP))",
  94. "P1W": "date_trunc('week', CAST({col} AS TIMESTAMP))",
  95. "P1M": "date_trunc('month', CAST({col} AS TIMESTAMP))",
  96. "P0.25Y": "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
  97. "P1Y": "date_trunc('year', CAST({col} AS TIMESTAMP))",
  98. "P1W/1970-01-03T00:00:00Z": "date_add('day', 5, date_trunc('week', "
  99. "date_add('day', 1, CAST({col} AS TIMESTAMP))))",
  100. "1969-12-28T00:00:00Z/P1W": "date_add('day', -1, date_trunc('week', "
  101. "date_add('day', 1, CAST({col} AS TIMESTAMP))))",
  102. }
  103. @classmethod
  104. def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool:
  105. return version is not None and StrictVersion(version) >= StrictVersion("0.319")
  106. @classmethod
  107. def get_table_names(
  108. cls, database: "Database", inspector: Inspector, schema: Optional[str]
  109. ) -> List[str]:
  110. tables = super().get_table_names(database, inspector, schema)
  111. if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
  112. return tables
  113. views = set(cls.get_view_names(database, inspector, schema))
  114. actual_tables = set(tables) - views
  115. return list(actual_tables)
  116. @classmethod
  117. def get_view_names(
  118. cls, database: "Database", inspector: Inspector, schema: Optional[str]
  119. ) -> List[str]:
  120. """Returns an empty list
  121. get_table_names() function returns all table names and view names,
  122. and get_view_names() is not implemented in sqlalchemy_presto.py
  123. https://github.com/dropbox/PyHive/blob/e25fc8440a0686bbb7a5db5de7cb1a77bdb4167a/pyhive/sqlalchemy_presto.py
  124. """
  125. if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
  126. return []
  127. if schema:
  128. sql = (
  129. "SELECT table_name FROM information_schema.views "
  130. "WHERE table_schema=%(schema)s"
  131. )
  132. params = {"schema": schema}
  133. else:
  134. sql = "SELECT table_name FROM information_schema.views"
  135. params = {}
  136. engine = cls.get_engine(database, schema=schema)
  137. with closing(engine.raw_connection()) as conn:
  138. with closing(conn.cursor()) as cursor:
  139. cursor.execute(sql, params)
  140. results = cursor.fetchall()
  141. return [row[0] for row in results]
  142. @classmethod
  143. def _create_column_info(cls, name: str, data_type: str) -> dict:
  144. """
  145. Create column info object
  146. :param name: column name
  147. :param data_type: column data type
  148. :return: column info object
  149. """
  150. return {"name": name, "type": f"{data_type}"}
  151. @classmethod
  152. def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
  153. """
  154. Get the full column name
  155. :param names: list of all individual column names
  156. :return: full column name
  157. """
  158. return ".".join(column[0] for column in names if column[0])
  159. @classmethod
  160. def _has_nested_data_types(cls, component_type: str) -> bool:
  161. """
  162. Check if string contains a data type. We determine if there is a data type by
  163. whitespace or multiple data types by commas
  164. :param component_type: data type
  165. :return: boolean
  166. """
  167. comma_regex = r",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
  168. white_space_regex = r"\s(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
  169. return (
  170. re.search(comma_regex, component_type) is not None
  171. or re.search(white_space_regex, component_type) is not None
  172. )
  173. @classmethod
  174. def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
  175. """
  176. Split data type based on given delimiter. Do not split the string if the
  177. delimiter is enclosed in quotes
  178. :param data_type: data type
  179. :param delimiter: string separator (i.e. open parenthesis, closed parenthesis,
  180. comma, whitespace)
  181. :return: list of strings after breaking it by the delimiter
  182. """
  183. return re.split(
  184. r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type
  185. )
  186. @classmethod
  187. def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches
  188. cls, parent_column_name: str, parent_data_type: str, result: List[dict]
  189. ) -> None:
  190. """
  191. Parse a row or array column
  192. :param result: list tracking the results
  193. """
  194. formatted_parent_column_name = parent_column_name
  195. # Quote the column name if there is a space
  196. if " " in parent_column_name:
  197. formatted_parent_column_name = f'"{parent_column_name}"'
  198. full_data_type = f"{formatted_parent_column_name} {parent_data_type}"
  199. original_result_len = len(result)
  200. # split on open parenthesis ( to get the structural
  201. # data type and its component types
  202. data_types = cls._split_data_type(full_data_type, r"\(")
  203. stack: List[Tuple[str, str]] = []
  204. for data_type in data_types:
  205. # split on closed parenthesis ) to track which component
  206. # types belong to what structural data type
  207. inner_types = cls._split_data_type(data_type, r"\)")
  208. for inner_type in inner_types:
  209. # We have finished parsing multiple structural data types
  210. if not inner_type and stack:
  211. stack.pop()
  212. elif cls._has_nested_data_types(inner_type):
  213. # split on comma , to get individual data types
  214. single_fields = cls._split_data_type(inner_type, ",")
  215. for single_field in single_fields:
  216. single_field = single_field.strip()
  217. # If component type starts with a comma, the first single field
  218. # will be an empty string. Disregard this empty string.
  219. if not single_field:
  220. continue
  221. # split on whitespace to get field name and data type
  222. field_info = cls._split_data_type(single_field, r"\s")
  223. # check if there is a structural data type within
  224. # overall structural data type
  225. if field_info[1] == "array" or field_info[1] == "row":
  226. stack.append((field_info[0], field_info[1]))
  227. full_parent_path = cls._get_full_name(stack)
  228. result.append(
  229. cls._create_column_info(
  230. full_parent_path, presto_type_map[field_info[1]]()
  231. )
  232. )
  233. else: # otherwise this field is a basic data type
  234. full_parent_path = cls._get_full_name(stack)
  235. column_name = "{}.{}".format(
  236. full_parent_path, field_info[0]
  237. )
  238. result.append(
  239. cls._create_column_info(
  240. column_name, presto_type_map[field_info[1]]()
  241. )
  242. )
  243. # If the component type ends with a structural data type, do not pop
  244. # the stack. We have run across a structural data type within the
  245. # overall structural data type. Otherwise, we have completely parsed
  246. # through the entire structural data type and can move on.
  247. if not (inner_type.endswith("array") or inner_type.endswith("row")):
  248. stack.pop()
  249. # We have an array of row objects (i.e. array(row(...)))
  250. elif inner_type == "array" or inner_type == "row":
  251. # Push a dummy object to represent the structural data type
  252. stack.append(("", inner_type))
  253. # We have an array of a basic data types(i.e. array(varchar)).
  254. elif stack:
  255. # Because it is an array of a basic data type. We have finished
  256. # parsing the structural data type and can move on.
  257. stack.pop()
  258. # Unquote the column name if necessary
  259. if formatted_parent_column_name != parent_column_name:
  260. for index in range(original_result_len, len(result)):
  261. result[index]["name"] = result[index]["name"].replace(
  262. formatted_parent_column_name, parent_column_name
  263. )
  264. @classmethod
  265. def _show_columns(
  266. cls, inspector: Inspector, table_name: str, schema: Optional[str]
  267. ) -> List[RowProxy]:
  268. """
  269. Show presto column names
  270. :param inspector: object that performs database schema inspection
  271. :param table_name: table name
  272. :param schema: schema name
  273. :return: list of column objects
  274. """
  275. quote = inspector.engine.dialect.identifier_preparer.quote_identifier
  276. full_table = quote(table_name)
  277. if schema:
  278. full_table = "{}.{}".format(quote(schema), full_table)
  279. columns = inspector.bind.execute("SHOW COLUMNS FROM {}".format(full_table))
  280. return columns
  281. @classmethod
  282. def get_columns(
  283. cls, inspector: Inspector, table_name: str, schema: Optional[str]
  284. ) -> List[Dict[str, Any]]:
  285. """
  286. Get columns from a Presto data source. This includes handling row and
  287. array data types
  288. :param inspector: object that performs database schema inspection
  289. :param table_name: table name
  290. :param schema: schema name
  291. :return: a list of results that contain column info
  292. (i.e. column name and data type)
  293. """
  294. columns = cls._show_columns(inspector, table_name, schema)
  295. result: List[dict] = []
  296. for column in columns:
  297. try:
  298. # parse column if it is a row or array
  299. if is_feature_enabled("PRESTO_EXPAND_DATA") and (
  300. "array" in column.Type or "row" in column.Type
  301. ):
  302. structural_column_index = len(result)
  303. cls._parse_structural_column(column.Column, column.Type, result)
  304. result[structural_column_index]["nullable"] = getattr(
  305. column, "Null", True
  306. )
  307. result[structural_column_index]["default"] = None
  308. continue
  309. else: # otherwise column is a basic data type
  310. column_type = presto_type_map[column.Type]()
  311. except KeyError:
  312. logger.info(
  313. "Did not recognize type {} of column {}".format( # pylint: disable=logging-format-interpolation
  314. column.Type, column.Column
  315. )
  316. )
  317. column_type = "OTHER"
  318. column_info = cls._create_column_info(column.Column, column_type)
  319. column_info["nullable"] = getattr(column, "Null", True)
  320. column_info["default"] = None
  321. result.append(column_info)
  322. return result
  323. @classmethod
  324. def _is_column_name_quoted(cls, column_name: str) -> bool:
  325. """
  326. Check if column name is in quotes
  327. :param column_name: column name
  328. :return: boolean
  329. """
  330. return column_name.startswith('"') and column_name.endswith('"')
  331. @classmethod
  332. def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
  333. """
  334. Format column clauses where names are in quotes and labels are specified
  335. :param cols: columns
  336. :return: column clauses
  337. """
  338. column_clauses = []
  339. # Column names are separated by periods. This regex will find periods in a
  340. # string if they are not enclosed in quotes because if a period is enclosed in
  341. # quotes, then that period is part of a column name.
  342. dot_pattern = r"""\. # split on period
  343. (?= # look ahead
  344. (?: # create non-capture group
  345. [^\"]*\"[^\"]*\" # two quotes
  346. )*[^\"]*$) # end regex"""
  347. dot_regex = re.compile(dot_pattern, re.VERBOSE)
  348. for col in cols:
  349. # get individual column names
  350. col_names = re.split(dot_regex, col["name"])
  351. # quote each column name if it is not already quoted
  352. for index, col_name in enumerate(col_names):
  353. if not cls._is_column_name_quoted(col_name):
  354. col_names[index] = '"{}"'.format(col_name)
  355. quoted_col_name = ".".join(
  356. col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
  357. for col_name in col_names
  358. )
  359. # create column clause in the format "name"."name" AS "name.name"
  360. column_clause = literal_column(quoted_col_name).label(col["name"])
  361. column_clauses.append(column_clause)
  362. return column_clauses
  363. @classmethod
  364. def select_star( # pylint: disable=too-many-arguments
  365. cls,
  366. database: "Database",
  367. table_name: str,
  368. engine: Engine,
  369. schema: Optional[str] = None,
  370. limit: int = 100,
  371. show_cols: bool = False,
  372. indent: bool = True,
  373. latest_partition: bool = True,
  374. cols: Optional[List[Dict[str, Any]]] = None,
  375. ) -> str:
  376. """
  377. Include selecting properties of row objects. We cannot easily break arrays into
  378. rows, so render the whole array in its own row and skip columns that correspond
  379. to an array's contents.
  380. """
  381. cols = cols or []
  382. presto_cols = cols
  383. if is_feature_enabled("PRESTO_EXPAND_DATA") and show_cols:
  384. dot_regex = r"\.(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
  385. presto_cols = [
  386. col for col in presto_cols if not re.search(dot_regex, col["name"])
  387. ]
  388. return super().select_star(
  389. database,
  390. table_name,
  391. engine,
  392. schema,
  393. limit,
  394. show_cols,
  395. indent,
  396. latest_partition,
  397. presto_cols,
  398. )
  399. @classmethod
  400. def estimate_statement_cost( # pylint: disable=too-many-locals
  401. cls, statement: str, database: "Database", cursor: Any, user_name: str
  402. ) -> Dict[str, Any]:
  403. """
  404. Run a SQL query that estimates the cost of a given statement.
  405. :param statement: A single SQL statement
  406. :param database: Database instance
  407. :param cursor: Cursor instance
  408. :param username: Effective username
  409. :return: JSON response from Presto
  410. """
  411. parsed_query = ParsedQuery(statement)
  412. sql = parsed_query.stripped()
  413. sql_query_mutator = config["SQL_QUERY_MUTATOR"]
  414. if sql_query_mutator:
  415. sql = sql_query_mutator(sql, user_name, security_manager, database)
  416. sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {sql}"
  417. cursor.execute(sql)
  418. # the output from Presto is a single column and a single row containing
  419. # JSON:
  420. #
  421. # {
  422. # ...
  423. # "estimate" : {
  424. # "outputRowCount" : 8.73265878E8,
  425. # "outputSizeInBytes" : 3.41425774958E11,
  426. # "cpuCost" : 3.41425774958E11,
  427. # "maxMemory" : 0.0,
  428. # "networkCost" : 3.41425774958E11
  429. # }
  430. # }
  431. result = json.loads(cursor.fetchone()[0])
  432. return result
  433. @classmethod
  434. def query_cost_formatter(
  435. cls, raw_cost: List[Dict[str, Any]]
  436. ) -> List[Dict[str, str]]:
  437. """
  438. Format cost estimate.
  439. :param raw_cost: JSON estimate from Presto
  440. :return: Human readable cost estimate
  441. """
  442. def humanize(value: Any, suffix: str) -> str:
  443. try:
  444. value = int(value)
  445. except ValueError:
  446. return str(value)
  447. prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"]
  448. prefix = ""
  449. to_next_prefix = 1000
  450. while value > to_next_prefix and prefixes:
  451. prefix = prefixes.pop(0)
  452. value //= to_next_prefix
  453. return f"{value} {prefix}{suffix}"
  454. cost = []
  455. columns = [
  456. ("outputRowCount", "Output count", " rows"),
  457. ("outputSizeInBytes", "Output size", "B"),
  458. ("cpuCost", "CPU cost", ""),
  459. ("maxMemory", "Max memory", "B"),
  460. ("networkCost", "Network cost", ""),
  461. ]
  462. for row in raw_cost:
  463. estimate: Dict[str, float] = row.get("estimate", {})
  464. statement_cost = {}
  465. for key, label, suffix in columns:
  466. if key in estimate:
  467. statement_cost[label] = humanize(estimate[key], suffix).strip()
  468. cost.append(statement_cost)
  469. return cost
  470. @classmethod
  471. def adjust_database_uri(
  472. cls, uri: URL, selected_schema: Optional[str] = None
  473. ) -> None:
  474. database = uri.database
  475. if selected_schema and database:
  476. selected_schema = parse.quote(selected_schema, safe="")
  477. if "/" in database:
  478. database = database.split("/")[0] + "/" + selected_schema
  479. else:
  480. database += "/" + selected_schema
  481. uri.database = database
  482. @classmethod
  483. def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
  484. tt = target_type.upper()
  485. if tt == "DATE":
  486. return f"""from_iso8601_date('{dttm.date().isoformat()}')"""
  487. if tt == "TIMESTAMP":
  488. return f"""from_iso8601_timestamp('{dttm.isoformat(timespec="microseconds")}')""" # pylint: disable=line-too-long
  489. return None
  490. @classmethod
  491. def epoch_to_dttm(cls) -> str:
  492. return "from_unixtime({col})"
  493. @classmethod
  494. def get_all_datasource_names(
  495. cls, database: "Database", datasource_type: str
  496. ) -> List[utils.DatasourceName]:
  497. datasource_df = database.get_df(
  498. "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S "
  499. "ORDER BY concat(table_schema, '.', table_name)".format(
  500. datasource_type.upper()
  501. ),
  502. None,
  503. )
  504. datasource_names: List[utils.DatasourceName] = []
  505. for _unused, row in datasource_df.iterrows():
  506. datasource_names.append(
  507. utils.DatasourceName(
  508. schema=row["table_schema"], table=row["table_name"]
  509. )
  510. )
  511. return datasource_names
  512. @classmethod
  513. def expand_data( # pylint: disable=too-many-locals
  514. cls, columns: List[dict], data: List[dict]
  515. ) -> Tuple[List[dict], List[dict], List[dict]]:
  516. """
  517. We do not immediately display rows and arrays clearly in the data grid. This
  518. method separates out nested fields and data values to help clearly display
  519. structural columns.
  520. Example: ColumnA is a row(nested_obj varchar) and ColumnB is an array(int)
  521. Original data set = [
  522. {'ColumnA': ['a1'], 'ColumnB': [1, 2]},
  523. {'ColumnA': ['a2'], 'ColumnB': [3, 4]},
  524. ]
  525. Expanded data set = [
  526. {'ColumnA': ['a1'], 'ColumnA.nested_obj': 'a1', 'ColumnB': 1},
  527. {'ColumnA': '', 'ColumnA.nested_obj': '', 'ColumnB': 2},
  528. {'ColumnA': ['a2'], 'ColumnA.nested_obj': 'a2', 'ColumnB': 3},
  529. {'ColumnA': '', 'ColumnA.nested_obj': '', 'ColumnB': 4},
  530. ]
  531. :param columns: columns selected in the query
  532. :param data: original data set
  533. :return: list of all columns(selected columns and their nested fields),
  534. expanded data set, listed of nested fields
  535. """
  536. if not is_feature_enabled("PRESTO_EXPAND_DATA"):
  537. return columns, data, []
  538. # process each column, unnesting ARRAY types and
  539. # expanding ROW types into new columns
  540. to_process = deque((column, 0) for column in columns)
  541. all_columns: List[dict] = []
  542. expanded_columns = []
  543. current_array_level = None
  544. while to_process:
  545. column, level = to_process.popleft()
  546. if column["name"] not in [column["name"] for column in all_columns]:
  547. all_columns.append(column)
  548. # When unnesting arrays we need to keep track of how many extra rows
  549. # were added, for each original row. This is necessary when we expand
  550. # multiple arrays, so that the arrays after the first reuse the rows
  551. # added by the first. every time we change a level in the nested arrays
  552. # we reinitialize this.
  553. if level != current_array_level:
  554. unnested_rows: Dict[int, int] = defaultdict(int)
  555. current_array_level = level
  556. name = column["name"]
  557. if column["type"].startswith("ARRAY("):
  558. # keep processing array children; we append to the right so that
  559. # multiple nested arrays are processed breadth-first
  560. to_process.append((get_children(column)[0], level + 1))
  561. # unnest array objects data into new rows
  562. i = 0
  563. while i < len(data):
  564. row = data[i]
  565. values = row.get(name)
  566. if values:
  567. # how many extra rows we need to unnest the data?
  568. extra_rows = len(values) - 1
  569. # how many rows were already added for this row?
  570. current_unnested_rows = unnested_rows[i]
  571. # add any necessary rows
  572. missing = extra_rows - current_unnested_rows
  573. for _ in range(missing):
  574. data.insert(i + current_unnested_rows + 1, {})
  575. unnested_rows[i] += 1
  576. # unnest array into rows
  577. for j, value in enumerate(values):
  578. data[i + j][name] = value
  579. # skip newly unnested rows
  580. i += unnested_rows[i]
  581. i += 1
  582. if column["type"].startswith("ROW("):
  583. # expand columns; we append them to the left so they are added
  584. # immediately after the parent
  585. expanded = get_children(column)
  586. to_process.extendleft((column, level) for column in expanded)
  587. expanded_columns.extend(expanded)
  588. # expand row objects into new columns
  589. for row in data:
  590. for value, col in zip(row.get(name) or [], expanded):
  591. row[col["name"]] = value
  592. data = [
  593. {k["name"]: row.get(k["name"], "") for k in all_columns} for row in data
  594. ]
  595. return all_columns, data, expanded_columns
  596. @classmethod
  597. def extra_table_metadata(
  598. cls, database: "Database", table_name: str, schema_name: str
  599. ) -> Dict[str, Any]:
  600. metadata = {}
  601. indexes = database.get_indexes(table_name, schema_name)
  602. if indexes:
  603. cols = indexes[0].get("column_names", [])
  604. full_table_name = table_name
  605. if schema_name and "." not in table_name:
  606. full_table_name = "{}.{}".format(schema_name, table_name)
  607. pql = cls._partition_query(full_table_name, database)
  608. col_names, latest_parts = cls.latest_partition(
  609. table_name, schema_name, database, show_first=True
  610. )
  611. if not latest_parts:
  612. latest_parts = tuple([None] * len(col_names)) # type: ignore
  613. metadata["partitions"] = {
  614. "cols": cols,
  615. "latest": dict(zip(col_names, latest_parts)), # type: ignore
  616. "partitionQuery": pql,
  617. }
  618. # flake8 is not matching `Optional[str]` to `Any` for some reason...
  619. metadata["view"] = cast(
  620. Any, cls.get_create_view(database, schema_name, table_name)
  621. )
  622. return metadata
  623. @classmethod
  624. def get_create_view(
  625. cls, database: "Database", schema: str, table: str
  626. ) -> Optional[str]:
  627. """
  628. Return a CREATE VIEW statement, or `None` if not a view.
  629. :param database: Database instance
  630. :param schema: Schema name
  631. :param table: Table (view) name
  632. """
  633. from pyhive.exc import DatabaseError
  634. engine = cls.get_engine(database, schema)
  635. with closing(engine.raw_connection()) as conn:
  636. with closing(conn.cursor()) as cursor:
  637. sql = f"SHOW CREATE VIEW {schema}.{table}"
  638. try:
  639. cls.execute(cursor, sql)
  640. polled = cursor.poll()
  641. while polled:
  642. time.sleep(0.2)
  643. polled = cursor.poll()
  644. except DatabaseError: # not a VIEW
  645. return None
  646. rows = cls.fetch_data(cursor, 1)
  647. return rows[0][0]
  648. @classmethod
  649. def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
  650. """Updates progress information"""
  651. query_id = query.id
  652. logger.info(f"Query {query_id}: Polling the cursor for progress")
  653. polled = cursor.poll()
  654. # poll returns dict -- JSON status information or ``None``
  655. # if the query is done
  656. # https://github.com/dropbox/PyHive/blob/
  657. # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
  658. while polled:
  659. # Update the object and wait for the kill signal.
  660. stats = polled.get("stats", {})
  661. query = session.query(type(query)).filter_by(id=query_id).one()
  662. if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
  663. cursor.cancel()
  664. break
  665. if stats:
  666. state = stats.get("state")
  667. # if already finished, then stop polling
  668. if state == "FINISHED":
  669. break
  670. completed_splits = float(stats.get("completedSplits"))
  671. total_splits = float(stats.get("totalSplits"))
  672. if total_splits and completed_splits:
  673. progress = 100 * (completed_splits / total_splits)
  674. logger.info(
  675. "Query {} progress: {} / {} " # pylint: disable=logging-format-interpolation
  676. "splits".format(query_id, completed_splits, total_splits)
  677. )
  678. if progress > query.progress:
  679. query.progress = progress
  680. session.commit()
  681. time.sleep(1)
  682. logger.info(f"Query {query_id}: Polling the cursor for progress")
  683. polled = cursor.poll()
  684. @classmethod
  685. def _extract_error_message(cls, e: Exception) -> Optional[str]:
  686. if (
  687. hasattr(e, "orig")
  688. and type(e.orig).__name__ == "DatabaseError" # type: ignore
  689. and isinstance(e.orig[0], dict) # type: ignore
  690. ):
  691. error_dict = e.orig[0] # type: ignore
  692. return "{} at {}: {}".format(
  693. error_dict.get("errorName"),
  694. error_dict.get("errorLocation"),
  695. error_dict.get("message"),
  696. )
  697. if type(e).__name__ == "DatabaseError" and hasattr(e, "args") and e.args:
  698. error_dict = e.args[0]
  699. return error_dict.get("message")
  700. return utils.error_msg_from_exception(e)
  701. @classmethod
  702. def _partition_query( # pylint: disable=too-many-arguments,too-many-locals
  703. cls,
  704. table_name: str,
  705. database: "Database",
  706. limit: int = 0,
  707. order_by: Optional[List[Tuple[str, bool]]] = None,
  708. filters: Optional[Dict[Any, Any]] = None,
  709. ) -> str:
  710. """Returns a partition query
  711. :param table_name: the name of the table to get partitions from
  712. :type table_name: str
  713. :param limit: the number of partitions to be returned
  714. :type limit: int
  715. :param order_by: a list of tuples of field name and a boolean
  716. that determines if that field should be sorted in descending
  717. order
  718. :type order_by: list of (str, bool) tuples
  719. :param filters: dict of field name and filter value combinations
  720. """
  721. limit_clause = "LIMIT {}".format(limit) if limit else ""
  722. order_by_clause = ""
  723. if order_by:
  724. l = []
  725. for field, desc in order_by:
  726. l.append(field + " DESC" if desc else "")
  727. order_by_clause = "ORDER BY " + ", ".join(l)
  728. where_clause = ""
  729. if filters:
  730. l = []
  731. for field, value in filters.items():
  732. l.append(f"{field} = '{value}'")
  733. where_clause = "WHERE " + " AND ".join(l)
  734. presto_version = database.get_extra().get("version")
  735. # Partition select syntax changed in v0.199, so check here.
  736. # Default to the new syntax if version is unset.
  737. partition_select_clause = (
  738. f'SELECT * FROM "{table_name}$partitions"'
  739. if not presto_version
  740. or StrictVersion(presto_version) >= StrictVersion("0.199")
  741. else f"SHOW PARTITIONS FROM {table_name}"
  742. )
  743. sql = textwrap.dedent(
  744. f"""\
  745. {partition_select_clause}
  746. {where_clause}
  747. {order_by_clause}
  748. {limit_clause}
  749. """
  750. )
  751. return sql
  752. @classmethod
  753. def where_latest_partition( # pylint: disable=too-many-arguments
  754. cls,
  755. table_name: str,
  756. schema: Optional[str],
  757. database: "Database",
  758. query: Select,
  759. columns: Optional[List] = None,
  760. ) -> Optional[Select]:
  761. try:
  762. col_names, values = cls.latest_partition(
  763. table_name, schema, database, show_first=True
  764. )
  765. except Exception: # pylint: disable=broad-except
  766. # table is not partitioned
  767. return None
  768. if values is None:
  769. return None
  770. column_names = {column.get("name") for column in columns or []}
  771. for col_name, value in zip(col_names, values):
  772. if col_name in column_names:
  773. query = query.where(Column(col_name) == value)
  774. return query
  775. @classmethod
  776. def _latest_partition_from_df( # pylint: disable=invalid-name
  777. cls, df: pd.DataFrame
  778. ) -> Optional[List[str]]:
  779. if not df.empty:
  780. return df.to_records(index=False)[0].item()
  781. return None
  782. @classmethod
  783. def latest_partition(
  784. cls,
  785. table_name: str,
  786. schema: Optional[str],
  787. database: "Database",
  788. show_first: bool = False,
  789. ) -> Tuple[List[str], Optional[List[str]]]:
  790. """Returns col name and the latest (max) partition value for a table
  791. :param table_name: the name of the table
  792. :param schema: schema / database / namespace
  793. :param database: database query will be run against
  794. :type database: models.Database
  795. :param show_first: displays the value for the first partitioning key
  796. if there are many partitioning keys
  797. :type show_first: bool
  798. >>> latest_partition('foo_table')
  799. (['ds'], ('2018-01-01',))
  800. """
  801. indexes = database.get_indexes(table_name, schema)
  802. if not indexes:
  803. raise SupersetTemplateException(
  804. f"Error getting partition for {schema}.{table_name}. "
  805. "Verify that this table has a partition."
  806. )
  807. if len(indexes[0]["column_names"]) < 1:
  808. raise SupersetTemplateException(
  809. "The table should have one partitioned field"
  810. )
  811. elif not show_first and len(indexes[0]["column_names"]) > 1:
  812. raise SupersetTemplateException(
  813. "The table should have a single partitioned field "
  814. "to use this function. You may want to use "
  815. "`presto.latest_sub_partition`"
  816. )
  817. column_names = indexes[0]["column_names"]
  818. part_fields = [(column_name, True) for column_name in column_names]
  819. sql = cls._partition_query(table_name, database, 1, part_fields)
  820. df = database.get_df(sql, schema)
  821. return column_names, cls._latest_partition_from_df(df)
  822. @classmethod
  823. def latest_sub_partition(
  824. cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
  825. ) -> Any:
  826. """Returns the latest (max) partition value for a table
  827. A filtering criteria should be passed for all fields that are
  828. partitioned except for the field to be returned. For example,
  829. if a table is partitioned by (``ds``, ``event_type`` and
  830. ``event_category``) and you want the latest ``ds``, you'll want
  831. to provide a filter as keyword arguments for both
  832. ``event_type`` and ``event_category`` as in
  833. ``latest_sub_partition('my_table',
  834. event_category='page', event_type='click')``
  835. :param table_name: the name of the table, can be just the table
  836. name or a fully qualified table name as ``schema_name.table_name``
  837. :type table_name: str
  838. :param schema: schema / database / namespace
  839. :type schema: str
  840. :param database: database query will be run against
  841. :type database: models.Database
  842. :param kwargs: keyword arguments define the filtering criteria
  843. on the partition list. There can be many of these.
  844. :type kwargs: str
  845. >>> latest_sub_partition('sub_partition_table', event_type='click')
  846. '2018-01-01'
  847. """
  848. indexes = database.get_indexes(table_name, schema)
  849. part_fields = indexes[0]["column_names"]
  850. for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary
  851. if k not in k in part_fields:
  852. msg = "Field [{k}] is not part of the portioning key"
  853. raise SupersetTemplateException(msg)
  854. if len(kwargs.keys()) != len(part_fields) - 1:
  855. msg = (
  856. "A filter needs to be specified for {} out of the " "{} fields."
  857. ).format(len(part_fields) - 1, len(part_fields))
  858. raise SupersetTemplateException(msg)
  859. for field in part_fields:
  860. if field not in kwargs.keys():
  861. field_to_return = field
  862. sql = cls._partition_query(
  863. table_name, database, 1, [(field_to_return, True)], kwargs
  864. )
  865. df = database.get_df(sql, schema)
  866. if df.empty:
  867. return ""
  868. return df.to_dict()[field_to_return][0]
  869. @classmethod
  870. @cache.memoize()
  871. def get_function_names(cls, database: "Database") -> List[str]:
  872. """
  873. Get a list of function names that are able to be called on the database.
  874. Used for SQL Lab autocomplete.
  875. :param database: The database to get functions for
  876. :return: A list of function names useable in the database
  877. """
  878. return database.get_df("SHOW FUNCTIONS")["Function"].tolist()