db_engine_specs.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543
  1. # pylint: disable=C,R,W
  2. """Compatibility layer for different database engines
  3. This modules stores logic specific to different database engines. Things
  4. like time-related functions that are similar but not identical, or
  5. information as to expose certain features or not and how to expose them.
  6. For instance, Hive/Presto supports partitions and have a specific API to
  7. list partitions. Other databases like Vertica also support partitions but
  8. have different API to get to them. Other databases don't support partitions
  9. at all. The classes here will use a common interface to specify all this.
  10. The general idea is to use static classes and an inheritance scheme.
  11. """
  12. from collections import defaultdict, namedtuple
  13. import inspect
  14. import logging
  15. import os
  16. import re
  17. import textwrap
  18. import time
  19. import boto3
  20. from flask import g
  21. from flask_babel import lazy_gettext as _
  22. import pandas
  23. from past.builtins import basestring
  24. import sqlalchemy as sqla
  25. from sqlalchemy import Column, select
  26. from sqlalchemy.engine import create_engine
  27. from sqlalchemy.engine.url import make_url
  28. from sqlalchemy.sql import quoted_name, text
  29. from sqlalchemy.sql.expression import TextAsFrom
  30. import sqlparse
  31. from tableschema import Table
  32. from werkzeug.utils import secure_filename
  33. from superset import app, cache_util, conf, db, sql_parse, utils
  34. from superset.exceptions import SupersetTemplateException
  35. from superset.utils import QueryStatus
  36. config = app.config
  37. tracking_url_trans = conf.get('TRACKING_URL_TRANSFORMER')
  38. hive_poll_interval = conf.get('HIVE_POLL_INTERVAL')
  39. Grain = namedtuple('Grain', 'name label function duration')
  40. builtin_time_grains = {
  41. None: 'Time Column',
  42. 'PT1S': 'second',
  43. 'PT1M': 'minute',
  44. 'PT5M': '5 minute',
  45. 'PT10M': '10 minute',
  46. 'PT15M': '15 minute',
  47. 'PT0.5H': 'half hour',
  48. 'PT1H': 'hour',
  49. 'P1D': 'day',
  50. 'P1W': 'week',
  51. 'P1M': 'month',
  52. 'P0.25Y': 'quarter',
  53. 'P1Y': 'year',
  54. '1969-12-28T00:00:00Z/P1W': 'week_start_sunday',
  55. '1969-12-29T00:00:00Z/P1W': 'week_start_monday',
  56. 'P1W/1970-01-03T00:00:00Z': 'week_ending_saturday',
  57. 'P1W/1970-01-04T00:00:00Z': 'week_ending_sunday',
  58. }
  59. def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
  60. ret_list = []
  61. blacklist = blacklist if blacklist else []
  62. for duration, func in time_grain_functions.items():
  63. if duration not in blacklist:
  64. name = time_grains.get(duration)
  65. ret_list.append(Grain(name, _(name), func, duration))
  66. return tuple(ret_list)
  67. class LimitMethod(object):
  68. """Enum the ways that limits can be applied"""
  69. FETCH_MANY = 'fetch_many'
  70. WRAP_SQL = 'wrap_sql'
  71. FORCE_LIMIT = 'force_limit'
  72. class BaseEngineSpec(object):
  73. """Abstract class for database engine specific configurations"""
  74. engine = 'base' # str as defined in sqlalchemy.engine.engine
  75. time_grain_functions = {}
  76. time_groupby_inline = False
  77. limit_method = LimitMethod.FORCE_LIMIT
  78. time_secondary_columns = False
  79. inner_joins = True
  80. allows_subquery = True
  81. force_column_alias_quotes = False
  82. arraysize = None
  83. @classmethod
  84. def get_time_grains(cls):
  85. blacklist = config.get('TIME_GRAIN_BLACKLIST', [])
  86. grains = builtin_time_grains.copy()
  87. grains.update(config.get('TIME_GRAIN_ADDONS', {}))
  88. grain_functions = cls.time_grain_functions.copy()
  89. grain_addon_functions = config.get('TIME_GRAIN_ADDON_FUNCTIONS', {})
  90. grain_functions.update(grain_addon_functions.get(cls.engine, {}))
  91. return _create_time_grains_tuple(grains, grain_functions, blacklist)
  92. @classmethod
  93. def fetch_data(cls, cursor, limit):
  94. if cls.arraysize:
  95. cursor.arraysize = cls.arraysize
  96. if cls.limit_method == LimitMethod.FETCH_MANY:
  97. return cursor.fetchmany(limit)
  98. return cursor.fetchall()
  99. @classmethod
  100. def epoch_to_dttm(cls):
  101. raise NotImplementedError()
  102. @classmethod
  103. def epoch_ms_to_dttm(cls):
  104. return cls.epoch_to_dttm().replace('{col}', '({col}/1000.000)')
  105. @classmethod
  106. def get_datatype(cls, type_code):
  107. if isinstance(type_code, basestring) and len(type_code):
  108. return type_code.upper()
  109. @classmethod
  110. def extra_table_metadata(cls, database, table_name, schema_name):
  111. """Returns engine-specific table metadata"""
  112. return {}
  113. @classmethod
  114. def apply_limit_to_sql(cls, sql, limit, database):
  115. """Alters the SQL statement to apply a LIMIT clause"""
  116. if cls.limit_method == LimitMethod.WRAP_SQL:
  117. sql = sql.strip('\t\n ;')
  118. qry = (
  119. select('*')
  120. .select_from(
  121. TextAsFrom(text(sql), ['*']).alias('inner_qry'),
  122. )
  123. .limit(limit)
  124. )
  125. return database.compile_sqla_query(qry)
  126. elif LimitMethod.FORCE_LIMIT:
  127. parsed_query = sql_parse.SupersetQuery(sql)
  128. sql = parsed_query.get_query_with_new_limit(limit)
  129. return sql
  130. @classmethod
  131. def get_limit_from_sql(cls, sql):
  132. parsed_query = sql_parse.SupersetQuery(sql)
  133. return parsed_query.limit
  134. @classmethod
  135. def get_query_with_new_limit(cls, sql, limit):
  136. parsed_query = sql_parse.SupersetQuery(sql)
  137. return parsed_query.get_query_with_new_limit(limit)
  138. @staticmethod
  139. def csv_to_df(**kwargs):
  140. kwargs['filepath_or_buffer'] = \
  141. config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
  142. kwargs['encoding'] = 'utf-8'
  143. kwargs['iterator'] = True
  144. chunks = pandas.read_csv(**kwargs)
  145. df = pandas.DataFrame()
  146. df = pandas.concat(chunk for chunk in chunks)
  147. return df
  148. @staticmethod
  149. def df_to_db(df, table, **kwargs):
  150. df.to_sql(**kwargs)
  151. table.user_id = g.user.id
  152. table.schema = kwargs['schema']
  153. table.fetch_metadata()
  154. db.session.add(table)
  155. db.session.commit()
  156. @staticmethod
  157. def create_table_from_csv(form, table):
  158. def _allowed_file(filename):
  159. # Only allow specific file extensions as specified in the config
  160. extension = os.path.splitext(filename)[1]
  161. return extension and extension[1:] in config['ALLOWED_EXTENSIONS']
  162. filename = secure_filename(form.csv_file.data.filename)
  163. if not _allowed_file(filename):
  164. raise Exception('Invalid file type selected')
  165. kwargs = {
  166. 'filepath_or_buffer': filename,
  167. 'sep': form.sep.data,
  168. 'header': form.header.data if form.header.data else 0,
  169. 'index_col': form.index_col.data,
  170. 'mangle_dupe_cols': form.mangle_dupe_cols.data,
  171. 'skipinitialspace': form.skipinitialspace.data,
  172. 'skiprows': form.skiprows.data,
  173. 'nrows': form.nrows.data,
  174. 'skip_blank_lines': form.skip_blank_lines.data,
  175. 'parse_dates': form.parse_dates.data,
  176. 'infer_datetime_format': form.infer_datetime_format.data,
  177. 'chunksize': 10000,
  178. }
  179. df = BaseEngineSpec.csv_to_df(**kwargs)
  180. df_to_db_kwargs = {
  181. 'table': table,
  182. 'df': df,
  183. 'name': form.name.data,
  184. 'con': create_engine(form.con.data.sqlalchemy_uri_decrypted, echo=False),
  185. 'schema': form.schema.data,
  186. 'if_exists': form.if_exists.data,
  187. 'index': form.index.data,
  188. 'index_label': form.index_label.data,
  189. 'chunksize': 10000,
  190. }
  191. BaseEngineSpec.df_to_db(**df_to_db_kwargs)
  192. @classmethod
  193. def convert_dttm(cls, target_type, dttm):
  194. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  195. @classmethod
  196. @cache_util.memoized_func(
  197. timeout=600,
  198. key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1]),
  199. use_tables_cache=True)
  200. def fetch_result_sets(cls, db, datasource_type, force=False):
  201. """Returns the dictionary {schema : [result_set_name]}.
  202. Datasource_type can be 'table' or 'view'.
  203. Empty schema corresponds to the list of full names of the all
  204. tables or views: <schema>.<result_set_name>.
  205. """
  206. schemas = db.inspector.get_schema_names()
  207. result_sets = {}
  208. all_result_sets = []
  209. for schema in schemas:
  210. if datasource_type == 'table':
  211. result_sets[schema] = sorted(
  212. db.inspector.get_table_names(schema))
  213. elif datasource_type == 'view':
  214. result_sets[schema] = sorted(
  215. db.inspector.get_view_names(schema))
  216. all_result_sets += [
  217. '{}.{}'.format(schema, t) for t in result_sets[schema]]
  218. if all_result_sets:
  219. result_sets[''] = all_result_sets
  220. return result_sets
  221. @classmethod
  222. def handle_cursor(cls, cursor, query, session):
  223. """Handle a live cursor between the execute and fetchall calls
  224. The flow works without this method doing anything, but it allows
  225. for handling the cursor and updating progress information in the
  226. query object"""
  227. pass
  228. @classmethod
  229. def extract_error_message(cls, e):
  230. """Extract error message for queries"""
  231. return utils.error_msg_from_exception(e)
  232. @classmethod
  233. def adjust_database_uri(cls, uri, selected_schema):
  234. """Based on a URI and selected schema, return a new URI
  235. The URI here represents the URI as entered when saving the database,
  236. ``selected_schema`` is the schema currently active presumably in
  237. the SQL Lab dropdown. Based on that, for some database engine,
  238. we can return a new altered URI that connects straight to the
  239. active schema, meaning the users won't have to prefix the object
  240. names by the schema name.
  241. Some databases engines have 2 level of namespacing: database and
  242. schema (postgres, oracle, mssql, ...)
  243. For those it's probably better to not alter the database
  244. component of the URI with the schema name, it won't work.
  245. Some database drivers like presto accept '{catalog}/{schema}' in
  246. the database component of the URL, that can be handled here.
  247. """
  248. return uri
  249. @classmethod
  250. def patch(cls):
  251. pass
  252. @classmethod
  253. @cache_util.memoized_func(
  254. enable_cache=lambda *args, **kwargs: kwargs.get('enable_cache', False),
  255. timeout=lambda *args, **kwargs: kwargs.get('cache_timeout'),
  256. key=lambda *args, **kwargs: 'db:{}:schema_list'.format(kwargs.get('db_id')))
  257. def get_schema_names(cls, inspector, db_id,
  258. enable_cache, cache_timeout, force=False):
  259. """A function to get all schema names in this db.
  260. :param inspector: URI string
  261. :param db_id: database id
  262. :param enable_cache: whether to enable cache for the function
  263. :param cache_timeout: timeout settings for cache in second.
  264. :param force: force to refresh
  265. :return: a list of schema names
  266. """
  267. return inspector.get_schema_names()
  268. @classmethod
  269. def get_table_names(cls, schema, inspector):
  270. return sorted(inspector.get_table_names(schema))
  271. @classmethod
  272. def where_latest_partition(
  273. cls, table_name, schema, database, qry, columns=None):
  274. return False
  275. @classmethod
  276. def _get_fields(cls, cols):
  277. return [sqla.column(c.get('name')) for c in cols]
  278. @classmethod
  279. def select_star(cls, my_db, table_name, engine, schema=None, limit=100,
  280. show_cols=False, indent=True, latest_partition=True,
  281. cols=None):
  282. fields = '*'
  283. cols = cols or []
  284. if (show_cols or latest_partition) and not cols:
  285. cols = my_db.get_columns(table_name, schema)
  286. if show_cols:
  287. fields = cls._get_fields(cols)
  288. quote = engine.dialect.identifier_preparer.quote
  289. if schema:
  290. full_table_name = quote(schema) + '.' + quote(table_name)
  291. else:
  292. full_table_name = quote(table_name)
  293. qry = select(fields).select_from(text(full_table_name))
  294. if limit:
  295. qry = qry.limit(limit)
  296. if latest_partition:
  297. partition_query = cls.where_latest_partition(
  298. table_name, schema, my_db, qry, columns=cols)
  299. if partition_query != False: # noqa
  300. qry = partition_query
  301. sql = my_db.compile_sqla_query(qry)
  302. if indent:
  303. sql = sqlparse.format(sql, reindent=True)
  304. return sql
  305. @classmethod
  306. def modify_url_for_impersonation(cls, url, impersonate_user, username):
  307. """
  308. Modify the SQL Alchemy URL object with the user to impersonate if applicable.
  309. :param url: SQLAlchemy URL object
  310. :param impersonate_user: Bool indicating if impersonation is enabled
  311. :param username: Effective username
  312. """
  313. if impersonate_user is not None and username is not None:
  314. url.username = username
  315. @classmethod
  316. def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
  317. """
  318. Return a configuration dictionary that can be merged with other configs
  319. that can set the correct properties for impersonating users
  320. :param uri: URI string
  321. :param impersonate_user: Bool indicating if impersonation is enabled
  322. :param username: Effective username
  323. :return: Dictionary with configs required for impersonation
  324. """
  325. return {}
  326. @classmethod
  327. def execute(cls, cursor, query, **kwargs):
  328. if cls.arraysize:
  329. cursor.arraysize = cls.arraysize
  330. cursor.execute(query)
  331. @classmethod
  332. def make_label_compatible(cls, label):
  333. """
  334. Return a sqlalchemy.sql.elements.quoted_name if the engine requires
  335. quoting of aliases to ensure that select query and query results
  336. have same case.
  337. """
  338. if cls.force_column_alias_quotes is True:
  339. return quoted_name(label, True)
  340. return label
  341. @staticmethod
  342. def mutate_expression_label(label):
  343. return label
  344. class PostgresBaseEngineSpec(BaseEngineSpec):
  345. """ Abstract class for Postgres 'like' databases """
  346. engine = ''
  347. time_grain_functions = {
  348. None: '{col}',
  349. 'PT1S': "DATE_TRUNC('second', {col}) AT TIME ZONE 'UTC'",
  350. 'PT1M': "DATE_TRUNC('minute', {col}) AT TIME ZONE 'UTC'",
  351. 'PT1H': "DATE_TRUNC('hour', {col}) AT TIME ZONE 'UTC'",
  352. 'P1D': "DATE_TRUNC('day', {col}) AT TIME ZONE 'UTC'",
  353. 'P1W': "DATE_TRUNC('week', {col}) AT TIME ZONE 'UTC'",
  354. 'P1M': "DATE_TRUNC('month', {col}) AT TIME ZONE 'UTC'",
  355. 'P0.25Y': "DATE_TRUNC('quarter', {col}) AT TIME ZONE 'UTC'",
  356. 'P1Y': "DATE_TRUNC('year', {col}) AT TIME ZONE 'UTC'",
  357. }
  358. @classmethod
  359. def fetch_data(cls, cursor, limit):
  360. if not cursor.description:
  361. return []
  362. if cls.limit_method == LimitMethod.FETCH_MANY:
  363. return cursor.fetchmany(limit)
  364. return cursor.fetchall()
  365. @classmethod
  366. def epoch_to_dttm(cls):
  367. return "(timestamp 'epoch' + {col} * interval '1 second')"
  368. @classmethod
  369. def convert_dttm(cls, target_type, dttm):
  370. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  371. class PostgresEngineSpec(PostgresBaseEngineSpec):
  372. engine = 'postgresql'
  373. @classmethod
  374. def get_table_names(cls, schema, inspector):
  375. """Need to consider foreign tables for PostgreSQL"""
  376. tables = inspector.get_table_names(schema)
  377. tables.extend(inspector.get_foreign_table_names(schema))
  378. return sorted(tables)
  379. class SnowflakeEngineSpec(PostgresBaseEngineSpec):
  380. engine = 'snowflake'
  381. force_column_alias_quotes = True
  382. time_grain_functions = {
  383. None: '{col}',
  384. 'PT1S': "DATE_TRUNC('SECOND', {col})",
  385. 'PT1M': "DATE_TRUNC('MINUTE', {col})",
  386. 'PT5M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 5) * 5, \
  387. DATE_TRUNC('HOUR', {col}))",
  388. 'PT10M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 10) * 10, \
  389. DATE_TRUNC('HOUR', {col}))",
  390. 'PT15M': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 15) * 15, \
  391. DATE_TRUNC('HOUR', {col}))",
  392. 'PT0.5H': "DATEADD(MINUTE, FLOOR(DATE_PART(MINUTE, {col}) / 30) * 30, \
  393. DATE_TRUNC('HOUR', {col}))",
  394. 'PT1H': "DATE_TRUNC('HOUR', {col})",
  395. 'P1D': "DATE_TRUNC('DAY', {col})",
  396. 'P1W': "DATE_TRUNC('WEEK', {col})",
  397. 'P1M': "DATE_TRUNC('MONTH', {col})",
  398. 'P0.25Y': "DATE_TRUNC('QUARTER', {col})",
  399. 'P1Y': "DATE_TRUNC('YEAR', {col})",
  400. }
  401. @classmethod
  402. def adjust_database_uri(cls, uri, selected_schema=None):
  403. database = uri.database
  404. if '/' in uri.database:
  405. database = uri.database.split('/')[0]
  406. if selected_schema:
  407. uri.database = database + '/' + selected_schema
  408. return uri
  409. class VerticaEngineSpec(PostgresBaseEngineSpec):
  410. engine = 'vertica'
  411. class RedshiftEngineSpec(PostgresBaseEngineSpec):
  412. engine = 'redshift'
  413. force_column_alias_quotes = True
  414. class OracleEngineSpec(PostgresBaseEngineSpec):
  415. engine = 'oracle'
  416. limit_method = LimitMethod.WRAP_SQL
  417. force_column_alias_quotes = True
  418. time_grain_functions = {
  419. None: '{col}',
  420. 'PT1S': 'CAST({col} as DATE)',
  421. 'PT1M': "TRUNC(TO_DATE({col}), 'MI')",
  422. 'PT1H': "TRUNC(TO_DATE({col}), 'HH')",
  423. 'P1D': "TRUNC(TO_DATE({col}), 'DDD')",
  424. 'P1W': "TRUNC(TO_DATE({col}), 'WW')",
  425. 'P1M': "TRUNC(TO_DATE({col}), 'MONTH')",
  426. 'P0.25Y': "TRUNC(TO_DATE({col}), 'Q')",
  427. 'P1Y': "TRUNC(TO_DATE({col}), 'YEAR')",
  428. }
  429. @classmethod
  430. def convert_dttm(cls, target_type, dttm):
  431. return (
  432. """TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
  433. ).format(dttm.isoformat())
  434. class Db2EngineSpec(BaseEngineSpec):
  435. engine = 'ibm_db_sa'
  436. limit_method = LimitMethod.WRAP_SQL
  437. force_column_alias_quotes = True
  438. time_grain_functions = {
  439. None: '{col}',
  440. 'PT1S': 'CAST({col} as TIMESTAMP)'
  441. ' - MICROSECOND({col}) MICROSECONDS',
  442. 'PT1M': 'CAST({col} as TIMESTAMP)'
  443. ' - SECOND({col}) SECONDS'
  444. ' - MICROSECOND({col}) MICROSECONDS',
  445. 'PT1H': 'CAST({col} as TIMESTAMP)'
  446. ' - MINUTE({col}) MINUTES'
  447. ' - SECOND({col}) SECONDS'
  448. ' - MICROSECOND({col}) MICROSECONDS ',
  449. 'P1D': 'CAST({col} as TIMESTAMP)'
  450. ' - HOUR({col}) HOURS'
  451. ' - MINUTE({col}) MINUTES'
  452. ' - SECOND({col}) SECONDS'
  453. ' - MICROSECOND({col}) MICROSECONDS',
  454. 'P1W': '{col} - (DAYOFWEEK({col})) DAYS',
  455. 'P1M': '{col} - (DAY({col})-1) DAYS',
  456. 'P0.25Y': '{col} - (DAY({col})-1) DAYS'
  457. ' - (MONTH({col})-1) MONTHS'
  458. ' + ((QUARTER({col})-1) * 3) MONTHS',
  459. 'P1Y': '{col} - (DAY({col})-1) DAYS'
  460. ' - (MONTH({col})-1) MONTHS',
  461. }
  462. @classmethod
  463. def epoch_to_dttm(cls):
  464. return "(TIMESTAMP('1970-01-01', '00:00:00') + {col} SECONDS)"
  465. @classmethod
  466. def convert_dttm(cls, target_type, dttm):
  467. return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
  468. class SqliteEngineSpec(BaseEngineSpec):
  469. engine = 'sqlite'
  470. time_grain_functions = {
  471. None: '{col}',
  472. 'PT1H': "DATETIME(STRFTIME('%Y-%m-%dT%H:00:00', {col}))",
  473. 'P1D': 'DATE({col})',
  474. 'P1W': "DATE({col}, -strftime('%W', {col}) || ' days')",
  475. 'P1M': "DATE({col}, -strftime('%d', {col}) || ' days', '+1 day')",
  476. 'P1Y': "DATETIME(STRFTIME('%Y-01-01T00:00:00', {col}))",
  477. 'P1W/1970-01-03T00:00:00Z': "DATE({col}, 'weekday 6')",
  478. '1969-12-28T00:00:00Z/P1W': "DATE({col}, 'weekday 0', '-7 days')",
  479. }
  480. @classmethod
  481. def epoch_to_dttm(cls):
  482. return "datetime({col}, 'unixepoch')"
  483. @classmethod
  484. @cache_util.memoized_func(
  485. timeout=600,
  486. key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1]),
  487. use_tables_cache=True)
  488. def fetch_result_sets(cls, db, datasource_type, force=False):
  489. schemas = db.inspector.get_schema_names()
  490. result_sets = {}
  491. all_result_sets = []
  492. schema = schemas[0]
  493. if datasource_type == 'table':
  494. result_sets[schema] = sorted(db.inspector.get_table_names())
  495. elif datasource_type == 'view':
  496. result_sets[schema] = sorted(db.inspector.get_view_names())
  497. all_result_sets += [
  498. '{}.{}'.format(schema, t) for t in result_sets[schema]]
  499. if all_result_sets:
  500. result_sets[''] = all_result_sets
  501. return result_sets
  502. @classmethod
  503. def convert_dttm(cls, target_type, dttm):
  504. iso = dttm.isoformat().replace('T', ' ')
  505. if '.' not in iso:
  506. iso += '.000000'
  507. return "'{}'".format(iso)
  508. @classmethod
  509. def get_table_names(cls, schema, inspector):
  510. """Need to disregard the schema for Sqlite"""
  511. return sorted(inspector.get_table_names())
  512. class MySQLEngineSpec(BaseEngineSpec):
  513. engine = 'mysql'
  514. time_grain_functions = {
  515. None: '{col}',
  516. 'PT1S': 'DATE_ADD(DATE({col}), '
  517. 'INTERVAL (HOUR({col})*60*60 + MINUTE({col})*60'
  518. ' + SECOND({col})) SECOND)',
  519. 'PT1M': 'DATE_ADD(DATE({col}), '
  520. 'INTERVAL (HOUR({col})*60 + MINUTE({col})) MINUTE)',
  521. 'PT1H': 'DATE_ADD(DATE({col}), '
  522. 'INTERVAL HOUR({col}) HOUR)',
  523. 'P1D': 'DATE({col})',
  524. 'P1W': 'DATE(DATE_SUB({col}, '
  525. 'INTERVAL DAYOFWEEK({col}) - 1 DAY))',
  526. 'P1M': 'DATE(DATE_SUB({col}, '
  527. 'INTERVAL DAYOFMONTH({col}) - 1 DAY))',
  528. 'P0.25Y': 'MAKEDATE(YEAR({col}), 1) '
  529. '+ INTERVAL QUARTER({col}) QUARTER - INTERVAL 1 QUARTER',
  530. 'P1Y': 'DATE(DATE_SUB({col}, '
  531. 'INTERVAL DAYOFYEAR({col}) - 1 DAY))',
  532. '1969-12-29T00:00:00Z/P1W': 'DATE(DATE_SUB({col}, '
  533. 'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
  534. }
  535. type_code_map = {} # loaded from get_datatype only if needed
  536. @classmethod
  537. def convert_dttm(cls, target_type, dttm):
  538. if target_type.upper() in ('DATETIME', 'DATE'):
  539. return "STR_TO_DATE('{}', '%Y-%m-%d %H:%i:%s')".format(
  540. dttm.strftime('%Y-%m-%d %H:%M:%S'))
  541. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  542. @classmethod
  543. def adjust_database_uri(cls, uri, selected_schema=None):
  544. if selected_schema:
  545. uri.database = selected_schema
  546. return uri
  547. @classmethod
  548. def get_datatype(cls, type_code):
  549. if not cls.type_code_map:
  550. # only import and store if needed at least once
  551. import MySQLdb
  552. ft = MySQLdb.constants.FIELD_TYPE
  553. cls.type_code_map = {
  554. getattr(ft, k): k
  555. for k in dir(ft)
  556. if not k.startswith('_')
  557. }
  558. datatype = type_code
  559. if isinstance(type_code, int):
  560. datatype = cls.type_code_map.get(type_code)
  561. if datatype and isinstance(datatype, basestring) and len(datatype):
  562. return datatype
  563. @classmethod
  564. def epoch_to_dttm(cls):
  565. return 'from_unixtime({col})'
  566. @classmethod
  567. def extract_error_message(cls, e):
  568. """Extract error message for queries"""
  569. message = str(e)
  570. try:
  571. if isinstance(e.args, tuple) and len(e.args) > 1:
  572. message = e.args[1]
  573. except Exception:
  574. pass
  575. return message
  576. class PrestoEngineSpec(BaseEngineSpec):
  577. engine = 'presto'
  578. time_grain_functions = {
  579. None: '{col}',
  580. 'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))",
  581. 'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))",
  582. 'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))",
  583. 'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))",
  584. 'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))",
  585. 'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))",
  586. 'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
  587. 'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))",
  588. 'P1W/1970-01-03T00:00:00Z':
  589. "date_add('day', 5, date_trunc('week', date_add('day', 1, \
  590. CAST({col} AS TIMESTAMP))))",
  591. '1969-12-28T00:00:00Z/P1W':
  592. "date_add('day', -1, date_trunc('week', \
  593. date_add('day', 1, CAST({col} AS TIMESTAMP))))",
  594. }
  595. @classmethod
  596. def adjust_database_uri(cls, uri, selected_schema=None):
  597. database = uri.database
  598. if selected_schema and database:
  599. if '/' in database:
  600. database = database.split('/')[0] + '/' + selected_schema
  601. else:
  602. database += '/' + selected_schema
  603. uri.database = database
  604. return uri
  605. @classmethod
  606. def convert_dttm(cls, target_type, dttm):
  607. tt = target_type.upper()
  608. if tt == 'DATE':
  609. return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
  610. if tt == 'TIMESTAMP':
  611. return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
  612. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  613. @classmethod
  614. def epoch_to_dttm(cls):
  615. return 'from_unixtime({col})'
  616. @classmethod
  617. @cache_util.memoized_func(
  618. timeout=600,
  619. key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1]),
  620. use_tables_cache=True)
  621. def fetch_result_sets(cls, db, datasource_type, force=False):
  622. """Returns the dictionary {schema : [result_set_name]}.
  623. Datasource_type can be 'table' or 'view'.
  624. Empty schema corresponds to the list of full names of the all
  625. tables or views: <schema>.<result_set_name>.
  626. """
  627. result_set_df = db.get_df(
  628. """SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S
  629. ORDER BY concat(table_schema, '.', table_name)""".format(
  630. datasource_type.upper(),
  631. ),
  632. None)
  633. result_sets = defaultdict(list)
  634. for unused, row in result_set_df.iterrows():
  635. result_sets[row['table_schema']].append(row['table_name'])
  636. result_sets[''].append('{}.{}'.format(
  637. row['table_schema'], row['table_name']))
  638. return result_sets
  639. @classmethod
  640. def extra_table_metadata(cls, database, table_name, schema_name):
  641. indexes = database.get_indexes(table_name, schema_name)
  642. if not indexes:
  643. return {}
  644. cols = indexes[0].get('column_names', [])
  645. full_table_name = table_name
  646. if schema_name and '.' not in table_name:
  647. full_table_name = '{}.{}'.format(schema_name, table_name)
  648. pql = cls._partition_query(full_table_name)
  649. col_name, latest_part = cls.latest_partition(
  650. table_name, schema_name, database, show_first=True)
  651. return {
  652. 'partitions': {
  653. 'cols': cols,
  654. 'latest': {col_name: latest_part},
  655. 'partitionQuery': pql,
  656. },
  657. }
  658. @classmethod
  659. def handle_cursor(cls, cursor, query, session):
  660. """Updates progress information"""
  661. logging.info('Polling the cursor for progress')
  662. polled = cursor.poll()
  663. # poll returns dict -- JSON status information or ``None``
  664. # if the query is done
  665. # https://github.com/dropbox/PyHive/blob/
  666. # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
  667. while polled:
  668. # Update the object and wait for the kill signal.
  669. stats = polled.get('stats', {})
  670. query = session.query(type(query)).filter_by(id=query.id).one()
  671. if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
  672. cursor.cancel()
  673. break
  674. if stats:
  675. state = stats.get('state')
  676. # if already finished, then stop polling
  677. if state == 'FINISHED':
  678. break
  679. completed_splits = float(stats.get('completedSplits'))
  680. total_splits = float(stats.get('totalSplits'))
  681. if total_splits and completed_splits:
  682. progress = 100 * (completed_splits / total_splits)
  683. logging.info(
  684. 'Query progress: {} / {} '
  685. 'splits'.format(completed_splits, total_splits))
  686. if progress > query.progress:
  687. query.progress = progress
  688. session.commit()
  689. time.sleep(1)
  690. logging.info('Polling the cursor for progress')
  691. polled = cursor.poll()
  692. @classmethod
  693. def extract_error_message(cls, e):
  694. if (
  695. hasattr(e, 'orig') and
  696. type(e.orig).__name__ == 'DatabaseError' and
  697. isinstance(e.orig[0], dict)):
  698. error_dict = e.orig[0]
  699. return '{} at {}: {}'.format(
  700. error_dict.get('errorName'),
  701. error_dict.get('errorLocation'),
  702. error_dict.get('message'),
  703. )
  704. if (
  705. type(e).__name__ == 'DatabaseError' and
  706. hasattr(e, 'args') and
  707. len(e.args) > 0
  708. ):
  709. error_dict = e.args[0]
  710. return error_dict.get('message')
  711. return utils.error_msg_from_exception(e)
  712. @classmethod
  713. def _partition_query(
  714. cls, table_name, limit=0, order_by=None, filters=None):
  715. """Returns a partition query
  716. :param table_name: the name of the table to get partitions from
  717. :type table_name: str
  718. :param limit: the number of partitions to be returned
  719. :type limit: int
  720. :param order_by: a list of tuples of field name and a boolean
  721. that determines if that field should be sorted in descending
  722. order
  723. :type order_by: list of (str, bool) tuples
  724. :param filters: a list of filters to apply
  725. :param filters: dict of field name and filter value combinations
  726. """
  727. limit_clause = 'LIMIT {}'.format(limit) if limit else ''
  728. order_by_clause = ''
  729. if order_by:
  730. l = [] # noqa: E741
  731. for field, desc in order_by:
  732. l.append(field + ' DESC' if desc else '')
  733. order_by_clause = 'ORDER BY ' + ', '.join(l)
  734. where_clause = ''
  735. if filters:
  736. l = [] # noqa: E741
  737. for field, value in filters.items():
  738. l.append("{field} = '{value}'".format(**locals()))
  739. where_clause = 'WHERE ' + ' AND '.join(l)
  740. sql = textwrap.dedent("""\
  741. SHOW PARTITIONS FROM {table_name}
  742. {where_clause}
  743. {order_by_clause}
  744. {limit_clause}
  745. """).format(**locals())
  746. return sql
  747. @classmethod
  748. def where_latest_partition(
  749. cls, table_name, schema, database, qry, columns=None):
  750. try:
  751. col_name, value = cls.latest_partition(
  752. table_name, schema, database, show_first=True)
  753. except Exception:
  754. # table is not partitioned
  755. return False
  756. for c in columns:
  757. if c.get('name') == col_name:
  758. return qry.where(Column(col_name) == value)
  759. return False
  760. @classmethod
  761. def _latest_partition_from_df(cls, df):
  762. recs = df.to_records(index=False)
  763. if recs:
  764. return recs[0][0]
  765. @classmethod
  766. def latest_partition(cls, table_name, schema, database, show_first=False):
  767. """Returns col name and the latest (max) partition value for a table
  768. :param table_name: the name of the table
  769. :type table_name: str
  770. :param schema: schema / database / namespace
  771. :type schema: str
  772. :param database: database query will be run against
  773. :type database: models.Database
  774. :param show_first: displays the value for the first partitioning key
  775. if there are many partitioning keys
  776. :type show_first: bool
  777. >>> latest_partition('foo_table')
  778. '2018-01-01'
  779. """
  780. indexes = database.get_indexes(table_name, schema)
  781. if len(indexes[0]['column_names']) < 1:
  782. raise SupersetTemplateException(
  783. 'The table should have one partitioned field')
  784. elif not show_first and len(indexes[0]['column_names']) > 1:
  785. raise SupersetTemplateException(
  786. 'The table should have a single partitioned field '
  787. 'to use this function. You may want to use '
  788. '`presto.latest_sub_partition`')
  789. part_field = indexes[0]['column_names'][0]
  790. sql = cls._partition_query(table_name, 1, [(part_field, True)])
  791. df = database.get_df(sql, schema)
  792. return part_field, cls._latest_partition_from_df(df)
  793. @classmethod
  794. def latest_sub_partition(cls, table_name, schema, database, **kwargs):
  795. """Returns the latest (max) partition value for a table
  796. A filtering criteria should be passed for all fields that are
  797. partitioned except for the field to be returned. For example,
  798. if a table is partitioned by (``ds``, ``event_type`` and
  799. ``event_category``) and you want the latest ``ds``, you'll want
  800. to provide a filter as keyword arguments for both
  801. ``event_type`` and ``event_category`` as in
  802. ``latest_sub_partition('my_table',
  803. event_category='page', event_type='click')``
  804. :param table_name: the name of the table, can be just the table
  805. name or a fully qualified table name as ``schema_name.table_name``
  806. :type table_name: str
  807. :param schema: schema / database / namespace
  808. :type schema: str
  809. :param database: database query will be run against
  810. :type database: models.Database
  811. :param kwargs: keyword arguments define the filtering criteria
  812. on the partition list. There can be many of these.
  813. :type kwargs: str
  814. >>> latest_sub_partition('sub_partition_table', event_type='click')
  815. '2018-01-01'
  816. """
  817. indexes = database.get_indexes(table_name, schema)
  818. part_fields = indexes[0]['column_names']
  819. for k in kwargs.keys():
  820. if k not in k in part_fields:
  821. msg = 'Field [{k}] is not part of the portioning key'
  822. raise SupersetTemplateException(msg)
  823. if len(kwargs.keys()) != len(part_fields) - 1:
  824. msg = (
  825. 'A filter needs to be specified for {} out of the '
  826. '{} fields.'
  827. ).format(len(part_fields) - 1, len(part_fields))
  828. raise SupersetTemplateException(msg)
  829. for field in part_fields:
  830. if field not in kwargs.keys():
  831. field_to_return = field
  832. sql = cls._partition_query(
  833. table_name, 1, [(field_to_return, True)], kwargs)
  834. df = database.get_df(sql, schema)
  835. if df.empty:
  836. return ''
  837. return df.to_dict()[field_to_return][0]
  838. class HiveEngineSpec(PrestoEngineSpec):
  839. """Reuses PrestoEngineSpec functionality."""
  840. engine = 'hive'
  841. # Scoping regex at class level to avoid recompiling
  842. # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
  843. jobs_stats_r = re.compile(
  844. r'.*INFO.*Total jobs = (?P<max_jobs>[0-9]+)')
  845. # 17/02/07 19:37:08 INFO ql.Driver: Launching Job 2 out of 5
  846. launching_job_r = re.compile(
  847. '.*INFO.*Launching Job (?P<job_number>[0-9]+) out of '
  848. '(?P<max_jobs>[0-9]+)')
  849. # 17/02/07 19:36:58 INFO exec.Task: 2017-02-07 19:36:58,152 Stage-18
  850. # map = 0%, reduce = 0%
  851. stage_progress_r = re.compile(
  852. r'.*INFO.*Stage-(?P<stage_number>[0-9]+).*'
  853. r'map = (?P<map_progress>[0-9]+)%.*'
  854. r'reduce = (?P<reduce_progress>[0-9]+)%.*')
  855. @classmethod
  856. def patch(cls):
  857. from pyhive import hive
  858. from superset.db_engines import hive as patched_hive
  859. from TCLIService import (
  860. constants as patched_constants,
  861. ttypes as patched_ttypes,
  862. TCLIService as patched_TCLIService)
  863. hive.TCLIService = patched_TCLIService
  864. hive.constants = patched_constants
  865. hive.ttypes = patched_ttypes
  866. hive.Cursor.fetch_logs = patched_hive.fetch_logs
  867. @classmethod
  868. @cache_util.memoized_func(
  869. timeout=600,
  870. key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1]),
  871. use_tables_cache=True)
  872. def fetch_result_sets(cls, db, datasource_type, force=False):
  873. return BaseEngineSpec.fetch_result_sets(
  874. db, datasource_type, force=force)
  875. @classmethod
  876. def fetch_data(cls, cursor, limit):
  877. from TCLIService import ttypes
  878. state = cursor.poll()
  879. if state.operationState == ttypes.TOperationState.ERROR_STATE:
  880. raise Exception('Query error', state.errorMessage)
  881. return super(HiveEngineSpec, cls).fetch_data(cursor, limit)
  882. @staticmethod
  883. def create_table_from_csv(form, table):
  884. """Uploads a csv file and creates a superset datasource in Hive."""
  885. def convert_to_hive_type(col_type):
  886. """maps tableschema's types to hive types"""
  887. tableschema_to_hive_types = {
  888. 'boolean': 'BOOLEAN',
  889. 'integer': 'INT',
  890. 'number': 'DOUBLE',
  891. 'string': 'STRING',
  892. }
  893. return tableschema_to_hive_types.get(col_type, 'STRING')
  894. bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']
  895. if not bucket_path:
  896. logging.info('No upload bucket specified')
  897. raise Exception(
  898. 'No upload bucket specified. You can specify one in the config file.')
  899. table_name = form.name.data
  900. schema_name = form.schema.data
  901. if config.get('UPLOADED_CSV_HIVE_NAMESPACE'):
  902. if '.' in table_name or schema_name:
  903. raise Exception(
  904. "You can't specify a namespace. "
  905. 'All tables will be uploaded to the `{}` namespace'.format(
  906. config.get('HIVE_NAMESPACE')))
  907. full_table_name = '{}.{}'.format(
  908. config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name)
  909. else:
  910. if '.' in table_name and schema_name:
  911. raise Exception(
  912. "You can't specify a namespace both in the name of the table "
  913. 'and in the schema field. Please remove one')
  914. full_table_name = '{}.{}'.format(
  915. schema_name, table_name) if schema_name else table_name
  916. filename = form.csv_file.data.filename
  917. upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
  918. upload_path = config['UPLOAD_FOLDER'] + \
  919. secure_filename(filename)
  920. hive_table_schema = Table(upload_path).infer()
  921. column_name_and_type = []
  922. for column_info in hive_table_schema['fields']:
  923. column_name_and_type.append(
  924. '`{}` {}'.format(
  925. column_info['name'],
  926. convert_to_hive_type(column_info['type'])))
  927. schema_definition = ', '.join(column_name_and_type)
  928. s3 = boto3.client('s3')
  929. location = os.path.join('s3a://', bucket_path, upload_prefix, table_name)
  930. s3.upload_file(
  931. upload_path, bucket_path,
  932. os.path.join(upload_prefix, table_name, filename))
  933. sql = """CREATE TABLE {full_table_name} ( {schema_definition} )
  934. ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
  935. TEXTFILE LOCATION '{location}'
  936. tblproperties ('skip.header.line.count'='1')""".format(**locals())
  937. logging.info(form.con.data)
  938. engine = create_engine(form.con.data.sqlalchemy_uri_decrypted)
  939. engine.execute(sql)
  940. @classmethod
  941. def convert_dttm(cls, target_type, dttm):
  942. tt = target_type.upper()
  943. if tt == 'DATE':
  944. return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
  945. elif tt == 'TIMESTAMP':
  946. return "CAST('{}' AS TIMESTAMP)".format(
  947. dttm.strftime('%Y-%m-%d %H:%M:%S'))
  948. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  949. @classmethod
  950. def adjust_database_uri(cls, uri, selected_schema=None):
  951. if selected_schema:
  952. uri.database = selected_schema
  953. return uri
  954. @classmethod
  955. def extract_error_message(cls, e):
  956. msg = str(e)
  957. match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
  958. if match:
  959. msg = match.group(1)
  960. return msg
  961. @classmethod
  962. def progress(cls, log_lines):
  963. total_jobs = 1 # assuming there's at least 1 job
  964. current_job = 1
  965. stages = {}
  966. for line in log_lines:
  967. match = cls.jobs_stats_r.match(line)
  968. if match:
  969. total_jobs = int(match.groupdict()['max_jobs']) or 1
  970. match = cls.launching_job_r.match(line)
  971. if match:
  972. current_job = int(match.groupdict()['job_number'])
  973. total_jobs = int(match.groupdict()['max_jobs']) or 1
  974. stages = {}
  975. match = cls.stage_progress_r.match(line)
  976. if match:
  977. stage_number = int(match.groupdict()['stage_number'])
  978. map_progress = int(match.groupdict()['map_progress'])
  979. reduce_progress = int(match.groupdict()['reduce_progress'])
  980. stages[stage_number] = (map_progress + reduce_progress) / 2
  981. logging.info(
  982. 'Progress detail: {}, '
  983. 'current job {}, '
  984. 'total jobs: {}'.format(stages, current_job, total_jobs))
  985. stage_progress = sum(
  986. stages.values()) / len(stages.values()) if stages else 0
  987. progress = (
  988. 100 * (current_job - 1) / total_jobs + stage_progress / total_jobs
  989. )
  990. return int(progress)
  991. @classmethod
  992. def get_tracking_url(cls, log_lines):
  993. lkp = 'Tracking URL = '
  994. for line in log_lines:
  995. if lkp in line:
  996. return line.split(lkp)[1]
  997. @classmethod
  998. def handle_cursor(cls, cursor, query, session):
  999. """Updates progress information"""
  1000. from pyhive import hive
  1001. unfinished_states = (
  1002. hive.ttypes.TOperationState.INITIALIZED_STATE,
  1003. hive.ttypes.TOperationState.RUNNING_STATE,
  1004. )
  1005. polled = cursor.poll()
  1006. last_log_line = 0
  1007. tracking_url = None
  1008. job_id = None
  1009. while polled.operationState in unfinished_states:
  1010. query = session.query(type(query)).filter_by(id=query.id).one()
  1011. if query.status == QueryStatus.STOPPED:
  1012. cursor.cancel()
  1013. break
  1014. log = cursor.fetch_logs() or ''
  1015. if log:
  1016. log_lines = log.splitlines()
  1017. progress = cls.progress(log_lines)
  1018. logging.info('Progress total: {}'.format(progress))
  1019. needs_commit = False
  1020. if progress > query.progress:
  1021. query.progress = progress
  1022. needs_commit = True
  1023. if not tracking_url:
  1024. tracking_url = cls.get_tracking_url(log_lines)
  1025. if tracking_url:
  1026. job_id = tracking_url.split('/')[-2]
  1027. logging.info(
  1028. 'Found the tracking url: {}'.format(tracking_url))
  1029. tracking_url = tracking_url_trans(tracking_url)
  1030. logging.info(
  1031. 'Transformation applied: {}'.format(tracking_url))
  1032. query.tracking_url = tracking_url
  1033. logging.info('Job id: {}'.format(job_id))
  1034. needs_commit = True
  1035. if job_id and len(log_lines) > last_log_line:
  1036. # Wait for job id before logging things out
  1037. # this allows for prefixing all log lines and becoming
  1038. # searchable in something like Kibana
  1039. for l in log_lines[last_log_line:]:
  1040. logging.info('[{}] {}'.format(job_id, l))
  1041. last_log_line = len(log_lines)
  1042. if needs_commit:
  1043. session.commit()
  1044. time.sleep(hive_poll_interval)
  1045. polled = cursor.poll()
  1046. @classmethod
  1047. def where_latest_partition(
  1048. cls, table_name, schema, database, qry, columns=None):
  1049. try:
  1050. col_name, value = cls.latest_partition(
  1051. table_name, schema, database, show_first=True)
  1052. except Exception:
  1053. # table is not partitioned
  1054. return False
  1055. for c in columns:
  1056. if str(c.name) == str(col_name):
  1057. return qry.where(c == str(value))
  1058. return False
  1059. @classmethod
  1060. def latest_sub_partition(cls, table_name, schema, database, **kwargs):
  1061. # TODO(bogdan): implement`
  1062. pass
  1063. @classmethod
  1064. def _latest_partition_from_df(cls, df):
  1065. """Hive partitions look like ds={partition name}"""
  1066. return df.ix[:, 0].max().split('=')[1]
  1067. @classmethod
  1068. def _partition_query(
  1069. cls, table_name, limit=0, order_by=None, filters=None):
  1070. return 'SHOW PARTITIONS {table_name}'.format(**locals())
  1071. @classmethod
  1072. def modify_url_for_impersonation(cls, url, impersonate_user, username):
  1073. """
  1074. Modify the SQL Alchemy URL object with the user to impersonate if applicable.
  1075. :param url: SQLAlchemy URL object
  1076. :param impersonate_user: Bool indicating if impersonation is enabled
  1077. :param username: Effective username
  1078. """
  1079. # Do nothing in the URL object since instead this should modify
  1080. # the configuraiton dictionary. See get_configuration_for_impersonation
  1081. pass
  1082. @classmethod
  1083. def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
  1084. """
  1085. Return a configuration dictionary that can be merged with other configs
  1086. that can set the correct properties for impersonating users
  1087. :param uri: URI string
  1088. :param impersonate_user: Bool indicating if impersonation is enabled
  1089. :param username: Effective username
  1090. :return: Dictionary with configs required for impersonation
  1091. """
  1092. configuration = {}
  1093. url = make_url(uri)
  1094. backend_name = url.get_backend_name()
  1095. # Must be Hive connection, enable impersonation, and set param auth=LDAP|KERBEROS
  1096. if (backend_name == 'hive' and 'auth' in url.query.keys() and
  1097. impersonate_user is True and username is not None):
  1098. configuration['hive.server2.proxy.user'] = username
  1099. return configuration
  1100. @staticmethod
  1101. def execute(cursor, query, async_=False):
  1102. kwargs = {'async': async_}
  1103. cursor.execute(query, **kwargs)
  1104. class MssqlEngineSpec(BaseEngineSpec):
  1105. engine = 'mssql'
  1106. epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')"
  1107. limit_method = LimitMethod.WRAP_SQL
  1108. time_grain_functions = {
  1109. None: '{col}',
  1110. 'PT1S': "DATEADD(second, DATEDIFF(second, '2000-01-01', {col}), '2000-01-01')",
  1111. 'PT1M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}), 0)',
  1112. 'PT5M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 5 * 5, 0)',
  1113. 'PT10M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 10 * 10, 0)',
  1114. 'PT15M': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 15 * 15, 0)',
  1115. 'PT0.5H': 'DATEADD(minute, DATEDIFF(minute, 0, {col}) / 30 * 30, 0)',
  1116. 'PT1H': 'DATEADD(hour, DATEDIFF(hour, 0, {col}), 0)',
  1117. 'P1D': 'DATEADD(day, DATEDIFF(day, 0, {col}), 0)',
  1118. 'P1W': 'DATEADD(week, DATEDIFF(week, 0, {col}), 0)',
  1119. 'P1M': 'DATEADD(month, DATEDIFF(month, 0, {col}), 0)',
  1120. 'P0.25Y': 'DATEADD(quarter, DATEDIFF(quarter, 0, {col}), 0)',
  1121. 'P1Y': 'DATEADD(year, DATEDIFF(year, 0, {col}), 0)',
  1122. }
  1123. @classmethod
  1124. def convert_dttm(cls, target_type, dttm):
  1125. return "CONVERT(DATETIME, '{}', 126)".format(dttm.isoformat())
  1126. class AthenaEngineSpec(BaseEngineSpec):
  1127. engine = 'awsathena'
  1128. time_grain_functions = {
  1129. None: '{col}',
  1130. 'PT1S': "date_trunc('second', CAST({col} AS TIMESTAMP))",
  1131. 'PT1M': "date_trunc('minute', CAST({col} AS TIMESTAMP))",
  1132. 'PT1H': "date_trunc('hour', CAST({col} AS TIMESTAMP))",
  1133. 'P1D': "date_trunc('day', CAST({col} AS TIMESTAMP))",
  1134. 'P1W': "date_trunc('week', CAST({col} AS TIMESTAMP))",
  1135. 'P1M': "date_trunc('month', CAST({col} AS TIMESTAMP))",
  1136. 'P0.25Y': "date_trunc('quarter', CAST({col} AS TIMESTAMP))",
  1137. 'P1Y': "date_trunc('year', CAST({col} AS TIMESTAMP))",
  1138. 'P1W/1970-01-03T00:00:00Z': "date_add('day', 5, date_trunc('week', \
  1139. date_add('day', 1, CAST({col} AS TIMESTAMP))))",
  1140. '1969-12-28T00:00:00Z/P1W': "date_add('day', -1, date_trunc('week', \
  1141. date_add('day', 1, CAST({col} AS TIMESTAMP))))",
  1142. }
  1143. @classmethod
  1144. def convert_dttm(cls, target_type, dttm):
  1145. tt = target_type.upper()
  1146. if tt == 'DATE':
  1147. return "from_iso8601_date('{}')".format(dttm.isoformat()[:10])
  1148. if tt == 'TIMESTAMP':
  1149. return "from_iso8601_timestamp('{}')".format(dttm.isoformat())
  1150. return ("CAST ('{}' AS TIMESTAMP)"
  1151. .format(dttm.strftime('%Y-%m-%d %H:%M:%S')))
  1152. @classmethod
  1153. def epoch_to_dttm(cls):
  1154. return 'from_unixtime({col})'
  1155. class ClickHouseEngineSpec(BaseEngineSpec):
  1156. """Dialect for ClickHouse analytical DB."""
  1157. engine = 'clickhouse'
  1158. time_secondary_columns = True
  1159. time_groupby_inline = True
  1160. time_grain_functions = {
  1161. None: '{col}',
  1162. 'PT1M': 'toStartOfMinute(toDateTime({col}))',
  1163. 'PT5M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 300)*300)',
  1164. 'PT10M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 600)*600)',
  1165. 'PT15M': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 900)*900)',
  1166. 'PT0.5H': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 1800)*1800)',
  1167. 'PT1H': 'toStartOfHour(toDateTime({col}))',
  1168. 'P1D': 'toStartOfDay(toDateTime({col}))',
  1169. 'P1W': 'toMonday(toDateTime({col}))',
  1170. 'P1M': 'toStartOfMonth(toDateTime({col}))',
  1171. 'P0.25Y': 'toStartOfQuarter(toDateTime({col}))',
  1172. 'P1Y': 'toStartOfYear(toDateTime({col}))',
  1173. }
  1174. @classmethod
  1175. def convert_dttm(cls, target_type, dttm):
  1176. tt = target_type.upper()
  1177. if tt == 'DATE':
  1178. return "toDate('{}')".format(dttm.strftime('%Y-%m-%d'))
  1179. if tt == 'DATETIME':
  1180. return "toDateTime('{}')".format(
  1181. dttm.strftime('%Y-%m-%d %H:%M:%S'))
  1182. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  1183. class BQEngineSpec(BaseEngineSpec):
  1184. """Engine spec for Google's BigQuery
  1185. As contributed by @mxmzdlv on issue #945"""
  1186. engine = 'bigquery'
  1187. """
  1188. https://www.python.org/dev/peps/pep-0249/#arraysize
  1189. raw_connections bypass the pybigquery query execution context and deal with
  1190. raw dbapi connection directly.
  1191. If this value is not set, the default value is set to 1, as described here,
  1192. https://googlecloudplatform.github.io/google-cloud-python/latest/_modules/google/cloud/bigquery/dbapi/cursor.html#Cursor
  1193. The default value of 5000 is derived from the pybigquery.
  1194. https://github.com/mxmzdlv/pybigquery/blob/d214bb089ca0807ca9aaa6ce4d5a01172d40264e/pybigquery/sqlalchemy_bigquery.py#L102
  1195. """
  1196. arraysize = 5000
  1197. time_grain_functions = {
  1198. None: '{col}',
  1199. 'PT1S': 'TIMESTAMP_TRUNC({col}, SECOND)',
  1200. 'PT1M': 'TIMESTAMP_TRUNC({col}, MINUTE)',
  1201. 'PT1H': 'TIMESTAMP_TRUNC({col}, HOUR)',
  1202. 'P1D': 'TIMESTAMP_TRUNC({col}, DAY)',
  1203. 'P1W': 'TIMESTAMP_TRUNC({col}, WEEK)',
  1204. 'P1M': 'TIMESTAMP_TRUNC({col}, MONTH)',
  1205. 'P0.25Y': 'TIMESTAMP_TRUNC({col}, QUARTER)',
  1206. 'P1Y': 'TIMESTAMP_TRUNC({col}, YEAR)',
  1207. }
  1208. @classmethod
  1209. def convert_dttm(cls, target_type, dttm):
  1210. tt = target_type.upper()
  1211. if tt == 'DATE':
  1212. return "'{}'".format(dttm.strftime('%Y-%m-%d'))
  1213. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  1214. @classmethod
  1215. def fetch_data(cls, cursor, limit):
  1216. data = super(BQEngineSpec, cls).fetch_data(cursor, limit)
  1217. if len(data) != 0 and type(data[0]).__name__ == 'Row':
  1218. data = [r.values() for r in data]
  1219. return data
  1220. @staticmethod
  1221. def mutate_expression_label(label):
  1222. mutated_label = re.sub('[^\w]+', '_', label)
  1223. if not re.match('^[a-zA-Z_]+.*', mutated_label):
  1224. raise SupersetTemplateException('BigQuery field_name used is invalid {}, '
  1225. 'should start with a letter or '
  1226. 'underscore'.format(mutated_label))
  1227. if len(mutated_label) > 128:
  1228. raise SupersetTemplateException('BigQuery field_name {}, should be atmost '
  1229. '128 characters'.format(mutated_label))
  1230. return mutated_label
  1231. @classmethod
  1232. def _get_fields(cls, cols):
  1233. """
  1234. BigQuery dialect requires us to not use backtick in the fieldname which are
  1235. nested.
  1236. Using literal_column handles that issue.
  1237. http://docs.sqlalchemy.org/en/latest/core/tutorial.html#using-more-specific-text-with-table-literal-column-and-column
  1238. Also explicility specifying column names so we don't encounter duplicate
  1239. column names in the result.
  1240. """
  1241. return [sqla.literal_column(c.get('name')).label(c.get('name').replace('.', '__'))
  1242. for c in cols]
  1243. class ImpalaEngineSpec(BaseEngineSpec):
  1244. """Engine spec for Cloudera's Impala"""
  1245. engine = 'impala'
  1246. time_grain_functions = {
  1247. None: '{col}',
  1248. 'PT1M': "TRUNC({col}, 'MI')",
  1249. 'PT1H': "TRUNC({col}, 'HH')",
  1250. 'P1D': "TRUNC({col}, 'DD')",
  1251. 'P1W': "TRUNC({col}, 'WW')",
  1252. 'P1M': "TRUNC({col}, 'MONTH')",
  1253. 'P0.25Y': "TRUNC({col}, 'Q')",
  1254. 'P1Y': "TRUNC({col}, 'YYYY')",
  1255. }
  1256. @classmethod
  1257. def epoch_to_dttm(cls):
  1258. return 'from_unixtime({col})'
  1259. @classmethod
  1260. def convert_dttm(cls, target_type, dttm):
  1261. tt = target_type.upper()
  1262. if tt == 'DATE':
  1263. return "'{}'".format(dttm.strftime('%Y-%m-%d'))
  1264. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  1265. @classmethod
  1266. @cache_util.memoized_func(
  1267. enable_cache=lambda *args, **kwargs: kwargs.get('enable_cache', False),
  1268. timeout=lambda *args, **kwargs: kwargs.get('cache_timeout'),
  1269. key=lambda *args, **kwargs: 'db:{}:schema_list'.format(kwargs.get('db_id')))
  1270. def get_schema_names(cls, inspector, db_id,
  1271. enable_cache, cache_timeout, force=False):
  1272. schemas = [row[0] for row in inspector.engine.execute('SHOW SCHEMAS')
  1273. if not row[0].startswith('_')]
  1274. return schemas
  1275. class DruidEngineSpec(BaseEngineSpec):
  1276. """Engine spec for Druid.io"""
  1277. engine = 'druid'
  1278. inner_joins = False
  1279. allows_subquery = False
  1280. time_grain_functions = {
  1281. None: '{col}',
  1282. 'PT1S': 'FLOOR({col} TO SECOND)',
  1283. 'PT1M': 'FLOOR({col} TO MINUTE)',
  1284. 'PT1H': 'FLOOR({col} TO HOUR)',
  1285. 'P1D': 'FLOOR({col} TO DAY)',
  1286. 'P1W': 'FLOOR({col} TO WEEK)',
  1287. 'P1M': 'FLOOR({col} TO MONTH)',
  1288. 'P0.25Y': 'FLOOR({col} TO QUARTER)',
  1289. 'P1Y': 'FLOOR({col} TO YEAR)',
  1290. }
  1291. class KylinEngineSpec(BaseEngineSpec):
  1292. """Dialect for Apache Kylin"""
  1293. engine = 'kylin'
  1294. time_grain_functions = {
  1295. None: '{col}',
  1296. 'PT1S': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO SECOND) AS TIMESTAMP)',
  1297. 'PT1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MINUTE) AS TIMESTAMP)',
  1298. 'PT1H': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO HOUR) AS TIMESTAMP)',
  1299. 'P1D': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO DAY) AS DATE)',
  1300. 'P1W': 'CAST(TIMESTAMPADD(WEEK, WEEK(CAST({col} AS DATE)) - 1, \
  1301. FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)',
  1302. 'P1M': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO MONTH) AS DATE)',
  1303. 'P0.25Y': 'CAST(TIMESTAMPADD(QUARTER, QUARTER(CAST({col} AS DATE)) - 1, \
  1304. FLOOR(CAST({col} AS TIMESTAMP) TO YEAR)) AS DATE)',
  1305. 'P1Y': 'CAST(FLOOR(CAST({col} AS TIMESTAMP) TO YEAR) AS DATE)',
  1306. }
  1307. @classmethod
  1308. def convert_dttm(cls, target_type, dttm):
  1309. tt = target_type.upper()
  1310. if tt == 'DATE':
  1311. return "CAST('{}' AS DATE)".format(dttm.isoformat()[:10])
  1312. if tt == 'TIMESTAMP':
  1313. return "CAST('{}' AS TIMESTAMP)".format(
  1314. dttm.strftime('%Y-%m-%d %H:%M:%S'))
  1315. return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  1316. class TeradataEngineSpec(BaseEngineSpec):
  1317. """Dialect for Teradata DB."""
  1318. engine = 'teradata'
  1319. limit_method = LimitMethod.WRAP_SQL
  1320. time_grains = (
  1321. Grain('Time Column', _('Time Column'), '{col}', None),
  1322. Grain('minute', _('minute'), "TRUNC(CAST({col} as DATE), 'MI')", 'PT1M'),
  1323. Grain('hour', _('hour'), "TRUNC(CAST({col} as DATE), 'HH')", 'PT1H'),
  1324. Grain('day', _('day'), "TRUNC(CAST({col} as DATE), 'DDD')", 'P1D'),
  1325. Grain('week', _('week'), "TRUNC(CAST({col} as DATE), 'WW')", 'P1W'),
  1326. Grain('month', _('month'), "TRUNC(CAST({col} as DATE), 'MONTH')", 'P1M'),
  1327. Grain('quarter', _('quarter'), "TRUNC(CAST({col} as DATE), 'Q')", 'P0.25Y'),
  1328. Grain('year', _('year'), "TRUNC(CAST({col} as DATE), 'YEAR')", 'P1Y'),
  1329. )
  1330. engines = {
  1331. o.engine: o for o in globals().values()
  1332. if inspect.isclass(o) and issubclass(o, BaseEngineSpec)}