123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907 |
- # pylint: disable=C,R,W
- from datetime import datetime
- import logging
- from flask import escape, Markup
- from flask_appbuilder import Model
- from flask_babel import lazy_gettext as _
- import pandas as pd
- import sqlalchemy as sa
- from sqlalchemy import (
- and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_,
- select, String, Text,
- )
- from sqlalchemy.orm import backref, relationship
- from sqlalchemy.schema import UniqueConstraint
- from sqlalchemy.sql import column, literal_column, table, text
- from sqlalchemy.sql.expression import TextAsFrom
- import sqlparse
- from superset import app, db, import_util, security_manager, utils
- from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
- from superset.jinja_context import get_template_processor
- from superset.models.annotations import Annotation
- from superset.models.core import Database
- from superset.models.helpers import QueryResult
- from superset.utils import DTTM_ALIAS, QueryStatus
- config = app.config
- class AnnotationDatasource(BaseDatasource):
- """ Dummy object so we can query annotations using 'Viz' objects just like
- regular datasources.
- """
- cache_timeout = 0
- def query(self, query_obj):
- df = None
- error_message = None
- qry = db.session.query(Annotation)
- qry = qry.filter(Annotation.layer_id == query_obj['filter'][0]['val'])
- if query_obj['from_dttm']:
- qry = qry.filter(Annotation.start_dttm >= query_obj['from_dttm'])
- if query_obj['to_dttm']:
- qry = qry.filter(Annotation.end_dttm <= query_obj['to_dttm'])
- status = QueryStatus.SUCCESS
- try:
- df = pd.read_sql_query(qry.statement, db.engine)
- except Exception as e:
- status = QueryStatus.FAILED
- logging.exception(e)
- error_message = (
- utils.error_msg_from_exception(e))
- return QueryResult(
- status=status,
- df=df,
- duration=0,
- query='',
- error_message=error_message)
- def get_query_str(self, query_obj):
- raise NotImplementedError()
- def values_for_column(self, column_name, limit=10000):
- raise NotImplementedError()
- class TableColumn(Model, BaseColumn):
- """ORM object for table columns, each table can have multiple columns"""
- __tablename__ = 'table_columns'
- __table_args__ = (UniqueConstraint('table_id', 'column_name'),)
- table_id = Column(Integer, ForeignKey('tables.id'))
- table = relationship(
- 'SqlaTable',
- backref=backref('columns', cascade='all, delete-orphan'),
- foreign_keys=[table_id])
- is_dttm = Column(Boolean, default=False)
- expression = Column(Text, default='')
- python_date_format = Column(String(255))
- database_expression = Column(String(255))
- export_fields = (
- 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active',
- 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min',
- 'filterable', 'expression', 'description', 'python_date_format',
- 'database_expression',
- )
- update_from_object_fields = [
- s for s in export_fields if s not in ('table_id',)]
- export_parent = 'table'
- def get_sqla_col(self, label=None):
- db_engine_spec = self.table.database.db_engine_spec
- label = db_engine_spec.make_label_compatible(label if label else self.column_name)
- if not self.expression:
- col = column(self.column_name).label(label)
- else:
- col = literal_column(self.expression).label(label)
- return col
- @property
- def datasource(self):
- return self.table
- def get_time_filter(self, start_dttm, end_dttm):
- col = self.get_sqla_col(label='__time')
- l = [] # noqa: E741
- if start_dttm:
- l.append(col >= text(self.dttm_sql_literal(start_dttm)))
- if end_dttm:
- l.append(col <= text(self.dttm_sql_literal(end_dttm)))
- return and_(*l)
- def get_timestamp_expression(self, time_grain):
- """Getting the time component of the query"""
- pdf = self.python_date_format
- is_epoch = pdf in ('epoch_s', 'epoch_ms')
- if not self.expression and not time_grain and not is_epoch:
- return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)
- expr = self.expression or self.column_name
- if is_epoch:
- # if epoch, translate to DATE using db specific conf
- db_spec = self.table.database.db_engine_spec
- if pdf == 'epoch_s':
- expr = db_spec.epoch_to_dttm().format(col=expr)
- elif pdf == 'epoch_ms':
- expr = db_spec.epoch_ms_to_dttm().format(col=expr)
- if time_grain:
- grain = self.table.database.grains_dict().get(time_grain)
- if grain:
- expr = grain.function.format(col=expr)
- return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
- @classmethod
- def import_obj(cls, i_column):
- def lookup_obj(lookup_column):
- return db.session.query(TableColumn).filter(
- TableColumn.table_id == lookup_column.table_id,
- TableColumn.column_name == lookup_column.column_name).first()
- return import_util.import_simple_obj(db.session, i_column, lookup_obj)
- def dttm_sql_literal(self, dttm):
- """Convert datetime object to a SQL expression string
- If database_expression is empty, the internal dttm
- will be parsed as the string with the pattern that
- the user inputted (python_date_format)
- If database_expression is not empty, the internal dttm
- will be parsed as the sql sentence for the database to convert
- """
- tf = self.python_date_format
- if self.database_expression:
- return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
- elif tf:
- if tf == 'epoch_s':
- return str((dttm - datetime(1970, 1, 1)).total_seconds())
- elif tf == 'epoch_ms':
- return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0)
- return "'{}'".format(dttm.strftime(tf))
- else:
- s = self.table.database.db_engine_spec.convert_dttm(
- self.type or '', dttm)
- return s or "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S.%f'))
- def get_metrics(self):
- # TODO deprecate, this is not needed since MetricsControl
- metrics = []
- M = SqlMetric # noqa
- quoted = self.column_name
- if self.sum:
- metrics.append(M(
- metric_name='sum__' + self.column_name,
- metric_type='sum',
- expression='SUM({})'.format(quoted),
- ))
- if self.avg:
- metrics.append(M(
- metric_name='avg__' + self.column_name,
- metric_type='avg',
- expression='AVG({})'.format(quoted),
- ))
- if self.max:
- metrics.append(M(
- metric_name='max__' + self.column_name,
- metric_type='max',
- expression='MAX({})'.format(quoted),
- ))
- if self.min:
- metrics.append(M(
- metric_name='min__' + self.column_name,
- metric_type='min',
- expression='MIN({})'.format(quoted),
- ))
- if self.count_distinct:
- metrics.append(M(
- metric_name='count_distinct__' + self.column_name,
- metric_type='count_distinct',
- expression='COUNT(DISTINCT {})'.format(quoted),
- ))
- return {m.metric_name: m for m in metrics}
- class SqlMetric(Model, BaseMetric):
- """ORM object for metrics, each table can have multiple metrics"""
- __tablename__ = 'sql_metrics'
- __table_args__ = (UniqueConstraint('table_id', 'metric_name'),)
- table_id = Column(Integer, ForeignKey('tables.id'))
- table = relationship(
- 'SqlaTable',
- backref=backref('metrics', cascade='all, delete-orphan'),
- foreign_keys=[table_id])
- expression = Column(Text)
- export_fields = (
- 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression',
- 'description', 'is_restricted', 'd3format', 'warning_text')
- update_from_object_fields = list([
- s for s in export_fields if s not in ('table_id', )])
- export_parent = 'table'
- def get_sqla_col(self, label=None):
- db_engine_spec = self.table.database.db_engine_spec
- label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
- return literal_column(self.expression).label(label)
- @property
- def perm(self):
- return (
- '{parent_name}.[{obj.metric_name}](id:{obj.id})'
- ).format(obj=self,
- parent_name=self.table.full_name) if self.table else None
- @classmethod
- def import_obj(cls, i_metric):
- def lookup_obj(lookup_metric):
- return db.session.query(SqlMetric).filter(
- SqlMetric.table_id == lookup_metric.table_id,
- SqlMetric.metric_name == lookup_metric.metric_name).first()
- return import_util.import_simple_obj(db.session, i_metric, lookup_obj)
- class SqlaTable(Model, BaseDatasource):
- """An ORM object for SqlAlchemy table references"""
- type = 'table'
- query_language = 'sql'
- metric_class = SqlMetric
- column_class = TableColumn
- __tablename__ = 'tables'
- __table_args__ = (UniqueConstraint('database_id', 'table_name'),)
- table_name = Column(String(250))
- main_dttm_col = Column(String(250))
- database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
- fetch_values_predicate = Column(String(1000))
- user_id = Column(Integer, ForeignKey('ab_user.id'))
- owner = relationship(
- security_manager.user_model,
- backref='tables',
- foreign_keys=[user_id])
- database = relationship(
- 'Database',
- backref=backref('tables', cascade='all, delete-orphan'),
- foreign_keys=[database_id])
- schema = Column(String(255))
- sql = Column(Text)
- is_sqllab_view = Column(Boolean, default=False)
- template_params = Column(Text)
- baselink = 'tablemodelview'
- export_fields = (
- 'table_name', 'main_dttm_col', 'description', 'default_endpoint',
- 'database_id', 'offset', 'cache_timeout', 'schema',
- 'sql', 'params', 'template_params', 'filter_select_enabled')
- update_from_object_fields = [
- f for f in export_fields if f not in ('table_name', 'database_id')]
- export_parent = 'database'
- export_children = ['metrics', 'columns']
- sqla_aggregations = {
- 'COUNT_DISTINCT': lambda column_name: sa.func.COUNT(sa.distinct(column_name)),
- 'COUNT': sa.func.COUNT,
- 'SUM': sa.func.SUM,
- 'AVG': sa.func.AVG,
- 'MIN': sa.func.MIN,
- 'MAX': sa.func.MAX,
- }
- def __repr__(self):
- return self.name
- @property
- def connection(self):
- return str(self.database)
- @property
- def description_markeddown(self):
- return utils.markdown(self.description)
- @property
- def datasource_name(self):
- return self.table_name
- @property
- def database_name(self):
- return self.database.name
- @property
- def link(self):
- name = escape(self.name)
- anchor = '<a target="_blank" href="{self.explore_url}">{name}</a>'
- return Markup(anchor.format(**locals()))
- @property
- def schema_perm(self):
- """Returns schema permission if present, database one otherwise."""
- return security_manager.get_schema_perm(self.database, self.schema)
- def get_perm(self):
- return (
- '[{obj.database}].[{obj.table_name}]'
- '(id:{obj.id})').format(obj=self)
- @property
- def name(self):
- if not self.schema:
- return self.table_name
- return '{}.{}'.format(self.schema, self.table_name)
- @property
- def full_name(self):
- return utils.get_datasource_full_name(
- self.database, self.table_name, schema=self.schema)
- @property
- def dttm_cols(self):
- l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741
- if self.main_dttm_col and self.main_dttm_col not in l:
- l.append(self.main_dttm_col)
- return l
- @property
- def num_cols(self):
- return [c.column_name for c in self.columns if c.is_num]
- @property
- def any_dttm_col(self):
- cols = self.dttm_cols
- if cols:
- return cols[0]
- @property
- def html(self):
- t = ((c.column_name, c.type) for c in self.columns)
- df = pd.DataFrame(t)
- df.columns = ['field', 'type']
- return df.to_html(
- index=False,
- classes=(
- 'dataframe table table-striped table-bordered '
- 'table-condensed'))
- @property
- def sql_url(self):
- return self.database.sql_url + '?table_name=' + str(self.table_name)
- def external_metadata(self):
- cols = self.database.get_columns(self.table_name, schema=self.schema)
- for col in cols:
- col['type'] = '{}'.format(col['type'])
- return cols
- @property
- def time_column_grains(self):
- return {
- 'time_columns': self.dttm_cols,
- 'time_grains': [grain.name for grain in self.database.grains()],
- }
- @property
- def select_star(self):
- # show_cols and latest_partition set to false to avoid
- # the expensive cost of inspecting the DB
- return self.database.select_star(
- self.name, show_cols=False, latest_partition=False)
- def get_col(self, col_name):
- columns = self.columns
- for col in columns:
- if col_name == col.column_name:
- return col
- @property
- def data(self):
- d = super(SqlaTable, self).data
- if self.type == 'table':
- grains = self.database.grains() or []
- if grains:
- grains = [(g.duration, g.name) for g in grains]
- d['granularity_sqla'] = utils.choicify(self.dttm_cols)
- d['time_grain_sqla'] = grains
- d['main_dttm_col'] = self.main_dttm_col
- return d
- def values_for_column(self, column_name, limit=10000):
- """Runs query against sqla to retrieve some
- sample values for the given column.
- """
- cols = {col.column_name: col for col in self.columns}
- target_col = cols[column_name]
- tp = self.get_template_processor()
- qry = (
- select([target_col.get_sqla_col()])
- .select_from(self.get_from_clause(tp))
- .distinct()
- )
- if limit:
- qry = qry.limit(limit)
- if self.fetch_values_predicate:
- tp = self.get_template_processor()
- qry = qry.where(tp.process_template(self.fetch_values_predicate))
- engine = self.database.get_sqla_engine()
- sql = '{}'.format(
- qry.compile(engine, compile_kwargs={'literal_binds': True}),
- )
- sql = self.mutate_query_from_config(sql)
- df = pd.read_sql_query(sql=sql, con=engine)
- return [row[0] for row in df.to_records(index=False)]
- def mutate_query_from_config(self, sql):
- """Apply config's SQL_QUERY_MUTATOR
- Typically adds comments to the query with context"""
- SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
- if SQL_QUERY_MUTATOR:
- username = utils.get_username()
- sql = SQL_QUERY_MUTATOR(sql, username, security_manager, self.database)
- return sql
- def get_template_processor(self, **kwargs):
- return get_template_processor(
- table=self, database=self.database, **kwargs)
- def get_query_str(self, query_obj):
- qry = self.get_sqla_query(**query_obj)
- sql = self.database.compile_sqla_query(qry)
- logging.info(sql)
- sql = sqlparse.format(sql, reindent=True)
- if query_obj['is_prequery']:
- query_obj['prequeries'].append(sql)
- sql = self.mutate_query_from_config(sql)
- return sql
- def get_sqla_table(self):
- tbl = table(self.table_name)
- if self.schema:
- tbl.schema = self.schema
- return tbl
- def get_from_clause(self, template_processor=None):
- # Supporting arbitrary SQL statements in place of tables
- if self.sql:
- from_sql = self.sql
- if template_processor:
- from_sql = template_processor.process_template(from_sql)
- from_sql = sqlparse.format(from_sql, strip_comments=True)
- return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
- return self.get_sqla_table()
- def adhoc_metric_to_sqla(self, metric, cols):
- """
- Turn an adhoc metric into a sqlalchemy column.
- :param dict metric: Adhoc metric definition
- :param dict cols: Columns for the current table
- :returns: The metric defined as a sqlalchemy column
- :rtype: sqlalchemy.sql.column
- """
- expression_type = metric.get('expressionType')
- db_engine_spec = self.database.db_engine_spec
- label = db_engine_spec.make_label_compatible(metric.get('label'))
- if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
- column_name = metric.get('column').get('column_name')
- sqla_column = column(column_name)
- table_column = cols.get(column_name)
- if table_column:
- sqla_column = table_column.get_sqla_col()
- sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
- sqla_metric = sqla_metric.label(label)
- return sqla_metric
- elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
- sqla_metric = literal_column(metric.get('sqlExpression'))
- sqla_metric = sqla_metric.label(label)
- return sqla_metric
- else:
- return None
- def get_sqla_query( # sqla
- self,
- groupby, metrics,
- granularity,
- from_dttm, to_dttm,
- filter=None, # noqa
- is_timeseries=True,
- timeseries_limit=15,
- timeseries_limit_metric=None,
- row_limit=None,
- inner_from_dttm=None,
- inner_to_dttm=None,
- orderby=None,
- extras=None,
- columns=None,
- order_desc=True,
- prequeries=None,
- is_prequery=False,
- ):
- """Querying any sqla table from this common interface"""
- template_kwargs = {
- 'from_dttm': from_dttm,
- 'groupby': groupby,
- 'metrics': metrics,
- 'row_limit': row_limit,
- 'to_dttm': to_dttm,
- 'filter': filter,
- 'columns': {col.column_name: col for col in self.columns},
- }
- template_kwargs.update(self.template_params_dict)
- template_processor = self.get_template_processor(**template_kwargs)
- db_engine_spec = self.database.db_engine_spec
- orderby = orderby or []
- # For backward compatibility
- if granularity not in self.dttm_cols:
- granularity = self.main_dttm_col
- # Database spec supports join-free timeslot grouping
- time_groupby_inline = db_engine_spec.time_groupby_inline
- cols = {col.column_name: col for col in self.columns}
- metrics_dict = {m.metric_name: m for m in self.metrics}
- if not granularity and is_timeseries:
- raise Exception(_(
- 'Datetime column not provided as part table configuration '
- 'and is required by this type of chart'))
- if not groupby and not metrics and not columns:
- raise Exception(_('Empty query?'))
- metrics_exprs = []
- for m in metrics:
- if utils.is_adhoc_metric(m):
- metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
- elif m in metrics_dict:
- metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
- else:
- raise Exception(_("Metric '{}' is not valid".format(m)))
- if metrics_exprs:
- main_metric_expr = metrics_exprs[0]
- else:
- main_metric_expr = literal_column('COUNT(*)').label(
- db_engine_spec.make_label_compatible('count'))
- select_exprs = []
- groupby_exprs = []
- if groupby:
- select_exprs = []
- inner_select_exprs = []
- inner_groupby_exprs = []
- for s in groupby:
- col = cols[s]
- outer = col.get_sqla_col()
- inner = col.get_sqla_col(col.column_name + '__')
- groupby_exprs.append(outer)
- select_exprs.append(outer)
- inner_groupby_exprs.append(inner)
- inner_select_exprs.append(inner)
- elif columns:
- for s in columns:
- select_exprs.append(cols[s].get_sqla_col())
- metrics_exprs = []
- if granularity:
- dttm_col = cols[granularity]
- time_grain = extras.get('time_grain_sqla')
- time_filters = []
- if is_timeseries:
- timestamp = dttm_col.get_timestamp_expression(time_grain)
- select_exprs += [timestamp]
- groupby_exprs += [timestamp]
- # Use main dttm column to support index with secondary dttm columns
- if db_engine_spec.time_secondary_columns and \
- self.main_dttm_col in self.dttm_cols and \
- self.main_dttm_col != dttm_col.column_name:
- time_filters.append(cols[self.main_dttm_col].
- get_time_filter(from_dttm, to_dttm))
- time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
- select_exprs += metrics_exprs
- qry = sa.select(select_exprs)
- tbl = self.get_from_clause(template_processor)
- if not columns:
- qry = qry.group_by(*groupby_exprs)
- where_clause_and = []
- having_clause_and = []
- for flt in filter:
- if not all([flt.get(s) for s in ['col', 'op']]):
- continue
- col = flt['col']
- op = flt['op']
- col_obj = cols.get(col)
- if col_obj:
- is_list_target = op in ('in', 'not in')
- eq = self.filter_values_handler(
- flt.get('val'),
- target_column_is_numeric=col_obj.is_num,
- is_list_target=is_list_target)
- if op in ('in', 'not in'):
- cond = col_obj.get_sqla_col().in_(eq)
- if '<NULL>' in eq:
- cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
- if op == 'not in':
- cond = ~cond
- where_clause_and.append(cond)
- else:
- if col_obj.is_num:
- eq = utils.string_to_num(flt['val'])
- if op == '==':
- where_clause_and.append(col_obj.get_sqla_col() == eq)
- elif op == '!=':
- where_clause_and.append(col_obj.get_sqla_col() != eq)
- elif op == '>':
- where_clause_and.append(col_obj.get_sqla_col() > eq)
- elif op == '<':
- where_clause_and.append(col_obj.get_sqla_col() < eq)
- elif op == '>=':
- where_clause_and.append(col_obj.get_sqla_col() >= eq)
- elif op == '<=':
- where_clause_and.append(col_obj.get_sqla_col() <= eq)
- elif op == 'LIKE':
- where_clause_and.append(col_obj.get_sqla_col().like(eq))
- elif op == 'IS NULL':
- where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
- elif op == 'IS NOT NULL':
- where_clause_and.append(
- col_obj.get_sqla_col() != None) # noqa
- if extras:
- where = extras.get('where')
- if where:
- where = template_processor.process_template(where)
- where_clause_and += [sa.text('({})'.format(where))]
- having = extras.get('having')
- if having:
- having = template_processor.process_template(having)
- having_clause_and += [sa.text('({})'.format(having))]
- if granularity:
- qry = qry.where(and_(*(time_filters + where_clause_and)))
- else:
- qry = qry.where(and_(*where_clause_and))
- qry = qry.having(and_(*having_clause_and))
- if not orderby and not columns:
- orderby = [(main_metric_expr, not order_desc)]
- for col, ascending in orderby:
- direction = asc if ascending else desc
- if utils.is_adhoc_metric(col):
- col = self.adhoc_metric_to_sqla(col, cols)
- qry = qry.order_by(direction(col))
- if row_limit:
- qry = qry.limit(row_limit)
- if is_timeseries and \
- timeseries_limit and groupby and not time_groupby_inline:
- if self.database.db_engine_spec.inner_joins:
- # some sql dialects require for order by expressions
- # to also be in the select clause -- others, e.g. vertica,
- # require a unique inner alias
- inner_main_metric_expr = main_metric_expr.label('mme_inner__')
- inner_select_exprs += [inner_main_metric_expr]
- subq = select(inner_select_exprs)
- subq = subq.select_from(tbl)
- inner_time_filter = dttm_col.get_time_filter(
- inner_from_dttm or from_dttm,
- inner_to_dttm or to_dttm,
- )
- subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
- subq = subq.group_by(*inner_groupby_exprs)
- ob = inner_main_metric_expr
- if timeseries_limit_metric:
- if utils.is_adhoc_metric(timeseries_limit_metric):
- ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
- elif timeseries_limit_metric in metrics_dict:
- timeseries_limit_metric = metrics_dict.get(
- timeseries_limit_metric,
- )
- ob = timeseries_limit_metric.get_sqla_col()
- else:
- raise Exception(_("Metric '{}' is not valid".format(m)))
- direction = desc if order_desc else asc
- subq = subq.order_by(direction(ob))
- subq = subq.limit(timeseries_limit)
- on_clause = []
- for i, gb in enumerate(groupby):
- on_clause.append(
- groupby_exprs[i] == column(gb + '__'))
- tbl = tbl.join(subq.alias(), and_(*on_clause))
- else:
- # run subquery to get top groups
- subquery_obj = {
- 'prequeries': prequeries,
- 'is_prequery': True,
- 'is_timeseries': False,
- 'row_limit': timeseries_limit,
- 'groupby': groupby,
- 'metrics': metrics,
- 'granularity': granularity,
- 'from_dttm': inner_from_dttm or from_dttm,
- 'to_dttm': inner_to_dttm or to_dttm,
- 'filter': filter,
- 'orderby': orderby,
- 'extras': extras,
- 'columns': columns,
- 'order_desc': True,
- }
- result = self.query(subquery_obj)
- cols = {col.column_name: col for col in self.columns}
- dimensions = [
- c for c in result.df.columns
- if c not in metrics and c in cols
- ]
- top_groups = self._get_top_groups(result.df, dimensions)
- qry = qry.where(top_groups)
- return qry.select_from(tbl)
- def _get_top_groups(self, df, dimensions):
- cols = {col.column_name: col for col in self.columns}
- groups = []
- for unused, row in df.iterrows():
- group = []
- for dimension in dimensions:
- col_obj = cols.get(dimension)
- group.append(col_obj.get_sqla_col() == row[dimension])
- groups.append(and_(*group))
- return or_(*groups)
- def query(self, query_obj):
- qry_start_dttm = datetime.now()
- sql = self.get_query_str(query_obj)
- status = QueryStatus.SUCCESS
- error_message = None
- df = None
- try:
- df = self.database.get_df(sql, self.schema)
- except Exception as e:
- status = QueryStatus.FAILED
- logging.exception(e)
- error_message = (
- self.database.db_engine_spec.extract_error_message(e))
- # if this is a main query with prequeries, combine them together
- if not query_obj['is_prequery']:
- query_obj['prequeries'].append(sql)
- sql = ';\n\n'.join(query_obj['prequeries'])
- sql += ';'
- return QueryResult(
- status=status,
- df=df,
- duration=datetime.now() - qry_start_dttm,
- query=sql,
- error_message=error_message)
- def get_sqla_table_object(self):
- return self.database.get_table(self.table_name, schema=self.schema)
- def fetch_metadata(self):
- """Fetches the metadata for the table and merges it in"""
- try:
- table = self.get_sqla_table_object()
- except Exception as e:
- logging.exception(e)
- raise Exception(_(
- "Table [{}] doesn't seem to exist in the specified database, "
- "couldn't fetch column information").format(self.table_name))
- M = SqlMetric # noqa
- metrics = []
- any_date_col = None
- db_dialect = self.database.get_dialect()
- dbcols = (
- db.session.query(TableColumn)
- .filter(TableColumn.table == self)
- .filter(or_(TableColumn.column_name == col.name
- for col in table.columns)))
- dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
- db_engine_spec = self.database.db_engine_spec
- for col in table.columns:
- try:
- datatype = col.type.compile(dialect=db_dialect).upper()
- except Exception as e:
- datatype = 'UNKNOWN'
- logging.error(
- 'Unrecognized data type in {}.{}'.format(table, col.name))
- logging.exception(e)
- dbcol = dbcols.get(col.name, None)
- if not dbcol:
- dbcol = TableColumn(column_name=col.name, type=datatype)
- dbcol.groupby = dbcol.is_string
- dbcol.filterable = dbcol.is_string
- dbcol.sum = dbcol.is_num
- dbcol.avg = dbcol.is_num
- dbcol.is_dttm = dbcol.is_time
- else:
- dbcol.type = datatype
- self.columns.append(dbcol)
- if not any_date_col and dbcol.is_time:
- any_date_col = col.name
- metrics += dbcol.get_metrics().values()
- metrics.append(M(
- metric_name='count',
- verbose_name='COUNT(*)',
- metric_type='count',
- expression='COUNT(*)',
- ))
- if not self.main_dttm_col:
- self.main_dttm_col = any_date_col
- for metric in metrics:
- metric.metric_name = db_engine_spec.mutate_expression_label(
- metric.metric_name)
- self.add_missing_metrics(metrics)
- db.session.merge(self)
- db.session.commit()
- @classmethod
- def import_obj(cls, i_datasource, import_time=None):
- """Imports the datasource from the object to the database.
- Metrics and columns and datasource will be overrided if exists.
- This function can be used to import/export dashboards between multiple
- superset instances. Audit metadata isn't copies over.
- """
- def lookup_sqlatable(table):
- return db.session.query(SqlaTable).join(Database).filter(
- SqlaTable.table_name == table.table_name,
- SqlaTable.schema == table.schema,
- Database.id == table.database_id,
- ).first()
- def lookup_database(table):
- return db.session.query(Database).filter_by(
- database_name=table.params_dict['database_name']).one()
- return import_util.import_datasource(
- db.session, i_datasource, lookup_database, lookup_sqlatable,
- import_time)
- @classmethod
- def query_datasources_by_name(
- cls, session, database, datasource_name, schema=None):
- query = (
- session.query(cls)
- .filter_by(database_id=database.id)
- .filter_by(table_name=datasource_name)
- )
- if schema:
- query = query.filter_by(schema=schema)
- return query.all()
- @staticmethod
- def default_query(qry):
- return qry.filter_by(is_sqllab_view=False)
- sa.event.listen(SqlaTable, 'after_insert', security_manager.set_perm)
- sa.event.listen(SqlaTable, 'after_update', security_manager.set_perm)
|