hive.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. import logging
  18. import os
  19. import re
  20. import time
  21. from datetime import datetime
  22. from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
  23. from urllib import parse
  24. import pandas as pd
  25. from sqlalchemy import Column
  26. from sqlalchemy.engine.base import Engine
  27. from sqlalchemy.engine.reflection import Inspector
  28. from sqlalchemy.engine.url import make_url, URL
  29. from sqlalchemy.orm import Session
  30. from sqlalchemy.sql.expression import ColumnClause, Select
  31. from wtforms.form import Form
  32. from superset import app, cache, conf
  33. from superset.db_engine_specs.base import BaseEngineSpec
  34. from superset.db_engine_specs.presto import PrestoEngineSpec
  35. from superset.models.sql_lab import Query
  36. from superset.utils import core as utils
  37. if TYPE_CHECKING:
  38. # prevent circular imports
  39. from superset.models.core import Database # pylint: disable=unused-import
  40. QueryStatus = utils.QueryStatus
  41. config = app.config
  42. logger = logging.getLogger(__name__)
  43. tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
  44. hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
  45. class HiveEngineSpec(PrestoEngineSpec):
  46. """Reuses PrestoEngineSpec functionality."""
  47. engine = "hive"
  48. max_column_name_length = 767
  49. # Scoping regex at class level to avoid recompiling
  50. # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
  51. jobs_stats_r = re.compile(r".*INFO.*Total jobs = (?P<max_jobs>[0-9]+)")
  52. # 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
  53. launching_job_r = re.compile(
  54. ".*INFO.*Launching Job (?P<job_number>[0-9]+) out of " "(?P<max_jobs>[0-9]+)"
  55. )
  56. # 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
  57. # map = 0%, reduce = 0%
  58. stage_progress_r = re.compile(
  59. r".*INFO.*Stage-(?P<stage_number>[0-9]+).*"
  60. r"map = (?P<map_progress>[0-9]+)%.*"
  61. r"reduce = (?P<reduce_progress>[0-9]+)%.*"
  62. )
  63. @classmethod
  64. def patch(cls) -> None:
  65. from pyhive import hive # pylint: disable=no-name-in-module
  66. from superset.db_engines import hive as patched_hive
  67. from TCLIService import (
  68. constants as patched_constants,
  69. ttypes as patched_ttypes,
  70. TCLIService as patched_TCLIService,
  71. )
  72. hive.TCLIService = patched_TCLIService
  73. hive.constants = patched_constants
  74. hive.ttypes = patched_ttypes
  75. hive.Cursor.fetch_logs = patched_hive.fetch_logs
  76. @classmethod
  77. def get_all_datasource_names(
  78. cls, database: "Database", datasource_type: str
  79. ) -> List[utils.DatasourceName]:
  80. return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
  81. @classmethod
  82. def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
  83. import pyhive
  84. from TCLIService import ttypes
  85. state = cursor.poll()
  86. if state.operationState == ttypes.TOperationState.ERROR_STATE:
  87. raise Exception("Query error", state.errorMessage)
  88. try:
  89. return super(HiveEngineSpec, cls).fetch_data(cursor, limit)
  90. except pyhive.exc.ProgrammingError:
  91. return []
  92. @classmethod
  93. def create_table_from_csv( # pylint: disable=too-many-locals
  94. cls, form: Form, database: "Database"
  95. ) -> None:
  96. """Uploads a csv file and creates a superset datasource in Hive."""
  97. def convert_to_hive_type(col_type: str) -> str:
  98. """maps tableschema's types to hive types"""
  99. tableschema_to_hive_types = {
  100. "boolean": "BOOLEAN",
  101. "integer": "INT",
  102. "number": "DOUBLE",
  103. "string": "STRING",
  104. }
  105. return tableschema_to_hive_types.get(col_type, "STRING")
  106. bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
  107. if not bucket_path:
  108. logger.info("No upload bucket specified")
  109. raise Exception(
  110. "No upload bucket specified. You can specify one in the config file."
  111. )
  112. table_name = form.name.data
  113. schema_name = form.schema.data
  114. if config["UPLOADED_CSV_HIVE_NAMESPACE"]:
  115. if "." in table_name or schema_name:
  116. raise Exception(
  117. "You can't specify a namespace. "
  118. "All tables will be uploaded to the `{}` namespace".format(
  119. config["HIVE_NAMESPACE"]
  120. )
  121. )
  122. full_table_name = "{}.{}".format(
  123. config["UPLOADED_CSV_HIVE_NAMESPACE"], table_name
  124. )
  125. else:
  126. if "." in table_name and schema_name:
  127. raise Exception(
  128. "You can't specify a namespace both in the name of the table "
  129. "and in the schema field. Please remove one"
  130. )
  131. full_table_name = (
  132. "{}.{}".format(schema_name, table_name) if schema_name else table_name
  133. )
  134. filename = form.csv_file.data.filename
  135. upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY"]
  136. # Optional dependency
  137. from tableschema import Table # pylint: disable=import-error
  138. hive_table_schema = Table(filename).infer()
  139. column_name_and_type = []
  140. for column_info in hive_table_schema["fields"]:
  141. column_name_and_type.append(
  142. "`{}` {}".format(
  143. column_info["name"], convert_to_hive_type(column_info["type"])
  144. )
  145. )
  146. schema_definition = ", ".join(column_name_and_type)
  147. # Optional dependency
  148. import boto3 # pylint: disable=import-error
  149. s3 = boto3.client("s3")
  150. location = os.path.join("s3a://", bucket_path, upload_prefix, table_name)
  151. s3.upload_file(
  152. filename,
  153. bucket_path,
  154. os.path.join(upload_prefix, table_name, os.path.basename(filename)),
  155. )
  156. sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} )
  157. ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
  158. TEXTFILE LOCATION '{location}'
  159. tblproperties ('skip.header.line.count'='1')"""
  160. engine = cls.get_engine(database)
  161. engine.execute(sql)
  162. @classmethod
  163. def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
  164. tt = target_type.upper()
  165. if tt == "DATE":
  166. return f"CAST('{dttm.date().isoformat()}' AS DATE)"
  167. elif tt == "TIMESTAMP":
  168. return f"""CAST('{dttm.isoformat(sep=" ", timespec="microseconds")}' AS TIMESTAMP)""" # pylint: disable=line-too-long
  169. return None
  170. @classmethod
  171. def adjust_database_uri(
  172. cls, uri: URL, selected_schema: Optional[str] = None
  173. ) -> None:
  174. if selected_schema:
  175. uri.database = parse.quote(selected_schema, safe="")
  176. @classmethod
  177. def _extract_error_message(cls, e: Exception) -> str:
  178. msg = str(e)
  179. match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
  180. if match:
  181. msg = match.group(1)
  182. return msg
  183. @classmethod
  184. def progress(cls, log_lines: List[str]) -> int:
  185. total_jobs = 1 # assuming there's at least 1 job
  186. current_job = 1
  187. stages: Dict[int, float] = {}
  188. for line in log_lines:
  189. match = cls.jobs_stats_r.match(line)
  190. if match:
  191. total_jobs = int(match.groupdict()["max_jobs"]) or 1
  192. match = cls.launching_job_r.match(line)
  193. if match:
  194. current_job = int(match.groupdict()["job_number"])
  195. total_jobs = int(match.groupdict()["max_jobs"]) or 1
  196. stages = {}
  197. match = cls.stage_progress_r.match(line)
  198. if match:
  199. stage_number = int(match.groupdict()["stage_number"])
  200. map_progress = int(match.groupdict()["map_progress"])
  201. reduce_progress = int(match.groupdict()["reduce_progress"])
  202. stages[stage_number] = (map_progress + reduce_progress) / 2
  203. logger.info(
  204. "Progress detail: {}, " # pylint: disable=logging-format-interpolation
  205. "current job {}, "
  206. "total jobs: {}".format(stages, current_job, total_jobs)
  207. )
  208. stage_progress = sum(stages.values()) / len(stages.values()) if stages else 0
  209. progress = 100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
  210. return int(progress)
  211. @classmethod
  212. def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]:
  213. lkp = "Tracking URL = "
  214. for line in log_lines:
  215. if lkp in line:
  216. return line.split(lkp)[1]
  217. return None
  218. @classmethod
  219. def handle_cursor( # pylint: disable=too-many-locals
  220. cls, cursor: Any, query: Query, session: Session
  221. ) -> None:
  222. """Updates progress information"""
  223. from pyhive import hive # pylint: disable=no-name-in-module
  224. unfinished_states = (
  225. hive.ttypes.TOperationState.INITIALIZED_STATE,
  226. hive.ttypes.TOperationState.RUNNING_STATE,
  227. )
  228. polled = cursor.poll()
  229. last_log_line = 0
  230. tracking_url = None
  231. job_id = None
  232. query_id = query.id
  233. while polled.operationState in unfinished_states:
  234. query = session.query(type(query)).filter_by(id=query_id).one()
  235. if query.status == QueryStatus.STOPPED:
  236. cursor.cancel()
  237. break
  238. log = cursor.fetch_logs() or ""
  239. if log:
  240. log_lines = log.splitlines()
  241. progress = cls.progress(log_lines)
  242. logger.info(f"Query {query_id}: Progress total: {progress}")
  243. needs_commit = False
  244. if progress > query.progress:
  245. query.progress = progress
  246. needs_commit = True
  247. if not tracking_url:
  248. tracking_url = cls.get_tracking_url(log_lines)
  249. if tracking_url:
  250. job_id = tracking_url.split("/")[-2]
  251. logger.info(
  252. f"Query {query_id}: Found the tracking url: {tracking_url}"
  253. )
  254. tracking_url = tracking_url_trans(tracking_url)
  255. logger.info(
  256. f"Query {query_id}: Transformation applied: {tracking_url}"
  257. )
  258. query.tracking_url = tracking_url
  259. logger.info(f"Query {query_id}: Job id: {job_id}")
  260. needs_commit = True
  261. if job_id and len(log_lines) > last_log_line:
  262. # Wait for job id before logging things out
  263. # this allows for prefixing all log lines and becoming
  264. # searchable in something like Kibana
  265. for l in log_lines[last_log_line:]:
  266. logger.info(f"Query {query_id}: [{job_id}] {l}")
  267. last_log_line = len(log_lines)
  268. if needs_commit:
  269. session.commit()
  270. time.sleep(hive_poll_interval)
  271. polled = cursor.poll()
  272. @classmethod
  273. def get_columns(
  274. cls, inspector: Inspector, table_name: str, schema: Optional[str]
  275. ) -> List[Dict[str, Any]]:
  276. return inspector.get_columns(table_name, schema)
  277. @classmethod
  278. def where_latest_partition( # pylint: disable=too-many-arguments
  279. cls,
  280. table_name: str,
  281. schema: Optional[str],
  282. database: "Database",
  283. query: Select,
  284. columns: Optional[List] = None,
  285. ) -> Optional[Select]:
  286. try:
  287. col_names, values = cls.latest_partition(
  288. table_name, schema, database, show_first=True
  289. )
  290. except Exception: # pylint: disable=broad-except
  291. # table is not partitioned
  292. return None
  293. if values is not None and columns is not None:
  294. for col_name, value in zip(col_names, values):
  295. for clm in columns:
  296. if clm.get("name") == col_name:
  297. query = query.where(Column(col_name) == value)
  298. return query
  299. return None
  300. @classmethod
  301. def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
  302. return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
  303. @classmethod
  304. def latest_sub_partition(
  305. cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
  306. ) -> str:
  307. # TODO(bogdan): implement`
  308. pass
  309. @classmethod
  310. def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
  311. """Hive partitions look like ds={partition name}"""
  312. if not df.empty:
  313. return [df.ix[:, 0].max().split("=")[1]]
  314. return None
  315. @classmethod
  316. def _partition_query( # pylint: disable=too-many-arguments
  317. cls,
  318. table_name: str,
  319. database: "Database",
  320. limit: int = 0,
  321. order_by: Optional[List[Tuple[str, bool]]] = None,
  322. filters: Optional[Dict[Any, Any]] = None,
  323. ) -> str:
  324. return f"SHOW PARTITIONS {table_name}"
  325. @classmethod
  326. def select_star( # pylint: disable=too-many-arguments
  327. cls,
  328. database: "Database",
  329. table_name: str,
  330. engine: Engine,
  331. schema: Optional[str] = None,
  332. limit: int = 100,
  333. show_cols: bool = False,
  334. indent: bool = True,
  335. latest_partition: bool = True,
  336. cols: Optional[List[Dict[str, Any]]] = None,
  337. ) -> str:
  338. return super( # pylint: disable=bad-super-call
  339. PrestoEngineSpec, cls
  340. ).select_star(
  341. database,
  342. table_name,
  343. engine,
  344. schema,
  345. limit,
  346. show_cols,
  347. indent,
  348. latest_partition,
  349. cols,
  350. )
  351. @classmethod
  352. def modify_url_for_impersonation(
  353. cls, url: URL, impersonate_user: bool, username: Optional[str]
  354. ) -> None:
  355. """
  356. Modify the SQL Alchemy URL object with the user to impersonate if applicable.
  357. :param url: SQLAlchemy URL object
  358. :param impersonate_user: Flag indicating if impersonation is enabled
  359. :param username: Effective username
  360. """
  361. # Do nothing in the URL object since instead this should modify
  362. # the configuraiton dictionary. See get_configuration_for_impersonation
  363. pass
  364. @classmethod
  365. def get_configuration_for_impersonation(
  366. cls, uri: str, impersonate_user: bool, username: Optional[str]
  367. ) -> Dict[str, str]:
  368. """
  369. Return a configuration dictionary that can be merged with other configs
  370. that can set the correct properties for impersonating users
  371. :param uri: URI string
  372. :param impersonate_user: Flag indicating if impersonation is enabled
  373. :param username: Effective username
  374. :return: Configs required for impersonation
  375. """
  376. configuration = {}
  377. url = make_url(uri)
  378. backend_name = url.get_backend_name()
  379. # Must be Hive connection, enable impersonation, and set param
  380. # auth=LDAP|KERBEROS
  381. if (
  382. backend_name == "hive"
  383. and "auth" in url.query.keys()
  384. and impersonate_user is True
  385. and username is not None
  386. ):
  387. configuration["hive.server2.proxy.user"] = username
  388. return configuration
  389. @staticmethod
  390. def execute( # type: ignore
  391. cursor, query: str, async_: bool = False
  392. ): # pylint: disable=arguments-differ
  393. kwargs = {"async": async_}
  394. cursor.execute(query, **kwargs)
  395. @classmethod
  396. @cache.memoize()
  397. def get_function_names(cls, database: "Database") -> List[str]:
  398. """
  399. Get a list of function names that are able to be called on the database.
  400. Used for SQL Lab autocomplete.
  401. :param database: The database to get functions for
  402. :return: A list of function names useable in the database
  403. """
  404. return database.get_df("SHOW FUNCTIONS")["tab_name"].tolist()