# pylint: disable=C,R,W # pylint: disable=invalid-unary-operand-type from collections import OrderedDict from copy import deepcopy from datetime import datetime, timedelta from distutils.version import LooseVersion import json import logging from multiprocessing.pool import ThreadPool import re from dateutil.parser import parse as dparse from flask import escape, Markup from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from flask_babel import lazy_gettext as _ import pandas from pydruid.client import PyDruid from pydruid.utils.aggregators import count from pydruid.utils.dimensions import MapLookupExtraction, RegexExtraction from pydruid.utils.filters import Dimension, Filter from pydruid.utils.having import Aggregation from pydruid.utils.postaggregator import ( Const, Field, HyperUniqueCardinality, Postaggregator, Quantile, Quantiles, ) import requests import sqlalchemy as sa from sqlalchemy import ( Boolean, Column, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, ) from sqlalchemy.orm import backref, relationship from superset import conf, db, import_util, security_manager, utils from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.exceptions import MetricPermException, SupersetException from superset.models.helpers import ( AuditMixinNullable, ImportMixin, QueryResult, ) from superset.utils import ( DimSelector, DTTM_ALIAS, flasher, ) DRUID_TZ = conf.get('DRUID_TZ') POST_AGG_TYPE = 'postagg' # Function wrapper because bound methods cannot # be passed to processes def _fetch_metadata_for(datasource): return datasource.latest_metadata() class JavascriptPostAggregator(Postaggregator): def __init__(self, name, field_names, function): self.post_aggregator = { 'type': 'javascript', 'fieldNames': field_names, 'name': name, 'function': function, } self.name = name class CustomPostAggregator(Postaggregator): """A way to allow users to specify completely custom PostAggregators""" def __init__(self, name, post_aggregator): self.name = name self.post_aggregator = post_aggregator class DruidCluster(Model, AuditMixinNullable, ImportMixin): """ORM object referencing the Druid clusters""" __tablename__ = 'clusters' type = 'druid' id = Column(Integer, primary_key=True) verbose_name = Column(String(250), unique=True) # short unique name, used in permissions cluster_name = Column(String(250), unique=True) coordinator_host = Column(String(255)) coordinator_port = Column(Integer, default=8081) coordinator_endpoint = Column( String(255), default='druid/coordinator/v1/metadata') broker_host = Column(String(255)) broker_port = Column(Integer, default=8082) broker_endpoint = Column(String(255), default='druid/v2') metadata_last_refreshed = Column(DateTime) cache_timeout = Column(Integer) export_fields = ('cluster_name', 'coordinator_host', 'coordinator_port', 'coordinator_endpoint', 'broker_host', 'broker_port', 'broker_endpoint', 'cache_timeout') update_from_object_fields = export_fields export_children = ['datasources'] def __repr__(self): return self.verbose_name if self.verbose_name else self.cluster_name def __html__(self): return self.__repr__() @property def data(self): return { 'id': self.id, 'name': self.cluster_name, 'backend': 'druid', } @staticmethod def get_base_url(host, port): if not re.match('http(s)?://', host): host = 'http://' + host url = '{0}:{1}'.format(host, port) if port else host return url def get_base_broker_url(self): base_url = self.get_base_url( self.broker_host, self.broker_port) return '{base_url}/{self.broker_endpoint}'.format(**locals()) def get_pydruid_client(self): cli = PyDruid( self.get_base_url(self.broker_host, self.broker_port), self.broker_endpoint) return cli def get_datasources(self): endpoint = self.get_base_broker_url() + '/datasources' return json.loads(requests.get(endpoint).text) def get_druid_version(self): endpoint = self.get_base_url( self.coordinator_host, self.coordinator_port) + '/status' return json.loads(requests.get(endpoint).text)['version'] @property @utils.memoized def druid_version(self): return self.get_druid_version() def refresh_datasources( self, datasource_name=None, merge_flag=True, refreshAll=True): """Refresh metadata of all datasources in the cluster If ``datasource_name`` is specified, only that datasource is updated """ ds_list = self.get_datasources() blacklist = conf.get('DRUID_DATA_SOURCE_BLACKLIST', []) ds_refresh = [] if not datasource_name: ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list)) elif datasource_name not in blacklist and datasource_name in ds_list: ds_refresh.append(datasource_name) else: return self.refresh(ds_refresh, merge_flag, refreshAll) def refresh(self, datasource_names, merge_flag, refreshAll): """ Fetches metadata for the specified datasources and merges to the Superset database """ session = db.session ds_list = ( session.query(DruidDatasource) .filter(DruidDatasource.cluster_name == self.cluster_name) .filter(DruidDatasource.datasource_name.in_(datasource_names)) ) ds_map = {ds.name: ds for ds in ds_list} for ds_name in datasource_names: datasource = ds_map.get(ds_name, None) if not datasource: datasource = DruidDatasource(datasource_name=ds_name) with session.no_autoflush: session.add(datasource) flasher( _('Adding new datasource [{}]').format(ds_name), 'success') ds_map[ds_name] = datasource elif refreshAll: flasher( _('Refreshing datasource [{}]').format(ds_name), 'info') else: del ds_map[ds_name] continue datasource.cluster = self datasource.merge_flag = merge_flag session.flush() # Prepare multithreaded executation pool = ThreadPool() ds_refresh = list(ds_map.values()) metadata = pool.map(_fetch_metadata_for, ds_refresh) pool.close() pool.join() for i in range(0, len(ds_refresh)): datasource = ds_refresh[i] cols = metadata[i] if cols: col_objs_list = ( session.query(DruidColumn) .filter(DruidColumn.datasource_id == datasource.id) .filter(DruidColumn.column_name.in_(cols.keys())) ) col_objs = {col.column_name: col for col in col_objs_list} for col in cols: if col == '__time': # skip the time column continue col_obj = col_objs.get(col) if not col_obj: col_obj = DruidColumn( datasource_id=datasource.id, column_name=col) with session.no_autoflush: session.add(col_obj) col_obj.type = cols[col]['type'] col_obj.datasource = datasource if col_obj.type == 'STRING': col_obj.groupby = True col_obj.filterable = True if col_obj.type == 'hyperUnique' or col_obj.type == 'thetaSketch': col_obj.count_distinct = True if col_obj.is_num: col_obj.sum = True col_obj.min = True col_obj.max = True datasource.refresh_metrics() session.commit() @property def perm(self): return '[{obj.cluster_name}].(id:{obj.id})'.format(obj=self) def get_perm(self): return self.perm @property def name(self): return self.verbose_name if self.verbose_name else self.cluster_name @property def unique_name(self): return self.verbose_name if self.verbose_name else self.cluster_name class DruidColumn(Model, BaseColumn): """ORM model for storing Druid datasource column metadata""" __tablename__ = 'columns' __table_args__ = (UniqueConstraint('column_name', 'datasource_id'),) datasource_id = Column( Integer, ForeignKey('datasources.id')) # Setting enable_typechecks=False disables polymorphic inheritance. datasource = relationship( 'DruidDatasource', backref=backref('columns', cascade='all, delete-orphan'), enable_typechecks=False) dimension_spec_json = Column(Text) export_fields = ( 'datasource_id', 'column_name', 'is_active', 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', 'description', 'dimension_spec_json', 'verbose_name', ) update_from_object_fields = export_fields export_parent = 'datasource' def __repr__(self): return self.column_name @property def expression(self): return self.dimension_spec_json @property def dimension_spec(self): if self.dimension_spec_json: return json.loads(self.dimension_spec_json) def get_metrics(self): metrics = {} metrics['count'] = DruidMetric( metric_name='count', verbose_name='COUNT(*)', metric_type='count', json=json.dumps({'type': 'count', 'name': 'count'}), ) # Somehow we need to reassign this for UDAFs if self.type in ('DOUBLE', 'FLOAT'): corrected_type = 'DOUBLE' else: corrected_type = self.type if self.sum and self.is_num: mt = corrected_type.lower() + 'Sum' name = 'sum__' + self.column_name metrics[name] = DruidMetric( metric_name=name, metric_type='sum', verbose_name='SUM({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}), ) if self.avg and self.is_num: mt = corrected_type.lower() + 'Avg' name = 'avg__' + self.column_name metrics[name] = DruidMetric( metric_name=name, metric_type='avg', verbose_name='AVG({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}), ) if self.min and self.is_num: mt = corrected_type.lower() + 'Min' name = 'min__' + self.column_name metrics[name] = DruidMetric( metric_name=name, metric_type='min', verbose_name='MIN({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}), ) if self.max and self.is_num: mt = corrected_type.lower() + 'Max' name = 'max__' + self.column_name metrics[name] = DruidMetric( metric_name=name, metric_type='max', verbose_name='MAX({})'.format(self.column_name), json=json.dumps({ 'type': mt, 'name': name, 'fieldName': self.column_name}), ) if self.count_distinct: name = 'count_distinct__' + self.column_name if self.type == 'hyperUnique' or self.type == 'thetaSketch': metrics[name] = DruidMetric( metric_name=name, verbose_name='COUNT(DISTINCT {})'.format(self.column_name), metric_type=self.type, json=json.dumps({ 'type': self.type, 'name': name, 'fieldName': self.column_name, }), ) else: metrics[name] = DruidMetric( metric_name=name, verbose_name='COUNT(DISTINCT {})'.format(self.column_name), metric_type='count_distinct', json=json.dumps({ 'type': 'cardinality', 'name': name, 'fieldNames': [self.column_name]}), ) return metrics def refresh_metrics(self): """Refresh metrics based on the column metadata""" metrics = self.get_metrics() dbmetrics = ( db.session.query(DruidMetric) .filter(DruidMetric.datasource_id == self.datasource_id) .filter(DruidMetric.metric_name.in_(metrics.keys())) ) dbmetrics = {metric.metric_name: metric for metric in dbmetrics} for metric in metrics.values(): dbmetric = dbmetrics.get(metric.metric_name) if dbmetric: for attr in ['json', 'metric_type']: setattr(dbmetric, attr, getattr(metric, attr)) else: with db.session.no_autoflush: metric.datasource_id = self.datasource_id db.session.add(metric) @classmethod def import_obj(cls, i_column): def lookup_obj(lookup_column): return db.session.query(DruidColumn).filter( DruidColumn.datasource_id == lookup_column.datasource_id, DruidColumn.column_name == lookup_column.column_name).first() return import_util.import_simple_obj(db.session, i_column, lookup_obj) class DruidMetric(Model, BaseMetric): """ORM object referencing Druid metrics for a datasource""" __tablename__ = 'metrics' __table_args__ = (UniqueConstraint('metric_name', 'datasource_id'),) datasource_id = Column( Integer, ForeignKey('datasources.id')) # Setting enable_typechecks=False disables polymorphic inheritance. datasource = relationship( 'DruidDatasource', backref=backref('metrics', cascade='all, delete-orphan'), enable_typechecks=False) json = Column(Text) export_fields = ( 'metric_name', 'verbose_name', 'metric_type', 'datasource_id', 'json', 'description', 'is_restricted', 'd3format', 'warning_text', ) update_from_object_fields = export_fields export_parent = 'datasource' @property def expression(self): return self.json @property def json_obj(self): try: obj = json.loads(self.json) except Exception: obj = {} return obj @property def perm(self): return ( '{parent_name}.[{obj.metric_name}](id:{obj.id})' ).format(obj=self, parent_name=self.datasource.full_name, ) if self.datasource else None @classmethod def import_obj(cls, i_metric): def lookup_obj(lookup_metric): return db.session.query(DruidMetric).filter( DruidMetric.datasource_id == lookup_metric.datasource_id, DruidMetric.metric_name == lookup_metric.metric_name).first() return import_util.import_simple_obj(db.session, i_metric, lookup_obj) class DruidDatasource(Model, BaseDatasource): """ORM object referencing Druid datasources (tables)""" __tablename__ = 'datasources' __table_args__ = (UniqueConstraint('datasource_name', 'cluster_name'),) type = 'druid' query_language = 'json' cluster_class = DruidCluster metric_class = DruidMetric column_class = DruidColumn baselink = 'druiddatasourcemodelview' # Columns datasource_name = Column(String(255)) is_hidden = Column(Boolean, default=False) filter_select_enabled = Column(Boolean, default=True) # override default fetch_values_from = Column(String(100)) cluster_name = Column( String(250), ForeignKey('clusters.cluster_name')) cluster = relationship( 'DruidCluster', backref='datasources', foreign_keys=[cluster_name]) user_id = Column(Integer, ForeignKey('ab_user.id')) owner = relationship( security_manager.user_model, backref=backref('datasources', cascade='all, delete-orphan'), foreign_keys=[user_id]) UniqueConstraint('cluster_name', 'datasource_name') export_fields = ( 'datasource_name', 'is_hidden', 'description', 'default_endpoint', 'cluster_name', 'offset', 'cache_timeout', 'params', 'filter_select_enabled', ) update_from_object_fields = export_fields export_parent = 'cluster' export_children = ['columns', 'metrics'] @property def database(self): return self.cluster @property def connection(self): return str(self.database) @property def num_cols(self): return [c.column_name for c in self.columns if c.is_num] @property def name(self): return self.datasource_name @property def schema(self): ds_name = self.datasource_name or '' name_pieces = ds_name.split('.') if len(name_pieces) > 1: return name_pieces[0] else: return None @property def schema_perm(self): """Returns schema permission if present, cluster one otherwise.""" return security_manager.get_schema_perm(self.cluster, self.schema) def get_perm(self): return ( '[{obj.cluster_name}].[{obj.datasource_name}]' '(id:{obj.id})').format(obj=self) def update_from_object(self, obj): return NotImplementedError() @property def link(self): name = escape(self.datasource_name) return Markup('{name}').format(**locals()) @property def full_name(self): return utils.get_datasource_full_name( self.cluster_name, self.datasource_name) @property def time_column_grains(self): return { 'time_columns': [ 'all', '5 seconds', '30 seconds', '1 minute', '5 minutes' '30 minutes', '1 hour', '6 hour', '1 day', '7 days', 'week', 'week_starting_sunday', 'week_ending_saturday', 'month', 'quarter', 'year', ], 'time_grains': ['now'], } def __repr__(self): return self.datasource_name @renders('datasource_name') def datasource_link(self): url = '/superset/explore/{obj.type}/{obj.id}/'.format(obj=self) name = escape(self.datasource_name) return Markup('{name}'.format(**locals())) def get_metric_obj(self, metric_name): return [ m.json_obj for m in self.metrics if m.metric_name == metric_name ][0] @classmethod def import_obj(cls, i_datasource, import_time=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overridden if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ def lookup_datasource(d): return db.session.query(DruidDatasource).filter( DruidDatasource.datasource_name == d.datasource_name, DruidCluster.cluster_name == d.cluster_name, ).first() def lookup_cluster(d): return db.session.query(DruidCluster).filter_by( cluster_name=d.cluster_name).one() return import_util.import_datasource( db.session, i_datasource, lookup_cluster, lookup_datasource, import_time) def latest_metadata(self): """Returns segment metadata from the latest segment""" logging.info('Syncing datasource [{}]'.format(self.datasource_name)) client = self.cluster.get_pydruid_client() try: results = client.time_boundary(datasource=self.datasource_name) except IOError: results = None if results: max_time = results[0]['result']['maxTime'] max_time = dparse(max_time) else: max_time = datetime.now() # Query segmentMetadata for 7 days back. However, due to a bug, # we need to set this interval to more than 1 day ago to exclude # realtime segments, which triggered a bug (fixed in druid 0.8.2). # https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ lbound = (max_time - timedelta(days=7)).isoformat() if LooseVersion(self.cluster.druid_version) < LooseVersion('0.8.2'): rbound = (max_time - timedelta(1)).isoformat() else: rbound = max_time.isoformat() segment_metadata = None try: segment_metadata = client.segment_metadata( datasource=self.datasource_name, intervals=lbound + '/' + rbound, merge=self.merge_flag, analysisTypes=[]) except Exception as e: logging.warning('Failed first attempt to get latest segment') logging.exception(e) if not segment_metadata: # if no segments in the past 7 days, look at all segments lbound = datetime(1901, 1, 1).isoformat()[:10] if LooseVersion(self.cluster.druid_version) < LooseVersion('0.8.2'): rbound = datetime.now().isoformat() else: rbound = datetime(2050, 1, 1).isoformat()[:10] try: segment_metadata = client.segment_metadata( datasource=self.datasource_name, intervals=lbound + '/' + rbound, merge=self.merge_flag, analysisTypes=[]) except Exception as e: logging.warning('Failed 2nd attempt to get latest segment') logging.exception(e) if segment_metadata: return segment_metadata[-1]['columns'] def refresh_metrics(self): for col in self.columns: col.refresh_metrics() @classmethod def sync_to_db_from_config( cls, druid_config, user, cluster, refresh=True): """Merges the ds config from druid_config into one stored in the db.""" session = db.session datasource = ( session.query(cls) .filter_by(datasource_name=druid_config['name']) .first() ) # Create a new datasource. if not datasource: datasource = cls( datasource_name=druid_config['name'], cluster=cluster, owner=user, changed_by_fk=user.id, created_by_fk=user.id, ) session.add(datasource) elif not refresh: return dimensions = druid_config['dimensions'] col_objs = ( session.query(DruidColumn) .filter(DruidColumn.datasource_id == datasource.id) .filter(DruidColumn.column_name.in_(dimensions)) ) col_objs = {col.column_name: col for col in col_objs} for dim in dimensions: col_obj = col_objs.get(dim, None) if not col_obj: col_obj = DruidColumn( datasource_id=datasource.id, column_name=dim, groupby=True, filterable=True, # TODO: fetch type from Hive. type='STRING', datasource=datasource, ) session.add(col_obj) # Import Druid metrics metric_objs = ( session.query(DruidMetric) .filter(DruidMetric.datasource_id == datasource.id) .filter(DruidMetric.metric_name.in_( spec['name'] for spec in druid_config['metrics_spec'] )) ) metric_objs = {metric.metric_name: metric for metric in metric_objs} for metric_spec in druid_config['metrics_spec']: metric_name = metric_spec['name'] metric_type = metric_spec['type'] metric_json = json.dumps(metric_spec) if metric_type == 'count': metric_type = 'longSum' metric_json = json.dumps({ 'type': 'longSum', 'name': metric_name, 'fieldName': metric_name, }) metric_obj = metric_objs.get(metric_name, None) if not metric_obj: metric_obj = DruidMetric( metric_name=metric_name, metric_type=metric_type, verbose_name='%s(%s)' % (metric_type, metric_name), datasource=datasource, json=metric_json, description=( 'Imported from the airolap config dir for %s' % druid_config['name']), ) session.add(metric_obj) session.commit() @staticmethod def time_offset(granularity): if granularity == 'week_ending_saturday': return 6 * 24 * 3600 * 1000 # 6 days return 0 # uses https://en.wikipedia.org/wiki/ISO_8601 # http://druid.io/docs/0.8.0/querying/granularities.html # TODO: pass origin from the UI @staticmethod def granularity(period_name, timezone=None, origin=None): if not period_name or period_name == 'all': return 'all' iso_8601_dict = { '5 seconds': 'PT5S', '30 seconds': 'PT30S', '1 minute': 'PT1M', '5 minutes': 'PT5M', '30 minutes': 'PT30M', '1 hour': 'PT1H', '6 hour': 'PT6H', 'one day': 'P1D', '1 day': 'P1D', '7 days': 'P7D', 'week': 'P1W', 'week_starting_sunday': 'P1W', 'week_ending_saturday': 'P1W', 'month': 'P1M', 'quarter': 'P3M', 'year': 'P1Y', } granularity = {'type': 'period'} if timezone: granularity['timeZone'] = timezone if origin: dttm = utils.parse_human_datetime(origin) granularity['origin'] = dttm.isoformat() if period_name in iso_8601_dict: granularity['period'] = iso_8601_dict[period_name] if period_name in ('week_ending_saturday', 'week_starting_sunday'): # use Sunday as start of the week granularity['origin'] = '2016-01-03T00:00:00' elif not isinstance(period_name, str): granularity['type'] = 'duration' granularity['duration'] = period_name elif period_name.startswith('P'): # identify if the string is the iso_8601 period granularity['period'] = period_name else: granularity['type'] = 'duration' granularity['duration'] = utils.parse_human_timedelta( period_name).total_seconds() * 1000 return granularity @staticmethod def get_post_agg(mconf): """ For a metric specified as `postagg` returns the kind of post aggregation for pydruid. """ if mconf.get('type') == 'javascript': return JavascriptPostAggregator( name=mconf.get('name', ''), field_names=mconf.get('fieldNames', []), function=mconf.get('function', '')) elif mconf.get('type') == 'quantile': return Quantile( mconf.get('name', ''), mconf.get('probability', ''), ) elif mconf.get('type') == 'quantiles': return Quantiles( mconf.get('name', ''), mconf.get('probabilities', ''), ) elif mconf.get('type') == 'fieldAccess': return Field(mconf.get('name')) elif mconf.get('type') == 'constant': return Const( mconf.get('value'), output_name=mconf.get('name', ''), ) elif mconf.get('type') == 'hyperUniqueCardinality': return HyperUniqueCardinality( mconf.get('name'), ) elif mconf.get('type') == 'arithmetic': return Postaggregator( mconf.get('fn', '/'), mconf.get('fields', []), mconf.get('name', '')) else: return CustomPostAggregator( mconf.get('name', ''), mconf) @staticmethod def find_postaggs_for(postagg_names, metrics_dict): """Return a list of metrics that are post aggregations""" postagg_metrics = [ metrics_dict[name] for name in postagg_names if metrics_dict[name].metric_type == POST_AGG_TYPE ] # Remove post aggregations that were found for postagg in postagg_metrics: postagg_names.remove(postagg.metric_name) return postagg_metrics @staticmethod def recursive_get_fields(_conf): _type = _conf.get('type') _field = _conf.get('field') _fields = _conf.get('fields') field_names = [] if _type in ['fieldAccess', 'hyperUniqueCardinality', 'quantile', 'quantiles']: field_names.append(_conf.get('fieldName', '')) if _field: field_names += DruidDatasource.recursive_get_fields(_field) if _fields: for _f in _fields: field_names += DruidDatasource.recursive_get_fields(_f) return list(set(field_names)) @staticmethod def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict): mconf = postagg.json_obj required_fields = set( DruidDatasource.recursive_get_fields(mconf) + mconf.get('fieldNames', [])) # Check if the fields are already in aggs # or is a previous postagg required_fields = set([ field for field in required_fields if field not in visited_postaggs and field not in agg_names ]) # First try to find postaggs that match if len(required_fields) > 0: missing_postaggs = DruidDatasource.find_postaggs_for( required_fields, metrics_dict) for missing_metric in required_fields: agg_names.add(missing_metric) for missing_postagg in missing_postaggs: # Add to visited first to avoid infinite recursion # if post aggregations are cyclicly dependent visited_postaggs.add(missing_postagg.metric_name) for missing_postagg in missing_postaggs: DruidDatasource.resolve_postagg( missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict) post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj) @staticmethod def metrics_and_post_aggs(metrics, metrics_dict, druid_version=None): # Separate metrics into those that are aggregations # and those that are post aggregations saved_agg_names = set() adhoc_agg_configs = [] postagg_names = [] for metric in metrics: if utils.is_adhoc_metric(metric): adhoc_agg_configs.append(metric) elif metrics_dict[metric].metric_type != POST_AGG_TYPE: saved_agg_names.add(metric) else: postagg_names.append(metric) # Create the post aggregations, maintain order since postaggs # may depend on previous ones post_aggs = OrderedDict() visited_postaggs = set() for postagg_name in postagg_names: postagg = metrics_dict[postagg_name] visited_postaggs.add(postagg_name) DruidDatasource.resolve_postagg( postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict) aggs = DruidDatasource.get_aggregations( metrics_dict, saved_agg_names, adhoc_agg_configs, ) return aggs, post_aggs def values_for_column(self, column_name, limit=10000): """Retrieve some values for the given column""" logging.info( 'Getting values for columns [{}] limited to [{}]' .format(column_name, limit)) # TODO: Use Lexicographic TopNMetricSpec once supported by PyDruid if self.fetch_values_from: from_dttm = utils.parse_human_datetime(self.fetch_values_from) else: from_dttm = datetime(1970, 1, 1) qry = dict( datasource=self.datasource_name, granularity='all', intervals=from_dttm.isoformat() + '/' + datetime.now().isoformat(), aggregations=dict(count=count('count')), dimension=column_name, metric='count', threshold=limit, ) client = self.cluster.get_pydruid_client() client.topn(**qry) df = client.export_pandas() return [row[column_name] for row in df.to_records(index=False)] def get_query_str(self, query_obj, phase=1, client=None): return self.run_query(client=client, phase=phase, **query_obj) def _add_filter_from_pre_query_data(self, df, dimensions, dim_filter): ret = dim_filter if df is not None and not df.empty: new_filters = [] for unused, row in df.iterrows(): fields = [] for dim in dimensions: f = None # Check if this dimension uses an extraction function # If so, create the appropriate pydruid extraction object if isinstance(dim, dict) and 'extractionFn' in dim: (col, extraction_fn) = DruidDatasource._create_extraction_fn(dim) dim_val = dim['outputName'] f = Filter( dimension=col, value=row[dim_val], extraction_function=extraction_fn, ) elif isinstance(dim, dict): dim_val = dim['outputName'] if dim_val: f = Dimension(dim_val) == row[dim_val] else: f = Dimension(dim) == row[dim] if f: fields.append(f) if len(fields) > 1: term = Filter(type='and', fields=fields) new_filters.append(term) elif fields: new_filters.append(fields[0]) if new_filters: ff = Filter(type='or', fields=new_filters) if not dim_filter: ret = ff else: ret = Filter(type='and', fields=[ff, dim_filter]) return ret @staticmethod def druid_type_from_adhoc_metric(adhoc_metric): column_type = adhoc_metric['column']['type'].lower() aggregate = adhoc_metric['aggregate'].lower() if aggregate == 'count': return 'count' if aggregate == 'count_distinct': return 'cardinality' else: return column_type + aggregate.capitalize() @staticmethod def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]): """ Returns a dictionary of aggregation metric names to aggregation json objects :param metrics_dict: dictionary of all the metrics :param saved_metrics: list of saved metric names :param adhoc_metrics: list of adhoc metric names :raise SupersetException: if one or more metric names are not aggregations """ aggregations = OrderedDict() invalid_metric_names = [] for metric_name in saved_metrics: if metric_name in metrics_dict: metric = metrics_dict[metric_name] if metric.metric_type == POST_AGG_TYPE: invalid_metric_names.append(metric_name) else: aggregations[metric_name] = metric.json_obj else: invalid_metric_names.append(metric_name) if len(invalid_metric_names) > 0: raise SupersetException( _('Metric(s) {} must be aggregations.').format(invalid_metric_names)) for adhoc_metric in adhoc_metrics: aggregations[adhoc_metric['label']] = { 'fieldName': adhoc_metric['column']['column_name'], 'fieldNames': [adhoc_metric['column']['column_name']], 'type': DruidDatasource.druid_type_from_adhoc_metric(adhoc_metric), 'name': adhoc_metric['label'], } return aggregations def check_restricted_metrics(self, aggregations): rejected_metrics = [ m.metric_name for m in self.metrics if m.is_restricted and m.metric_name in aggregations.keys() and not security_manager.has_access('metric_access', m.perm) ] if rejected_metrics: raise MetricPermException( 'Access to the metrics denied: ' + ', '.join(rejected_metrics), ) def get_dimensions(self, groupby, columns_dict): dimensions = [] groupby = [gb for gb in groupby if gb in columns_dict] for column_name in groupby: col = columns_dict.get(column_name) dim_spec = col.dimension_spec if col else None if dim_spec: dimensions.append(dim_spec) else: dimensions.append(column_name) return dimensions def intervals_from_dttms(self, from_dttm, to_dttm): # Couldn't find a way to just not filter on time... from_dttm = from_dttm or datetime(1901, 1, 1) to_dttm = to_dttm or datetime(2101, 1, 1) # add tzinfo to native datetime with config from_dttm = from_dttm.replace(tzinfo=DRUID_TZ) to_dttm = to_dttm.replace(tzinfo=DRUID_TZ) return '{}/{}'.format( from_dttm.isoformat() if from_dttm else '', to_dttm.isoformat() if to_dttm else '', ) @staticmethod def _dimensions_to_values(dimensions): """ Replace dimensions specs with their `dimension` values, and ignore those without """ values = [] for dimension in dimensions: if isinstance(dimension, dict): if 'extractionFn' in dimension: values.append(dimension) elif 'dimension' in dimension: values.append(dimension['dimension']) else: values.append(dimension) return values @staticmethod def sanitize_metric_object(metric): """ Update a metric with the correct type if necessary. :param dict metric: The metric to sanitize """ if ( utils.is_adhoc_metric(metric) and metric['column']['type'].upper() == 'FLOAT' ): metric['column']['type'] = 'DOUBLE' def run_query( # noqa / druid self, groupby, metrics, granularity, from_dttm, to_dttm, filter=None, # noqa is_timeseries=True, timeseries_limit=None, timeseries_limit_metric=None, row_limit=None, inner_from_dttm=None, inner_to_dttm=None, orderby=None, extras=None, # noqa columns=None, phase=2, client=None, order_desc=True, prequeries=None, is_prequery=False, ): """Runs a query against Druid and returns a dataframe. """ # TODO refactor into using a TBD Query object client = client or self.cluster.get_pydruid_client() row_limit = row_limit or conf.get('ROW_LIMIT') if not is_timeseries: granularity = 'all' if granularity == 'all': phase = 1 inner_from_dttm = inner_from_dttm or from_dttm inner_to_dttm = inner_to_dttm or to_dttm timezone = from_dttm.replace(tzinfo=DRUID_TZ).tzname() if from_dttm else None query_str = '' metrics_dict = {m.metric_name: m for m in self.metrics} columns_dict = {c.column_name: c for c in self.columns} if ( self.cluster and LooseVersion(self.cluster.get_druid_version()) < LooseVersion('0.11.0') ): for metric in metrics: self.sanitize_metric_object(metric) self.sanitize_metric_object(timeseries_limit_metric) aggregations, post_aggs = DruidDatasource.metrics_and_post_aggs( metrics, metrics_dict) self.check_restricted_metrics(aggregations) # the dimensions list with dimensionSpecs expanded dimensions = self.get_dimensions(groupby, columns_dict) extras = extras or {} qry = dict( datasource=self.datasource_name, dimensions=dimensions, aggregations=aggregations, granularity=DruidDatasource.granularity( granularity, timezone=timezone, origin=extras.get('druid_time_origin'), ), post_aggregations=post_aggs, intervals=self.intervals_from_dttms(from_dttm, to_dttm), ) filters = DruidDatasource.get_filters(filter, self.num_cols, columns_dict) if filters: qry['filter'] = filters having_filters = self.get_having_filters(extras.get('having_druid')) if having_filters: qry['having'] = having_filters order_direction = 'descending' if order_desc else 'ascending' if columns: columns.append('__time') del qry['post_aggregations'] del qry['aggregations'] qry['dimensions'] = columns qry['metrics'] = [] qry['granularity'] = 'all' qry['limit'] = row_limit client.scan(**qry) elif len(groupby) == 0 and not having_filters: logging.info('Running timeseries query for no groupby values') del qry['dimensions'] client.timeseries(**qry) elif ( not having_filters and len(groupby) == 1 and order_desc ): dim = list(qry.get('dimensions'))[0] logging.info('Running two-phase topn query for dimension [{}]'.format(dim)) pre_qry = deepcopy(qry) if timeseries_limit_metric: order_by = utils.get_metric_name(timeseries_limit_metric) aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) if phase == 1: pre_qry['aggregations'].update(aggs_dict) pre_qry['post_aggregations'].update(post_aggs_dict) else: pre_qry['aggregations'] = aggs_dict pre_qry['post_aggregations'] = post_aggs_dict else: order_by = list(qry['aggregations'].keys())[0] # Limit on the number of timeseries, doing a two-phases query pre_qry['granularity'] = 'all' pre_qry['threshold'] = min(row_limit, timeseries_limit or row_limit) pre_qry['metric'] = order_by pre_qry['dimension'] = self._dimensions_to_values(qry.get('dimensions'))[0] del pre_qry['dimensions'] client.topn(**pre_qry) logging.info('Phase 1 Complete') if phase == 2: query_str += '// Two phase query\n// Phase 1\n' query_str += json.dumps( client.query_builder.last_query.query_dict, indent=2) query_str += '\n' if phase == 1: return query_str query_str += ( "// Phase 2 (built based on phase one's results)\n") df = client.export_pandas() qry['filter'] = self._add_filter_from_pre_query_data( df, [pre_qry['dimension']], filters) qry['threshold'] = timeseries_limit or 1000 if row_limit and granularity == 'all': qry['threshold'] = row_limit qry['dimension'] = dim del qry['dimensions'] qry['metric'] = list(qry['aggregations'].keys())[0] client.topn(**qry) logging.info('Phase 2 Complete') elif len(groupby) > 0 or having_filters: # If grouping on multiple fields or using a having filter # we have to force a groupby query logging.info('Running groupby query for dimensions [{}]'.format(dimensions)) if timeseries_limit and is_timeseries: logging.info('Running two-phase query for timeseries') pre_qry = deepcopy(qry) pre_qry_dims = self._dimensions_to_values(qry['dimensions']) # Can't use set on an array with dicts # Use set with non-dict items only non_dict_dims = list( set([x for x in pre_qry_dims if not isinstance(x, dict)]), ) dict_dims = [x for x in pre_qry_dims if isinstance(x, dict)] pre_qry['dimensions'] = non_dict_dims + dict_dims order_by = None if metrics: order_by = utils.get_metric_name(metrics[0]) else: order_by = pre_qry_dims[0] if timeseries_limit_metric: order_by = utils.get_metric_name(timeseries_limit_metric) aggs_dict, post_aggs_dict = DruidDatasource.metrics_and_post_aggs( [timeseries_limit_metric], metrics_dict) if phase == 1: pre_qry['aggregations'].update(aggs_dict) pre_qry['post_aggregations'].update(post_aggs_dict) else: pre_qry['aggregations'] = aggs_dict pre_qry['post_aggregations'] = post_aggs_dict # Limit on the number of timeseries, doing a two-phases query pre_qry['granularity'] = 'all' pre_qry['limit_spec'] = { 'type': 'default', 'limit': min(timeseries_limit, row_limit), 'intervals': self.intervals_from_dttms( inner_from_dttm, inner_to_dttm), 'columns': [{ 'dimension': order_by, 'direction': order_direction, }], } client.groupby(**pre_qry) logging.info('Phase 1 Complete') query_str += '// Two phase query\n// Phase 1\n' query_str += json.dumps( client.query_builder.last_query.query_dict, indent=2) query_str += '\n' if phase == 1: return query_str query_str += ( "// Phase 2 (built based on phase one's results)\n") df = client.export_pandas() qry['filter'] = self._add_filter_from_pre_query_data( df, pre_qry['dimensions'], filters, ) qry['limit_spec'] = None if row_limit: dimension_values = self._dimensions_to_values(dimensions) qry['limit_spec'] = { 'type': 'default', 'limit': row_limit, 'columns': [{ 'dimension': ( utils.get_metric_name( metrics[0], ) if metrics else dimension_values[0] ), 'direction': order_direction, }], } client.groupby(**qry) logging.info('Query Complete') query_str += json.dumps( client.query_builder.last_query.query_dict, indent=2) return query_str @staticmethod def homogenize_types(df, groupby_cols): """Converting all GROUPBY columns to strings When grouping by a numeric (say FLOAT) column, pydruid returns strings in the dataframe. This creates issues downstream related to having mixed types in the dataframe Here we replace None with and make the whole series a str instead of an object. """ for col in groupby_cols: df[col] = df[col].fillna('').astype('unicode') return df def query(self, query_obj): qry_start_dttm = datetime.now() client = self.cluster.get_pydruid_client() query_str = self.get_query_str( client=client, query_obj=query_obj, phase=2) df = client.export_pandas() if df is None or df.size == 0: return QueryResult( df=pandas.DataFrame([]), query=query_str, duration=datetime.now() - qry_start_dttm) df = self.homogenize_types(df, query_obj.get('groupby', [])) df.columns = [ DTTM_ALIAS if c in ('timestamp', '__time') else c for c in df.columns ] is_timeseries = query_obj['is_timeseries'] \ if 'is_timeseries' in query_obj else True if ( not is_timeseries and DTTM_ALIAS in df.columns): del df[DTTM_ALIAS] # Reordering columns cols = [] if DTTM_ALIAS in df.columns: cols += [DTTM_ALIAS] cols += query_obj.get('groupby') or [] cols += query_obj.get('columns') or [] cols += query_obj.get('metrics') or [] cols = utils.get_metric_names(cols) cols = [col for col in cols if col in df.columns] df = df[cols] time_offset = DruidDatasource.time_offset(query_obj['granularity']) def increment_timestamp(ts): dt = utils.parse_human_datetime(ts).replace( tzinfo=DRUID_TZ) return dt + timedelta(milliseconds=time_offset) if DTTM_ALIAS in df.columns and time_offset: df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(increment_timestamp) return QueryResult( df=df, query=query_str, duration=datetime.now() - qry_start_dttm) @staticmethod def _create_extraction_fn(dim_spec): extraction_fn = None if dim_spec and 'extractionFn' in dim_spec: col = dim_spec['dimension'] fn = dim_spec['extractionFn'] ext_type = fn.get('type') if ext_type == 'lookup' and fn['lookup'].get('type') == 'map': replace_missing_values = fn.get('replaceMissingValueWith') retain_missing_values = fn.get('retainMissingValue', False) injective = fn.get('isOneToOne', False) extraction_fn = MapLookupExtraction( fn['lookup']['map'], replace_missing_values=replace_missing_values, retain_missing_values=retain_missing_values, injective=injective, ) elif ext_type == 'regex': extraction_fn = RegexExtraction(fn['expr']) else: raise Exception(_('Unsupported extraction function: ' + ext_type)) return (col, extraction_fn) @classmethod def get_filters(cls, raw_filters, num_cols, columns_dict): # noqa """Given Superset filter data structure, returns pydruid Filter(s)""" filters = None for flt in raw_filters: col = flt.get('col') op = flt.get('op') eq = flt.get('val') if ( not col or not op or (eq is None and op not in ('IS NULL', 'IS NOT NULL'))): continue # Check if this dimension uses an extraction function # If so, create the appropriate pydruid extraction object column_def = columns_dict.get(col) dim_spec = column_def.dimension_spec if column_def else None extraction_fn = None if dim_spec and 'extractionFn' in dim_spec: (col, extraction_fn) = DruidDatasource._create_extraction_fn(dim_spec) cond = None is_numeric_col = col in num_cols is_list_target = op in ('in', 'not in') eq = cls.filter_values_handler( eq, is_list_target=is_list_target, target_column_is_numeric=is_numeric_col) # For these two ops, could have used Dimension, # but it doesn't support extraction functions if op == '==': cond = Filter(dimension=col, value=eq, extraction_function=extraction_fn) elif op == '!=': cond = ~Filter(dimension=col, value=eq, extraction_function=extraction_fn) elif op in ('in', 'not in'): fields = [] # ignore the filter if it has no value if not len(eq): continue # if it uses an extraction fn, use the "in" operator # as Dimension isn't supported elif extraction_fn is not None: cond = Filter( dimension=col, values=eq, type='in', extraction_function=extraction_fn, ) elif len(eq) == 1: cond = Dimension(col) == eq[0] else: for s in eq: fields.append(Dimension(col) == s) cond = Filter(type='or', fields=fields) if op == 'not in': cond = ~cond elif op == 'regex': cond = Filter( extraction_function=extraction_fn, type='regex', pattern=eq, dimension=col, ) # For the ops below, could have used pydruid's Bound, # but it doesn't support extraction functions elif op == '>=': cond = Filter( type='bound', extraction_function=extraction_fn, dimension=col, lowerStrict=False, upperStrict=False, lower=eq, upper=None, alphaNumeric=is_numeric_col, ) elif op == '<=': cond = Filter( type='bound', extraction_function=extraction_fn, dimension=col, lowerStrict=False, upperStrict=False, lower=None, upper=eq, alphaNumeric=is_numeric_col, ) elif op == '>': cond = Filter( type='bound', extraction_function=extraction_fn, lowerStrict=True, upperStrict=False, dimension=col, lower=eq, upper=None, alphaNumeric=is_numeric_col, ) elif op == '<': cond = Filter( type='bound', extraction_function=extraction_fn, upperStrict=True, lowerStrict=False, dimension=col, lower=None, upper=eq, alphaNumeric=is_numeric_col, ) elif op == 'IS NULL': cond = Dimension(col) == None # NOQA elif op == 'IS NOT NULL': cond = Dimension(col) != None # NOQA if filters: filters = Filter(type='and', fields=[ cond, filters, ]) else: filters = cond return filters def _get_having_obj(self, col, op, eq): cond = None if op == '==': if col in self.column_names: cond = DimSelector(dimension=col, value=eq) else: cond = Aggregation(col) == eq elif op == '>': cond = Aggregation(col) > eq elif op == '<': cond = Aggregation(col) < eq return cond def get_having_filters(self, raw_filters): filters = None reversed_op_map = { '!=': '==', '>=': '<', '<=': '>', } for flt in raw_filters: if not all(f in flt for f in ['col', 'op', 'val']): continue col = flt['col'] op = flt['op'] eq = flt['val'] cond = None if op in ['==', '>', '<']: cond = self._get_having_obj(col, op, eq) elif op in reversed_op_map: cond = ~self._get_having_obj(col, reversed_op_map[op], eq) if filters: filters = filters & cond else: filters = cond return filters @classmethod def query_datasources_by_name( cls, session, database, datasource_name, schema=None): return ( session.query(cls) .filter_by(cluster_name=database.id) .filter_by(datasource_name=datasource_name) .all() ) def external_metadata(self): self.merge_flag = True return [ { 'name': k, 'type': v.get('type'), } for k, v in self.latest_metadata().items() ] sa.event.listen(DruidDatasource, 'after_insert', security_manager.set_perm) sa.event.listen(DruidDatasource, 'after_update', security_manager.set_perm)