models.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907
  1. # pylint: disable=C,R,W
  2. from datetime import datetime
  3. import logging
  4. from flask import escape, Markup
  5. from flask_appbuilder import Model
  6. from flask_babel import lazy_gettext as _
  7. import pandas as pd
  8. import sqlalchemy as sa
  9. from sqlalchemy import (
  10. and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_,
  11. select, String, Text,
  12. )
  13. from sqlalchemy.orm import backref, relationship
  14. from sqlalchemy.schema import UniqueConstraint
  15. from sqlalchemy.sql import column, literal_column, table, text
  16. from sqlalchemy.sql.expression import TextAsFrom
  17. import sqlparse
  18. from superset import app, db, import_util, security_manager, utils
  19. from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
  20. from superset.jinja_context import get_template_processor
  21. from superset.models.annotations import Annotation
  22. from superset.models.core import Database
  23. from superset.models.helpers import QueryResult
  24. from superset.utils import DTTM_ALIAS, QueryStatus
  25. config = app.config
  26. class AnnotationDatasource(BaseDatasource):
  27. """ Dummy object so we can query annotations using 'Viz' objects just like
  28. regular datasources.
  29. """
  30. cache_timeout = 0
  31. def query(self, query_obj):
  32. df = None
  33. error_message = None
  34. qry = db.session.query(Annotation)
  35. qry = qry.filter(Annotation.layer_id == query_obj['filter'][0]['val'])
  36. if query_obj['from_dttm']:
  37. qry = qry.filter(Annotation.start_dttm >= query_obj['from_dttm'])
  38. if query_obj['to_dttm']:
  39. qry = qry.filter(Annotation.end_dttm <= query_obj['to_dttm'])
  40. status = QueryStatus.SUCCESS
  41. try:
  42. df = pd.read_sql_query(qry.statement, db.engine)
  43. except Exception as e:
  44. status = QueryStatus.FAILED
  45. logging.exception(e)
  46. error_message = (
  47. utils.error_msg_from_exception(e))
  48. return QueryResult(
  49. status=status,
  50. df=df,
  51. duration=0,
  52. query='',
  53. error_message=error_message)
  54. def get_query_str(self, query_obj):
  55. raise NotImplementedError()
  56. def values_for_column(self, column_name, limit=10000):
  57. raise NotImplementedError()
  58. class TableColumn(Model, BaseColumn):
  59. """ORM object for table columns, each table can have multiple columns"""
  60. __tablename__ = 'table_columns'
  61. __table_args__ = (UniqueConstraint('table_id', 'column_name'),)
  62. table_id = Column(Integer, ForeignKey('tables.id'))
  63. table = relationship(
  64. 'SqlaTable',
  65. backref=backref('columns', cascade='all, delete-orphan'),
  66. foreign_keys=[table_id])
  67. is_dttm = Column(Boolean, default=False)
  68. expression = Column(Text, default='')
  69. python_date_format = Column(String(255))
  70. database_expression = Column(String(255))
  71. export_fields = (
  72. 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active',
  73. 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min',
  74. 'filterable', 'expression', 'description', 'python_date_format',
  75. 'database_expression',
  76. )
  77. update_from_object_fields = [
  78. s for s in export_fields if s not in ('table_id',)]
  79. export_parent = 'table'
  80. def get_sqla_col(self, label=None):
  81. db_engine_spec = self.table.database.db_engine_spec
  82. label = db_engine_spec.make_label_compatible(label if label else self.column_name)
  83. if not self.expression:
  84. col = column(self.column_name).label(label)
  85. else:
  86. col = literal_column(self.expression).label(label)
  87. return col
  88. @property
  89. def datasource(self):
  90. return self.table
  91. def get_time_filter(self, start_dttm, end_dttm):
  92. col = self.get_sqla_col(label='__time')
  93. l = [] # noqa: E741
  94. if start_dttm:
  95. l.append(col >= text(self.dttm_sql_literal(start_dttm)))
  96. if end_dttm:
  97. l.append(col <= text(self.dttm_sql_literal(end_dttm)))
  98. return and_(*l)
  99. def get_timestamp_expression(self, time_grain):
  100. """Getting the time component of the query"""
  101. pdf = self.python_date_format
  102. is_epoch = pdf in ('epoch_s', 'epoch_ms')
  103. if not self.expression and not time_grain and not is_epoch:
  104. return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)
  105. expr = self.expression or self.column_name
  106. if is_epoch:
  107. # if epoch, translate to DATE using db specific conf
  108. db_spec = self.table.database.db_engine_spec
  109. if pdf == 'epoch_s':
  110. expr = db_spec.epoch_to_dttm().format(col=expr)
  111. elif pdf == 'epoch_ms':
  112. expr = db_spec.epoch_ms_to_dttm().format(col=expr)
  113. if time_grain:
  114. grain = self.table.database.grains_dict().get(time_grain)
  115. if grain:
  116. expr = grain.function.format(col=expr)
  117. return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
  118. @classmethod
  119. def import_obj(cls, i_column):
  120. def lookup_obj(lookup_column):
  121. return db.session.query(TableColumn).filter(
  122. TableColumn.table_id == lookup_column.table_id,
  123. TableColumn.column_name == lookup_column.column_name).first()
  124. return import_util.import_simple_obj(db.session, i_column, lookup_obj)
  125. def dttm_sql_literal(self, dttm):
  126. """Convert datetime object to a SQL expression string
  127. If database_expression is empty, the internal dttm
  128. will be parsed as the string with the pattern that
  129. the user inputted (python_date_format)
  130. If database_expression is not empty, the internal dttm
  131. will be parsed as the sql sentence for the database to convert
  132. """
  133. tf = self.python_date_format
  134. if self.database_expression:
  135. return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
  136. elif tf:
  137. if tf == 'epoch_s':
  138. return str((dttm - datetime(1970, 1, 1)).total_seconds())
  139. elif tf == 'epoch_ms':
  140. return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
  141. return "'{}'".format(dttm.strftime(tf))
  142. else:
  143. s = self.table.database.db_engine_spec.convert_dttm(
  144. self.type or '', dttm)
  145. return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))
  146. def get_metrics(self):
  147. # TODO deprecate, this is not needed since MetricsControl
  148. metrics = []
  149. M = SqlMetric # noqa
  150. quoted = self.column_name
  151. if self.sum:
  152. metrics.append(M(
  153. metric_name='sum__' + self.column_name,
  154. metric_type='sum',
  155. expression='SUM({})'.format(quoted),
  156. ))
  157. if self.avg:
  158. metrics.append(M(
  159. metric_name='avg__' + self.column_name,
  160. metric_type='avg',
  161. expression='AVG({})'.format(quoted),
  162. ))
  163. if self.max:
  164. metrics.append(M(
  165. metric_name='max__' + self.column_name,
  166. metric_type='max',
  167. expression='MAX({})'.format(quoted),
  168. ))
  169. if self.min:
  170. metrics.append(M(
  171. metric_name='min__' + self.column_name,
  172. metric_type='min',
  173. expression='MIN({})'.format(quoted),
  174. ))
  175. if self.count_distinct:
  176. metrics.append(M(
  177. metric_name='count_distinct__' + self.column_name,
  178. metric_type='count_distinct',
  179. expression='COUNT(DISTINCT {})'.format(quoted),
  180. ))
  181. return {m.metric_name: m for m in metrics}
  182. class SqlMetric(Model, BaseMetric):
  183. """ORM object for metrics, each table can have multiple metrics"""
  184. __tablename__ = 'sql_metrics'
  185. __table_args__ = (UniqueConstraint('table_id', 'metric_name'),)
  186. table_id = Column(Integer, ForeignKey('tables.id'))
  187. table = relationship(
  188. 'SqlaTable',
  189. backref=backref('metrics', cascade='all, delete-orphan'),
  190. foreign_keys=[table_id])
  191. expression = Column(Text)
  192. export_fields = (
  193. 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
  194. 'description', 'is_restricted', 'd3format', 'warning_text')
  195. update_from_object_fields = list([
  196. s for s in export_fields if s not in ('table_id', )])
  197. export_parent = 'table'
  198. def get_sqla_col(self, label=None):
  199. db_engine_spec = self.table.database.db_engine_spec
  200. label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
  201. return literal_column(self.expression).label(label)
  202. @property
  203. def perm(self):
  204. return (
  205. '{parent_name}.[{obj.metric_name}](id:{obj.id})'
  206. ).format(obj=self,
  207. parent_name=self.table.full_name) if self.table else None
  208. @classmethod
  209. def import_obj(cls, i_metric):
  210. def lookup_obj(lookup_metric):
  211. return db.session.query(SqlMetric).filter(
  212. SqlMetric.table_id == lookup_metric.table_id,
  213. SqlMetric.metric_name == lookup_metric.metric_name).first()
  214. return import_util.import_simple_obj(db.session, i_metric, lookup_obj)
  215. class SqlaTable(Model, BaseDatasource):
  216. """An ORM object for SqlAlchemy table references"""
  217. type = 'table'
  218. query_language = 'sql'
  219. metric_class = SqlMetric
  220. column_class = TableColumn
  221. __tablename__ = 'tables'
  222. __table_args__ = (UniqueConstraint('database_id', 'table_name'),)
  223. table_name = Column(String(250))
  224. main_dttm_col = Column(String(250))
  225. database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
  226. fetch_values_predicate = Column(String(1000))
  227. user_id = Column(Integer, ForeignKey('ab_user.id'))
  228. owner = relationship(
  229. security_manager.user_model,
  230. backref='tables',
  231. foreign_keys=[user_id])
  232. database = relationship(
  233. 'Database',
  234. backref=backref('tables', cascade='all, delete-orphan'),
  235. foreign_keys=[database_id])
  236. schema = Column(String(255))
  237. sql = Column(Text)
  238. is_sqllab_view = Column(Boolean, default=False)
  239. template_params = Column(Text)
  240. baselink = 'tablemodelview'
  241. export_fields = (
  242. 'table_name', 'main_dttm_col', 'description', 'default_endpoint',
  243. 'database_id', 'offset', 'cache_timeout', 'schema',
  244. 'sql', 'params', 'template_params', 'filter_select_enabled')
  245. update_from_object_fields = [
  246. f for f in export_fields if f not in ('table_name', 'database_id')]
  247. export_parent = 'database'
  248. export_children = ['metrics', 'columns']
  249. sqla_aggregations = {
  250. 'COUNT_DISTINCT': lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
  251. 'COUNT': sa.func.COUNT,
  252. 'SUM': sa.func.SUM,
  253. 'AVG': sa.func.AVG,
  254. 'MIN': sa.func.MIN,
  255. 'MAX': sa.func.MAX,
  256. }
  257. def __repr__(self):
  258. return self.name
  259. @property
  260. def connection(self):
  261. return str(self.database)
  262. @property
  263. def description_markeddown(self):
  264. return utils.markdown(self.description)
  265. @property
  266. def datasource_name(self):
  267. return self.table_name
  268. @property
  269. def database_name(self):
  270. return self.database.name
  271. @property
  272. def link(self):
  273. name = escape(self.name)
  274. anchor = '<a target="_blank" href="{self.explore_url}">{name}</a>'
  275. return Markup(anchor.format(**locals()))
  276. @property
  277. def schema_perm(self):
  278. """Returns schema permission if present, database one otherwise."""
  279. return security_manager.get_schema_perm(self.database, self.schema)
  280. def get_perm(self):
  281. return (
  282. '[{obj.database}].[{obj.table_name}]'
  283. '(id:{obj.id})').format(obj=self)
  284. @property
  285. def name(self):
  286. if not self.schema:
  287. return self.table_name
  288. return '{}.{}'.format(self.schema, self.table_name)
  289. @property
  290. def full_name(self):
  291. return utils.get_datasource_full_name(
  292. self.database, self.table_name, schema=self.schema)
  293. @property
  294. def dttm_cols(self):
  295. l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741
  296. if self.main_dttm_col and self.main_dttm_col not in l:
  297. l.append(self.main_dttm_col)
  298. return l
  299. @property
  300. def num_cols(self):
  301. return [c.column_name for c in self.columns if c.is_num]
  302. @property
  303. def any_dttm_col(self):
  304. cols = self.dttm_cols
  305. if cols:
  306. return cols[0]
  307. @property
  308. def html(self):
  309. t = ((c.column_name, c.type) for c in self.columns)
  310. df = pd.DataFrame(t)
  311. df.columns = ['field', 'type']
  312. return df.to_html(
  313. index=False,
  314. classes=(
  315. 'dataframe table table-striped table-bordered '
  316. 'table-condensed'))
  317. @property
  318. def sql_url(self):
  319. return self.database.sql_url + '?table_name=' + str(self.table_name)
  320. def external_metadata(self):
  321. cols = self.database.get_columns(self.table_name, schema=self.schema)
  322. for col in cols:
  323. col['type'] = '{}'.format(col['type'])
  324. return cols
  325. @property
  326. def time_column_grains(self):
  327. return {
  328. 'time_columns': self.dttm_cols,
  329. 'time_grains': [grain.name for grain in self.database.grains()],
  330. }
  331. @property
  332. def select_star(self):
  333. # show_cols and latest_partition set to false to avoid
  334. # the expensive cost of inspecting the DB
  335. return self.database.select_star(
  336. self.name, show_cols=False, latest_partition=False)
  337. def get_col(self, col_name):
  338. columns = self.columns
  339. for col in columns:
  340. if col_name == col.column_name:
  341. return col
  342. @property
  343. def data(self):
  344. d = super(SqlaTable, self).data
  345. if self.type == 'table':
  346. grains = self.database.grains() or []
  347. if grains:
  348. grains = [(g.duration, g.name) for g in grains]
  349. d['granularity_sqla'] = utils.choicify(self.dttm_cols)
  350. d['time_grain_sqla'] = grains
  351. d['main_dttm_col'] = self.main_dttm_col
  352. return d
  353. def values_for_column(self, column_name, limit=10000):
  354. """Runs query against sqla to retrieve some
  355. sample values for the given column.
  356. """
  357. cols = {col.column_name: col for col in self.columns}
  358. target_col = cols[column_name]
  359. tp = self.get_template_processor()
  360. qry = (
  361. select([target_col.get_sqla_col()])
  362. .select_from(self.get_from_clause(tp))
  363. .distinct()
  364. )
  365. if limit:
  366. qry = qry.limit(limit)
  367. if self.fetch_values_predicate:
  368. tp = self.get_template_processor()
  369. qry = qry.where(tp.process_template(self.fetch_values_predicate))
  370. engine = self.database.get_sqla_engine()
  371. sql = '{}'.format(
  372. qry.compile(engine, compile_kwargs={'literal_binds': True}),
  373. )
  374. sql = self.mutate_query_from_config(sql)
  375. df = pd.read_sql_query(sql=sql, con=engine)
  376. return [row[0] for row in df.to_records(index=False)]
  377. def mutate_query_from_config(self, sql):
  378. """Apply config's SQL_QUERY_MUTATOR
  379. Typically adds comments to the query with context"""
  380. SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
  381. if SQL_QUERY_MUTATOR:
  382. username = utils.get_username()
  383. sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database)
  384. return sql
  385. def get_template_processor(self, **kwargs):
  386. return get_template_processor(
  387. table=self, database=self.database, **kwargs)
  388. def get_query_str(self, query_obj):
  389. qry = self.get_sqla_query(**query_obj)
  390. sql = self.database.compile_sqla_query(qry)
  391. logging.info(sql)
  392. sql = sqlparse.format(sql, reindent=True)
  393. if query_obj['is_prequery']:
  394. query_obj['prequeries'].append(sql)
  395. sql = self.mutate_query_from_config(sql)
  396. return sql
  397. def get_sqla_table(self):
  398. tbl = table(self.table_name)
  399. if self.schema:
  400. tbl.schema = self.schema
  401. return tbl
  402. def get_from_clause(self, template_processor=None):
  403. # Supporting arbitrary SQL statements in place of tables
  404. if self.sql:
  405. from_sql = self.sql
  406. if template_processor:
  407. from_sql = template_processor.process_template(from_sql)
  408. from_sql = sqlparse.format(from_sql, strip_comments=True)
  409. return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
  410. return self.get_sqla_table()
  411. def adhoc_metric_to_sqla(self, metric, cols):
  412. """
  413. Turn an adhoc metric into a sqlalchemy column.
  414. :param dict metric: Adhoc metric definition
  415. :param dict cols: Columns for the current table
  416. :returns: The metric defined as a sqlalchemy column
  417. :rtype: sqlalchemy.sql.column
  418. """
  419. expression_type = metric.get('expressionType')
  420. db_engine_spec = self.database.db_engine_spec
  421. label = db_engine_spec.make_label_compatible(metric.get('label'))
  422. if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
  423. column_name = metric.get('column').get('column_name')
  424. sqla_column = column(column_name)
  425. table_column = cols.get(column_name)
  426. if table_column:
  427. sqla_column = table_column.get_sqla_col()
  428. sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
  429. sqla_metric = sqla_metric.label(label)
  430. return sqla_metric
  431. elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
  432. sqla_metric = literal_column(metric.get('sqlExpression'))
  433. sqla_metric = sqla_metric.label(label)
  434. return sqla_metric
  435. else:
  436. return None
  437. def get_sqla_query( # sqla
  438. self,
  439. groupby, metrics,
  440. granularity,
  441. from_dttm, to_dttm,
  442. filter=None, # noqa
  443. is_timeseries=True,
  444. timeseries_limit=15,
  445. timeseries_limit_metric=None,
  446. row_limit=None,
  447. inner_from_dttm=None,
  448. inner_to_dttm=None,
  449. orderby=None,
  450. extras=None,
  451. columns=None,
  452. order_desc=True,
  453. prequeries=None,
  454. is_prequery=False,
  455. ):
  456. """Querying any sqla table from this common interface"""
  457. template_kwargs = {
  458. 'from_dttm': from_dttm,
  459. 'groupby': groupby,
  460. 'metrics': metrics,
  461. 'row_limit': row_limit,
  462. 'to_dttm': to_dttm,
  463. 'filter': filter,
  464. 'columns': {col.column_name: col for col in self.columns},
  465. }
  466. template_kwargs.update(self.template_params_dict)
  467. template_processor = self.get_template_processor(**template_kwargs)
  468. db_engine_spec = self.database.db_engine_spec
  469. orderby = orderby or []
  470. # For backward compatibility
  471. if granularity not in self.dttm_cols:
  472. granularity = self.main_dttm_col
  473. # Database spec supports join-free timeslot grouping
  474. time_groupby_inline = db_engine_spec.time_groupby_inline
  475. cols = {col.column_name: col for col in self.columns}
  476. metrics_dict = {m.metric_name: m for m in self.metrics}
  477. if not granularity and is_timeseries:
  478. raise Exception(_(
  479. 'Datetime column not provided as part table configuration '
  480. 'and is required by this type of chart'))
  481. if not groupby and not metrics and not columns:
  482. raise Exception(_('Empty query?'))
  483. metrics_exprs = []
  484. for m in metrics:
  485. if utils.is_adhoc_metric(m):
  486. metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
  487. elif m in metrics_dict:
  488. metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
  489. else:
  490. raise Exception(_("Metric '{}' is not valid".format(m)))
  491. if metrics_exprs:
  492. main_metric_expr = metrics_exprs[0]
  493. else:
  494. main_metric_expr = literal_column('COUNT(*)').label(
  495. db_engine_spec.make_label_compatible('count'))
  496. select_exprs = []
  497. groupby_exprs = []
  498. if groupby:
  499. select_exprs = []
  500. inner_select_exprs = []
  501. inner_groupby_exprs = []
  502. for s in groupby:
  503. col = cols[s]
  504. outer = col.get_sqla_col()
  505. inner = col.get_sqla_col(col.column_name + '__')
  506. groupby_exprs.append(outer)
  507. select_exprs.append(outer)
  508. inner_groupby_exprs.append(inner)
  509. inner_select_exprs.append(inner)
  510. elif columns:
  511. for s in columns:
  512. select_exprs.append(cols[s].get_sqla_col())
  513. metrics_exprs = []
  514. if granularity:
  515. dttm_col = cols[granularity]
  516. time_grain = extras.get('time_grain_sqla')
  517. time_filters = []
  518. if is_timeseries:
  519. timestamp = dttm_col.get_timestamp_expression(time_grain)
  520. select_exprs += [timestamp]
  521. groupby_exprs += [timestamp]
  522. # Use main dttm column to support index with secondary dttm columns
  523. if db_engine_spec.time_secondary_columns and \
  524. self.main_dttm_col in self.dttm_cols and \
  525. self.main_dttm_col != dttm_col.column_name:
  526. time_filters.append(cols[self.main_dttm_col].
  527. get_time_filter(from_dttm, to_dttm))
  528. time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
  529. select_exprs += metrics_exprs
  530. qry = sa.select(select_exprs)
  531. tbl = self.get_from_clause(template_processor)
  532. if not columns:
  533. qry = qry.group_by(*groupby_exprs)
  534. where_clause_and = []
  535. having_clause_and = []
  536. for flt in filter:
  537. if not all([flt.get(s) for s in ['col', 'op']]):
  538. continue
  539. col = flt['col']
  540. op = flt['op']
  541. col_obj = cols.get(col)
  542. if col_obj:
  543. is_list_target = op in ('in', 'not in')
  544. eq = self.filter_values_handler(
  545. flt.get('val'),
  546. target_column_is_numeric=col_obj.is_num,
  547. is_list_target=is_list_target)
  548. if op in ('in', 'not in'):
  549. cond = col_obj.get_sqla_col().in_(eq)
  550. if '<NULL>' in eq:
  551. cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
  552. if op == 'not in':
  553. cond = ~cond
  554. where_clause_and.append(cond)
  555. else:
  556. if col_obj.is_num:
  557. eq = utils.string_to_num(flt['val'])
  558. if op == '==':
  559. where_clause_and.append(col_obj.get_sqla_col() == eq)
  560. elif op == '!=':
  561. where_clause_and.append(col_obj.get_sqla_col() != eq)
  562. elif op == '>':
  563. where_clause_and.append(col_obj.get_sqla_col() > eq)
  564. elif op == '<':
  565. where_clause_and.append(col_obj.get_sqla_col() < eq)
  566. elif op == '>=':
  567. where_clause_and.append(col_obj.get_sqla_col() >= eq)
  568. elif op == '<=':
  569. where_clause_and.append(col_obj.get_sqla_col() <= eq)
  570. elif op == 'LIKE':
  571. where_clause_and.append(col_obj.get_sqla_col().like(eq))
  572. elif op == 'IS NULL':
  573. where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
  574. elif op == 'IS NOT NULL':
  575. where_clause_and.append(
  576. col_obj.get_sqla_col() != None) # noqa
  577. if extras:
  578. where = extras.get('where')
  579. if where:
  580. where = template_processor.process_template(where)
  581. where_clause_and += [sa.text('({})'.format(where))]
  582. having = extras.get('having')
  583. if having:
  584. having = template_processor.process_template(having)
  585. having_clause_and += [sa.text('({})'.format(having))]
  586. if granularity:
  587. qry = qry.where(and_(*(time_filters + where_clause_and)))
  588. else:
  589. qry = qry.where(and_(*where_clause_and))
  590. qry = qry.having(and_(*having_clause_and))
  591. if not orderby and not columns:
  592. orderby = [(main_metric_expr, not order_desc)]
  593. for col, ascending in orderby:
  594. direction = asc if ascending else desc
  595. if utils.is_adhoc_metric(col):
  596. col = self.adhoc_metric_to_sqla(col, cols)
  597. qry = qry.order_by(direction(col))
  598. if row_limit:
  599. qry = qry.limit(row_limit)
  600. if is_timeseries and \
  601. timeseries_limit and groupby and not time_groupby_inline:
  602. if self.database.db_engine_spec.inner_joins:
  603. # some sql dialects require for order by expressions
  604. # to also be in the select clause -- others, e.g. vertica,
  605. # require a unique inner alias
  606. inner_main_metric_expr = main_metric_expr.label('mme_inner__')
  607. inner_select_exprs += [inner_main_metric_expr]
  608. subq = select(inner_select_exprs)
  609. subq = subq.select_from(tbl)
  610. inner_time_filter = dttm_col.get_time_filter(
  611. inner_from_dttm or from_dttm,
  612. inner_to_dttm or to_dttm,
  613. )
  614. subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
  615. subq = subq.group_by(*inner_groupby_exprs)
  616. ob = inner_main_metric_expr
  617. if timeseries_limit_metric:
  618. if utils.is_adhoc_metric(timeseries_limit_metric):
  619. ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
  620. elif timeseries_limit_metric in metrics_dict:
  621. timeseries_limit_metric = metrics_dict.get(
  622. timeseries_limit_metric,
  623. )
  624. ob = timeseries_limit_metric.get_sqla_col()
  625. else:
  626. raise Exception(_("Metric '{}' is not valid".format(m)))
  627. direction = desc if order_desc else asc
  628. subq = subq.order_by(direction(ob))
  629. subq = subq.limit(timeseries_limit)
  630. on_clause = []
  631. for i, gb in enumerate(groupby):
  632. on_clause.append(
  633. groupby_exprs[i] == column(gb + '__'))
  634. tbl = tbl.join(subq.alias(), and_(*on_clause))
  635. else:
  636. # run subquery to get top groups
  637. subquery_obj = {
  638. 'prequeries': prequeries,
  639. 'is_prequery': True,
  640. 'is_timeseries': False,
  641. 'row_limit': timeseries_limit,
  642. 'groupby': groupby,
  643. 'metrics': metrics,
  644. 'granularity': granularity,
  645. 'from_dttm': inner_from_dttm or from_dttm,
  646. 'to_dttm': inner_to_dttm or to_dttm,
  647. 'filter': filter,
  648. 'orderby': orderby,
  649. 'extras': extras,
  650. 'columns': columns,
  651. 'order_desc': True,
  652. }
  653. result = self.query(subquery_obj)
  654. cols = {col.column_name: col for col in self.columns}
  655. dimensions = [
  656. c for c in result.df.columns
  657. if c not in metrics and c in cols
  658. ]
  659. top_groups = self._get_top_groups(result.df, dimensions)
  660. qry = qry.where(top_groups)
  661. return qry.select_from(tbl)
  662. def _get_top_groups(self, df, dimensions):
  663. cols = {col.column_name: col for col in self.columns}
  664. groups = []
  665. for unused, row in df.iterrows():
  666. group = []
  667. for dimension in dimensions:
  668. col_obj = cols.get(dimension)
  669. group.append(col_obj.get_sqla_col() == row[dimension])
  670. groups.append(and_(*group))
  671. return or_(*groups)
  672. def query(self, query_obj):
  673. qry_start_dttm = datetime.now()
  674. sql = self.get_query_str(query_obj)
  675. status = QueryStatus.SUCCESS
  676. error_message = None
  677. df = None
  678. try:
  679. df = self.database.get_df(sql, self.schema)
  680. except Exception as e:
  681. status = QueryStatus.FAILED
  682. logging.exception(e)
  683. error_message = (
  684. self.database.db_engine_spec.extract_error_message(e))
  685. # if this is a main query with prequeries, combine them together
  686. if not query_obj['is_prequery']:
  687. query_obj['prequeries'].append(sql)
  688. sql = ';\n\n'.join(query_obj['prequeries'])
  689. sql += ';'
  690. return QueryResult(
  691. status=status,
  692. df=df,
  693. duration=datetime.now() - qry_start_dttm,
  694. query=sql,
  695. error_message=error_message)
  696. def get_sqla_table_object(self):
  697. return self.database.get_table(self.table_name, schema=self.schema)
  698. def fetch_metadata(self):
  699. """Fetches the metadata for the table and merges it in"""
  700. try:
  701. table = self.get_sqla_table_object()
  702. except Exception as e:
  703. logging.exception(e)
  704. raise Exception(_(
  705. "Table [{}] doesn't seem to exist in the specified database, "
  706. "couldn't fetch column information").format(self.table_name))
  707. M = SqlMetric # noqa
  708. metrics = []
  709. any_date_col = None
  710. db_dialect = self.database.get_dialect()
  711. dbcols = (
  712. db.session.query(TableColumn)
  713. .filter(TableColumn.table == self)
  714. .filter(or_(TableColumn.column_name == col.name
  715. for col in table.columns)))
  716. dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
  717. db_engine_spec = self.database.db_engine_spec
  718. for col in table.columns:
  719. try:
  720. datatype = col.type.compile(dialect=db_dialect).upper()
  721. except Exception as e:
  722. datatype = 'UNKNOWN'
  723. logging.error(
  724. 'Unrecognized data type in {}.{}'.format(table, col.name))
  725. logging.exception(e)
  726. dbcol = dbcols.get(col.name, None)
  727. if not dbcol:
  728. dbcol = TableColumn(column_name=col.name, type=datatype)
  729. dbcol.groupby = dbcol.is_string
  730. dbcol.filterable = dbcol.is_string
  731. dbcol.sum = dbcol.is_num
  732. dbcol.avg = dbcol.is_num
  733. dbcol.is_dttm = dbcol.is_time
  734. else:
  735. dbcol.type = datatype
  736. self.columns.append(dbcol)
  737. if not any_date_col and dbcol.is_time:
  738. any_date_col = col.name
  739. metrics += dbcol.get_metrics().values()
  740. metrics.append(M(
  741. metric_name='count',
  742. verbose_name='COUNT(*)',
  743. metric_type='count',
  744. expression='COUNT(*)',
  745. ))
  746. if not self.main_dttm_col:
  747. self.main_dttm_col = any_date_col
  748. for metric in metrics:
  749. metric.metric_name = db_engine_spec.mutate_expression_label(
  750. metric.metric_name)
  751. self.add_missing_metrics(metrics)
  752. db.session.merge(self)
  753. db.session.commit()
  754. @classmethod
  755. def import_obj(cls, i_datasource, import_time=None):
  756. """Imports the datasource from the object to the database.
  757. Metrics and columns and datasource will be overrided if exists.
  758. This function can be used to import/export dashboards between multiple
  759. superset instances. Audit metadata isn't copies over.
  760. """
  761. def lookup_sqlatable(table):
  762. return db.session.query(SqlaTable).join(Database).filter(
  763. SqlaTable.table_name == table.table_name,
  764. SqlaTable.schema == table.schema,
  765. Database.id == table.database_id,
  766. ).first()
  767. def lookup_database(table):
  768. return db.session.query(Database).filter_by(
  769. database_name=table.params_dict['database_name']).one()
  770. return import_util.import_datasource(
  771. db.session, i_datasource, lookup_database, lookup_sqlatable,
  772. import_time)
  773. @classmethod
  774. def query_datasources_by_name(
  775. cls, session, database, datasource_name, schema=None):
  776. query = (
  777. session.query(cls)
  778. .filter_by(database_id=database.id)
  779. .filter_by(table_name=datasource_name)
  780. )
  781. if schema:
  782. query = query.filter_by(schema=schema)
  783. return query.all()
  784. @staticmethod
  785. def default_query(qry):
  786. return qry.filter_by(is_sqllab_view=False)
  787. sa.event.listen(SqlaTable, 'after_insert', security_manager.set_perm)
  788. sa.event.listen(SqlaTable, 'after_update', security_manager.set_perm)