123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455 |
- # Licensed to the Apache Software Foundation (ASF) under one
- # or more contributor license agreements. See the NOTICE file
- # distributed with this work for additional information
- # regarding copyright ownership. The ASF licenses this file
- # to you under the Apache License, Version 2.0 (the
- # "License"); you may not use this file except in compliance
- # with the License. You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # KIND, either express or implied. See the License for the
- # specific language governing permissions and limitations
- # under the License.
- import logging
- import os
- import re
- import time
- from datetime import datetime
- from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
- from urllib import parse
- import pandas as pd
- from sqlalchemy import Column
- from sqlalchemy.engine.base import Engine
- from sqlalchemy.engine.reflection import Inspector
- from sqlalchemy.engine.url import make_url, URL
- from sqlalchemy.orm import Session
- from sqlalchemy.sql.expression import ColumnClause, Select
- from wtforms.form import Form
- from superset import app, cache, conf
- from superset.db_engine_specs.base import BaseEngineSpec
- from superset.db_engine_specs.presto import PrestoEngineSpec
- from superset.models.sql_lab import Query
- from superset.utils import core as utils
- # prevent circular imports
- from superset.models.core import Database # pylint: disable=unused-import
- QueryStatus = utils.QueryStatus
- config = app.config
- logger = logging.getLogger(__name__)
- tracking_url_trans = conf.get("TRACKING_URL_TRANSFORMER")
- hive_poll_interval = conf.get("HIVE_POLL_INTERVAL")
- class HiveEngineSpec(PrestoEngineSpec):
- """Reuses PrestoEngineSpec functionality."""
- engine = "hive"
- max_column_name_length = 767
- # Scoping regex at class level to avoid recompiling
- # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
- jobs_stats_r = re.compile(r".*INFO.*Total jobs = (?P<max_jobs>[0-9]+)")
- # 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
- launching_job_r = re.compile(
- ".*INFO.*Launching Job (?P<job_number>[0-9]+) out of " "(?P<max_jobs>[0-9]+)"
- )
- # 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
- # map = 0%, reduce = 0%
- stage_progress_r = re.compile(
- r".*INFO.*Stage-(?P<stage_number>[0-9]+).*"
- r"map = (?P<map_progress>[0-9]+)%.*"
- r"reduce = (?P<reduce_progress>[0-9]+)%.*"
- )
- @classmethod
- def patch(cls) -> None:
- from pyhive import hive # pylint: disable=no-name-in-module
- from superset.db_engines import hive as patched_hive
- from TCLIService import (
- constants as patched_constants,
- ttypes as patched_ttypes,
- TCLIService as patched_TCLIService,
- )
- hive.TCLIService = patched_TCLIService
- hive.constants = patched_constants
- hive.ttypes = patched_ttypes
- hive.Cursor.fetch_logs = patched_hive.fetch_logs
- @classmethod
- def get_all_datasource_names(
- cls, database: "Database", datasource_type: str
- ) -> List[utils.DatasourceName]:
- return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
- @classmethod
- def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
- import pyhive
- from TCLIService import ttypes
- state = cursor.poll()
- if state.operationState == ttypes.TOperationState.ERROR_STATE:
- raise Exception("Query error", state.errorMessage)
- try:
- return super(HiveEngineSpec, cls).fetch_data(cursor, limit)
- except pyhive.exc.ProgrammingError:
- return []
- @classmethod
- def create_table_from_csv( # pylint: disable=too-many-locals
- cls, form: Form, database: "Database"
- ) -> None:
- """Uploads a csv file and creates a superset datasource in Hive."""
- def convert_to_hive_type(col_type: str) -> str:
- """maps tableschema's types to hive types"""
- tableschema_to_hive_types = {
- "boolean": "BOOLEAN",
- "integer": "INT",
- "number": "DOUBLE",
- "string": "STRING",
- }
- return tableschema_to_hive_types.get(col_type, "STRING")
- bucket_path = config["CSV_TO_HIVE_UPLOAD_S3_BUCKET"]
- if not bucket_path:
- logger.info("No upload bucket specified")
- raise Exception(
- "No upload bucket specified. You can specify one in the config file."
- )
- table_name = form.name.data
- schema_name = form.schema.data
- if "." in table_name or schema_name:
- raise Exception(
- "You can't specify a namespace. "
- "All tables will be uploaded to the `{}` namespace".format(
- config["HIVE_NAMESPACE"]
- )
- )
- full_table_name = "{}.{}".format(
- config["UPLOADED_CSV_HIVE_NAMESPACE"], table_name
- )
- else:
- if "." in table_name and schema_name:
- raise Exception(
- "You can't specify a namespace both in the name of the table "
- "and in the schema field. Please remove one"
- )
- full_table_name = (
- "{}.{}".format(schema_name, table_name) if schema_name else table_name
- )
- filename = form.csv_file.data.filename
- upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY"]
- # Optional dependency
- from tableschema import Table # pylint: disable=import-error
- hive_table_schema = Table(filename).infer()
- column_name_and_type = []
- for column_info in hive_table_schema["fields"]:
- column_name_and_type.append(
- "`{}` {}".format(
- column_info["name"], convert_to_hive_type(column_info["type"])
- )
- )
- schema_definition = ", ".join(column_name_and_type)
- # Optional dependency
- import boto3 # pylint: disable=import-error
- s3 = boto3.client("s3")
- location = os.path.join("s3a://", bucket_path, upload_prefix, table_name)
- s3.upload_file(
- filename,
- bucket_path,
- os.path.join(upload_prefix, table_name, os.path.basename(filename)),
- )
- sql = f"""CREATE TABLE {full_table_name} ( {schema_definition} )
- TEXTFILE LOCATION '{location}'
- tblproperties ('skip.header.line.count'='1')"""
- engine = cls.get_engine(database)
- engine.execute(sql)
- @classmethod
- def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
- tt = target_type.upper()
- if tt == "DATE":
- return f"CAST('{dttm.date().isoformat()}' AS DATE)"
- elif tt == "TIMESTAMP":
- return f"""CAST('{dttm.isoformat(sep=" ", timespec="microseconds")}' AS TIMESTAMP)""" # pylint: disable=line-too-long
- return None
- @classmethod
- def adjust_database_uri(
- cls, uri: URL, selected_schema: Optional[str] = None
- ) -> None:
- if selected_schema:
- uri.database = parse.quote(selected_schema, safe="")
- @classmethod
- def _extract_error_message(cls, e: Exception) -> str:
- msg = str(e)
- match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
- if match:
- msg = match.group(1)
- return msg
- @classmethod
- def progress(cls, log_lines: List[str]) -> int:
- total_jobs = 1 # assuming there's at least 1 job
- current_job = 1
- stages: Dict[int, float] = {}
- for line in log_lines:
- match = cls.jobs_stats_r.match(line)
- if match:
- total_jobs = int(match.groupdict()["max_jobs"]) or 1
- match = cls.launching_job_r.match(line)
- if match:
- current_job = int(match.groupdict()["job_number"])
- total_jobs = int(match.groupdict()["max_jobs"]) or 1
- stages = {}
- match = cls.stage_progress_r.match(line)
- if match:
- stage_number = int(match.groupdict()["stage_number"])
- map_progress = int(match.groupdict()["map_progress"])
- reduce_progress = int(match.groupdict()["reduce_progress"])
- stages[stage_number] = (map_progress + reduce_progress) / 2
- logger.info(
- "Progress detail: {}, " # pylint: disable=logging-format-interpolation
- "current job {}, "
- "total jobs: {}".format(stages, current_job, total_jobs)
- )
- stage_progress = sum(stages.values()) / len(stages.values()) if stages else 0
- progress = 100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
- return int(progress)
- @classmethod
- def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]:
- lkp = "Tracking URL = "
- for line in log_lines:
- if lkp in line:
- return line.split(lkp)[1]
- return None
- @classmethod
- def handle_cursor( # pylint: disable=too-many-locals
- cls, cursor: Any, query: Query, session: Session
- ) -> None:
- """Updates progress information"""
- from pyhive import hive # pylint: disable=no-name-in-module
- unfinished_states = (
- hive.ttypes.TOperationState.INITIALIZED_STATE,
- hive.ttypes.TOperationState.RUNNING_STATE,
- )
- polled = cursor.poll()
- last_log_line = 0
- tracking_url = None
- job_id = None
- query_id = query.id
- while polled.operationState in unfinished_states:
- query = session.query(type(query)).filter_by(id=query_id).one()
- if query.status == QueryStatus.STOPPED:
- cursor.cancel()
- break
- log = cursor.fetch_logs() or ""
- if log:
- log_lines = log.splitlines()
- progress = cls.progress(log_lines)
- logger.info(f"Query {query_id}: Progress total: {progress}")
- needs_commit = False
- if progress > query.progress:
- query.progress = progress
- needs_commit = True
- if not tracking_url:
- tracking_url = cls.get_tracking_url(log_lines)
- if tracking_url:
- job_id = tracking_url.split("/")[-2]
- logger.info(
- f"Query {query_id}: Found the tracking url: {tracking_url}"
- )
- tracking_url = tracking_url_trans(tracking_url)
- logger.info(
- f"Query {query_id}: Transformation applied: {tracking_url}"
- )
- query.tracking_url = tracking_url
- logger.info(f"Query {query_id}: Job id: {job_id}")
- needs_commit = True
- if job_id and len(log_lines) > last_log_line:
- # Wait for job id before logging things out
- # this allows for prefixing all log lines and becoming
- # searchable in something like Kibana
- for l in log_lines[last_log_line:]:
- logger.info(f"Query {query_id}: [{job_id}] {l}")
- last_log_line = len(log_lines)
- if needs_commit:
- session.commit()
- time.sleep(hive_poll_interval)
- polled = cursor.poll()
- @classmethod
- def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
- return inspector.get_columns(table_name, schema)
- @classmethod
- def where_latest_partition( # pylint: disable=too-many-arguments
- cls,
- table_name: str,
- schema: Optional[str],
- database: "Database",
- query: Select,
- columns: Optional[List] = None,
- ) -> Optional[Select]:
- try:
- col_names, values = cls.latest_partition(
- table_name, schema, database, show_first=True
- )
- except Exception: # pylint: disable=broad-except
- # table is not partitioned
- return None
- if values is not None and columns is not None:
- for col_name, value in zip(col_names, values):
- for clm in columns:
- if clm.get("name") == col_name:
- query = query.where(Column(col_name) == value)
- return query
- return None
- @classmethod
- def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
- return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
- @classmethod
- def latest_sub_partition(
- cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
- ) -> str:
- # TODO(bogdan): implement`
- pass
- @classmethod
- def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
- """Hive partitions look like ds={partition name}"""
- if not df.empty:
- return [df.ix[:, 0].max().split("=")[1]]
- return None
- @classmethod
- def _partition_query( # pylint: disable=too-many-arguments
- cls,
- table_name: str,
- database: "Database",
- limit: int = 0,
- order_by: Optional[List[Tuple[str, bool]]] = None,
- filters: Optional[Dict[Any, Any]] = None,
- ) -> str:
- return f"SHOW PARTITIONS {table_name}"
- @classmethod
- def select_star( # pylint: disable=too-many-arguments
- cls,
- database: "Database",
- table_name: str,
- engine: Engine,
- schema: Optional[str] = None,
- limit: int = 100,
- show_cols: bool = False,
- indent: bool = True,
- latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
- ) -> str:
- return super( # pylint: disable=bad-super-call
- PrestoEngineSpec, cls
- ).select_star(
- database,
- table_name,
- engine,
- schema,
- limit,
- show_cols,
- indent,
- latest_partition,
- cols,
- )
- @classmethod
- def modify_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: Optional[str]
- ) -> None:
- """
- Modify the SQL Alchemy URL object with the user to impersonate if applicable.
- :param url: SQLAlchemy URL object
- :param impersonate_user: Flag indicating if impersonation is enabled
- :param username: Effective username
- """
- # Do nothing in the URL object since instead this should modify
- # the configuraiton dictionary. See get_configuration_for_impersonation
- pass
- @classmethod
- def get_configuration_for_impersonation(
- cls, uri: str, impersonate_user: bool, username: Optional[str]
- ) -> Dict[str, str]:
- """
- Return a configuration dictionary that can be merged with other configs
- that can set the correct properties for impersonating users
- :param uri: URI string
- :param impersonate_user: Flag indicating if impersonation is enabled
- :param username: Effective username
- :return: Configs required for impersonation
- """
- configuration = {}
- url = make_url(uri)
- backend_name = url.get_backend_name()
- # Must be Hive connection, enable impersonation, and set param
- if (
- backend_name == "hive"
- and "auth" in url.query.keys()
- and impersonate_user is True
- and username is not None
- ):
- configuration["hive.server2.proxy.user"] = username
- return configuration
- @staticmethod
- def execute( # type: ignore
- cursor, query: str, async_: bool = False
- ): # pylint: disable=arguments-differ
- kwargs = {"async": async_}
- cursor.execute(query, **kwargs)
- @classmethod
- @cache.memoize()
- def get_function_names(cls, database: "Database") -> List[str]:
- """
- Get a list of function names that are able to be called on the database.
- Used for SQL Lab autocomplete.
- :param database: The database to get functions for
- :return: A list of function names useable in the database
- """
- return database.get_df("SHOW FUNCTIONS")["tab_name"].tolist()