1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129 |
- # pylint: disable=C,R,W
- """A collection of ORM sqlalchemy models for Superset"""
- from contextlib import closing
- from copy import copy, deepcopy
- from datetime import datetime
- import functools
- import json
- import logging
- import textwrap
- from flask import escape, g, Markup, request
- from flask_appbuilder import Model
- from flask_appbuilder.models.decorators import renders
- from flask_appbuilder.security.sqla.models import User
- from future.standard_library import install_aliases
- import numpy
- import pandas as pd
- import sqlalchemy as sqla
- from sqlalchemy import (
- Boolean, Column, create_engine, DateTime, ForeignKey, Integer,
- MetaData, String, Table, Text,
- )
- from sqlalchemy.engine import url
- from sqlalchemy.engine.url import make_url
- from sqlalchemy.orm import relationship, sessionmaker, subqueryload
- from sqlalchemy.orm.session import make_transient
- from sqlalchemy.pool import NullPool
- from sqlalchemy.schema import UniqueConstraint
- from sqlalchemy_utils import EncryptedType
- import sqlparse
- from superset import app, db, db_engine_specs, security_manager, utils
- from superset.connectors.connector_registry import ConnectorRegistry
- from superset.legacy import update_time_range
- from superset.models.helpers import AuditMixinNullable, ImportMixin
- from superset.models.user_attributes import UserAttribute
- from superset.utils import MediumText
- from superset.viz import viz_types
- install_aliases()
- from urllib import parse # noqa
- config = app.config
- custom_password_store = config.get('SQLALCHEMY_CUSTOM_PASSWORD_STORE')
- stats_logger = config.get('STATS_LOGGER')
- metadata = Model.metadata # pylint: disable=no-member
- PASSWORD_MASK = 'X' * 10
- def set_related_perm(mapper, connection, target): # noqa
- src_class = target.cls_model
- id_ = target.datasource_id
- if id_:
- ds = db.session.query(src_class).filter_by(id=int(id_)).first()
- if ds:
- target.perm = ds.perm
- def copy_dashboard(mapper, connection, target):
- dashboard_id = config.get('DASHBOARD_TEMPLATE_ID')
- if dashboard_id is None:
- return
- Session = sessionmaker(autoflush=False)
- session = Session(bind=connection)
- new_user = session.query(User).filter_by(id=target.id).first()
- # copy template dashboard to user
- template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
- dashboard = Dashboard(
- dashboard_title=template.dashboard_title,
- position_json=template.position_json,
- description=template.description,
- css=template.css,
- json_metadata=template.json_metadata,
- slices=template.slices,
- owners=[new_user],
- )
- session.add(dashboard)
- session.commit()
- # set dashboard as the welcome dashboard
- extra_attributes = UserAttribute(
- user_id=target.id,
- welcome_dashboard_id=dashboard.id,
- )
- session.add(extra_attributes)
- session.commit()
- sqla.event.listen(User, 'after_insert', copy_dashboard)
- class Url(Model, AuditMixinNullable):
- """Used for the short url feature"""
- __tablename__ = 'url'
- id = Column(Integer, primary_key=True)
- url = Column(Text)
- class KeyValue(Model):
- """Used for any type of key-value store"""
- __tablename__ = 'keyvalue'
- id = Column(Integer, primary_key=True)
- value = Column(Text, nullable=False)
- class CssTemplate(Model, AuditMixinNullable):
- """CSS templates for dashboards"""
- __tablename__ = 'css_templates'
- id = Column(Integer, primary_key=True)
- template_name = Column(String(250))
- css = Column(Text, default='')
- slice_user = Table('slice_user', metadata,
- Column('id', Integer, primary_key=True),
- Column('user_id', Integer, ForeignKey('ab_user.id')),
- Column('slice_id', Integer, ForeignKey('slices.id')))
- class Slice(Model, AuditMixinNullable, ImportMixin):
- """A slice is essentially a report or a view on data"""
- __tablename__ = 'slices'
- id = Column(Integer, primary_key=True)
- slice_name = Column(String(250))
- datasource_id = Column(Integer)
- datasource_type = Column(String(200))
- datasource_name = Column(String(2000))
- viz_type = Column(String(250))
- params = Column(Text)
- description = Column(Text)
- cache_timeout = Column(Integer)
- perm = Column(String(1000))
- owners = relationship(security_manager.user_model, secondary=slice_user)
- export_fields = ('slice_name', 'datasource_type', 'datasource_name',
- 'viz_type', 'params', 'cache_timeout')
- def __repr__(self):
- return self.slice_name
- @property
- def cls_model(self):
- return ConnectorRegistry.sources[self.datasource_type]
- @property
- def datasource(self):
- return self.get_datasource
- def clone(self):
- return Slice(
- slice_name=self.slice_name,
- datasource_id=self.datasource_id,
- datasource_type=self.datasource_type,
- datasource_name=self.datasource_name,
- viz_type=self.viz_type,
- params=self.params,
- description=self.description,
- cache_timeout=self.cache_timeout)
- @datasource.getter
- @utils.memoized
- def get_datasource(self):
- return (
- db.session.query(self.cls_model)
- .filter_by(id=self.datasource_id)
- .first()
- )
- @renders('datasource_name')
- def datasource_link(self):
- # pylint: disable=no-member
- datasource = self.datasource
- return datasource.link if datasource else None
- def datasource_name_text(self):
- # pylint: disable=no-member
- datasource = self.datasource
- return datasource.name if datasource else None
- @property
- def datasource_edit_url(self):
- # pylint: disable=no-member
- datasource = self.datasource
- return datasource.url if datasource else None
- @property
- @utils.memoized
- def viz(self):
- d = json.loads(self.params)
- viz_class = viz_types[self.viz_type]
- # pylint: disable=no-member
- return viz_class(self.datasource, form_data=d)
- @property
- def description_markeddown(self):
- return utils.markdown(self.description)
- @property
- def data(self):
- """Data used to render slice in templates"""
- d = {}
- self.token = ''
- try:
- d = self.viz.data
- self.token = d.get('token')
- except Exception as e:
- logging.exception(e)
- d['error'] = str(e)
- return {
- 'datasource': self.datasource_name,
- 'description': self.description,
- 'description_markeddown': self.description_markeddown,
- 'edit_url': self.edit_url,
- 'form_data': self.form_data,
- 'slice_id': self.id,
- 'slice_name': self.slice_name,
- 'slice_url': self.slice_url,
- 'modified': self.modified(),
- 'changed_on': self.changed_on.isoformat(),
- }
- @property
- def json_data(self):
- return json.dumps(self.data)
- @property
- def form_data(self):
- form_data = {}
- try:
- form_data = json.loads(self.params)
- except Exception as e:
- logging.error("Malformed json in slice's params")
- logging.exception(e)
- form_data.update({
- 'slice_id': self.id,
- 'viz_type': self.viz_type,
- 'datasource': '{}__{}'.format(
- self.datasource_id, self.datasource_type),
- })
- if self.cache_timeout:
- form_data['cache_timeout'] = self.cache_timeout
- update_time_range(form_data)
- return form_data
- def get_explore_url(self, base_url='/superset/explore', overrides=None):
- overrides = overrides or {}
- form_data = {'slice_id': self.id}
- form_data.update(overrides)
- params = parse.quote(json.dumps(form_data))
- return (
- '{base_url}/?form_data={params}'.format(**locals()))
- @property
- def slice_url(self):
- """Defines the url to access the slice"""
- return self.get_explore_url()
- @property
- def explore_json_url(self):
- """Defines the url to access the slice"""
- return self.get_explore_url('/superset/explore_json')
- @property
- def edit_url(self):
- return '/chart/edit/{}'.format(self.id)
- @property
- def slice_link(self):
- url = self.slice_url
- name = escape(self.slice_name)
- return Markup('<a href="{url}">{name}</a>'.format(**locals()))
- def get_viz(self, force=False):
- """Creates :py:class:viz.BaseViz object from the url_params_multidict.
- :return: object of the 'viz_type' type that is taken from the
- url_params_multidict or self.params.
- :rtype: :py:class:viz.BaseViz
- """
- slice_params = json.loads(self.params)
- slice_params['slice_id'] = self.id
- slice_params['json'] = 'false'
- slice_params['slice_name'] = self.slice_name
- slice_params['viz_type'] = self.viz_type if self.viz_type else 'table'
- return viz_types[slice_params.get('viz_type')](
- self.datasource,
- form_data=slice_params,
- force=force,
- )
- @classmethod
- def import_obj(cls, slc_to_import, slc_to_override, import_time=None):
- """Inserts or overrides slc in the database.
- remote_id and import_time fields in params_dict are set to track the
- slice origin and ensure correct overrides for multiple imports.
- Slice.perm is used to find the datasources and connect them.
- :param Slice slc_to_import: Slice object to import
- :param Slice slc_to_override: Slice to replace, id matches remote_id
- :returns: The resulting id for the imported slice
- :rtype: int
- """
- session = db.session
- make_transient(slc_to_import)
- slc_to_import.dashboards = []
- slc_to_import.alter_params(
- remote_id=slc_to_import.id, import_time=import_time)
- slc_to_import = slc_to_import.copy()
- params = slc_to_import.params_dict
- slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name(
- session, slc_to_import.datasource_type, params['datasource_name'],
- params['schema'], params['database_name']).id
- if slc_to_override:
- slc_to_override.override(slc_to_import)
- session.flush()
- return slc_to_override.id
- session.add(slc_to_import)
- logging.info('Final slice: {}'.format(slc_to_import.to_json()))
- session.flush()
- return slc_to_import.id
- sqla.event.listen(Slice, 'before_insert', set_related_perm)
- sqla.event.listen(Slice, 'before_update', set_related_perm)
- dashboard_slices = Table(
- 'dashboard_slices', metadata,
- Column('id', Integer, primary_key=True),
- Column('dashboard_id', Integer, ForeignKey('dashboards.id')),
- Column('slice_id', Integer, ForeignKey('slices.id')),
- )
- dashboard_user = Table(
- 'dashboard_user', metadata,
- Column('id', Integer, primary_key=True),
- Column('user_id', Integer, ForeignKey('ab_user.id')),
- Column('dashboard_id', Integer, ForeignKey('dashboards.id')),
- )
- class Dashboard(Model, AuditMixinNullable, ImportMixin):
- """The dashboard object!"""
- __tablename__ = 'dashboards'
- id = Column(Integer, primary_key=True)
- dashboard_title = Column(String(500))
- position_json = Column(MediumText())
- description = Column(Text)
- css = Column(Text)
- json_metadata = Column(Text)
- slug = Column(String(255), unique=True)
- slices = relationship(
- 'Slice', secondary=dashboard_slices, backref='dashboards')
- owners = relationship(security_manager.user_model, secondary=dashboard_user)
- export_fields = ('dashboard_title', 'position_json', 'json_metadata',
- 'description', 'css', 'slug')
- def __repr__(self):
- return self.dashboard_title
- @property
- def table_names(self):
- # pylint: disable=no-member
- return ', '.join(
- {'{}'.format(s.datasource.full_name) for s in self.slices})
- @property
- def url(self):
- if self.json_metadata:
- # add default_filters to the preselect_filters of dashboard
- json_metadata = json.loads(self.json_metadata)
- default_filters = json_metadata.get('default_filters')
- # make sure default_filters is not empty and is valid
- if default_filters and default_filters != '{}':
- try:
- if json.loads(default_filters):
- filters = parse.quote(default_filters.encode('utf8'))
- return '/superset/dashboard/{}/?preselect_filters={}'.format(
- self.slug or self.id, filters)
- except Exception:
- pass
- return '/superset/dashboard/{}/'.format(self.slug or self.id)
- @property
- def datasources(self):
- return {slc.datasource for slc in self.slices}
- @property
- def sqla_metadata(self):
- # pylint: disable=no-member
- metadata = MetaData(bind=self.get_sqla_engine())
- return metadata.reflect()
- def dashboard_link(self):
- title = escape(self.dashboard_title)
- return Markup(
- '<a href="{self.url}">{title}</a>'.format(**locals()))
- @property
- def data(self):
- positions = self.position_json
- if positions:
- positions = json.loads(positions)
- return {
- 'id': self.id,
- 'metadata': self.params_dict,
- 'css': self.css,
- 'dashboard_title': self.dashboard_title,
- 'slug': self.slug,
- 'slices': [slc.data for slc in self.slices],
- 'position_json': positions,
- }
- @property
- def params(self):
- return self.json_metadata
- @params.setter
- def params(self, value):
- self.json_metadata = value
- @property
- def position(self):
- if self.position_json:
- return json.loads(self.position_json)
- return {}
- @classmethod
- def import_obj(cls, dashboard_to_import, import_time=None):
- """Imports the dashboard from the object to the database.
- Once dashboard is imported, json_metadata field is extended and stores
- remote_id and import_time. It helps to decide if the dashboard has to
- be overridden or just copies over. Slices that belong to this
- dashboard will be wired to existing tables. This function can be used
- to import/export dashboards between multiple superset instances.
- Audit metadata isn't copied over.
- """
- def alter_positions(dashboard, old_to_new_slc_id_dict):
- """ Updates slice_ids in the position json.
- Sample position_json data:
- {
- "DASHBOARD_VERSION_KEY": "v2",
- "DASHBOARD_ROOT_ID": {
- "type": "DASHBOARD_ROOT_TYPE",
- "id": "DASHBOARD_ROOT_ID",
- "children": ["DASHBOARD_GRID_ID"]
- },
- "DASHBOARD_GRID_ID": {
- "type": "DASHBOARD_GRID_TYPE",
- "id": "DASHBOARD_GRID_ID",
- "children": ["DASHBOARD_CHART_TYPE-2"]
- },
- "DASHBOARD_CHART_TYPE-2": {
- "type": "DASHBOARD_CHART_TYPE",
- "id": "DASHBOARD_CHART_TYPE-2",
- "children": [],
- "meta": {
- "width": 4,
- "height": 50,
- "chartId": 118
- }
- },
- }
- """
- position_data = json.loads(dashboard.position_json)
- position_json = position_data.values()
- for value in position_json:
- if (isinstance(value, dict) and value.get('meta') and
- value.get('meta').get('chartId')):
- old_slice_id = value.get('meta').get('chartId')
- if old_slice_id in old_to_new_slc_id_dict:
- value['meta']['chartId'] = (
- old_to_new_slc_id_dict[old_slice_id]
- )
- dashboard.position_json = json.dumps(position_data)
- logging.info('Started import of the dashboard: {}'
- .format(dashboard_to_import.to_json()))
- session = db.session
- logging.info('Dashboard has {} slices'
- .format(len(dashboard_to_import.slices)))
- # copy slices object as Slice.import_slice will mutate the slice
- # and will remove the existing dashboard - slice association
- slices = copy(dashboard_to_import.slices)
- old_to_new_slc_id_dict = {}
- new_filter_immune_slices = []
- new_timed_refresh_immune_slices = []
- new_expanded_slices = {}
- i_params_dict = dashboard_to_import.params_dict
- remote_id_slice_map = {
- slc.params_dict['remote_id']: slc
- for slc in session.query(Slice).all()
- if 'remote_id' in slc.params_dict
- }
- for slc in slices:
- logging.info('Importing slice {} from the dashboard: {}'.format(
- slc.to_json(), dashboard_to_import.dashboard_title))
- remote_slc = remote_id_slice_map.get(slc.id)
- new_slc_id = Slice.import_obj(slc, remote_slc, import_time=import_time)
- old_to_new_slc_id_dict[slc.id] = new_slc_id
- # update json metadata that deals with slice ids
- new_slc_id_str = '{}'.format(new_slc_id)
- old_slc_id_str = '{}'.format(slc.id)
- if ('filter_immune_slices' in i_params_dict and
- old_slc_id_str in i_params_dict['filter_immune_slices']):
- new_filter_immune_slices.append(new_slc_id_str)
- if ('timed_refresh_immune_slices' in i_params_dict and
- old_slc_id_str in
- i_params_dict['timed_refresh_immune_slices']):
- new_timed_refresh_immune_slices.append(new_slc_id_str)
- if ('expanded_slices' in i_params_dict and
- old_slc_id_str in i_params_dict['expanded_slices']):
- new_expanded_slices[new_slc_id_str] = (
- i_params_dict['expanded_slices'][old_slc_id_str])
- # override the dashboard
- existing_dashboard = None
- for dash in session.query(Dashboard).all():
- if ('remote_id' in dash.params_dict and
- dash.params_dict['remote_id'] ==
- dashboard_to_import.id):
- existing_dashboard = dash
- dashboard_to_import.id = None
- alter_positions(dashboard_to_import, old_to_new_slc_id_dict)
- dashboard_to_import.alter_params(import_time=import_time)
- if new_expanded_slices:
- dashboard_to_import.alter_params(
- expanded_slices=new_expanded_slices)
- if new_filter_immune_slices:
- dashboard_to_import.alter_params(
- filter_immune_slices=new_filter_immune_slices)
- if new_timed_refresh_immune_slices:
- dashboard_to_import.alter_params(
- timed_refresh_immune_slices=new_timed_refresh_immune_slices)
- new_slices = session.query(Slice).filter(
- Slice.id.in_(old_to_new_slc_id_dict.values())).all()
- if existing_dashboard:
- existing_dashboard.override(dashboard_to_import)
- existing_dashboard.slices = new_slices
- session.flush()
- return existing_dashboard.id
- else:
- # session.add(dashboard_to_import) causes sqlachemy failures
- # related to the attached users / slices. Creating new object
- # allows to avoid conflicts in the sql alchemy state.
- copied_dash = dashboard_to_import.copy()
- copied_dash.slices = new_slices
- session.add(copied_dash)
- session.flush()
- return copied_dash.id
- @classmethod
- def export_dashboards(cls, dashboard_ids):
- copied_dashboards = []
- datasource_ids = set()
- for dashboard_id in dashboard_ids:
- # make sure that dashboard_id is an integer
- dashboard_id = int(dashboard_id)
- copied_dashboard = (
- db.session.query(Dashboard)
- .options(subqueryload(Dashboard.slices))
- .filter_by(id=dashboard_id).first()
- )
- make_transient(copied_dashboard)
- for slc in copied_dashboard.slices:
- datasource_ids.add((slc.datasource_id, slc.datasource_type))
- # add extra params for the import
- slc.alter_params(
- remote_id=slc.id,
- datasource_name=slc.datasource.name,
- schema=slc.datasource.name,
- database_name=slc.datasource.database.name,
- )
- copied_dashboard.alter_params(remote_id=dashboard_id)
- copied_dashboards.append(copied_dashboard)
- eager_datasources = []
- for dashboard_id, dashboard_type in datasource_ids:
- eager_datasource = ConnectorRegistry.get_eager_datasource(
- db.session, dashboard_type, dashboard_id)
- eager_datasource.alter_params(
- remote_id=eager_datasource.id,
- database_name=eager_datasource.database.name,
- )
- make_transient(eager_datasource)
- eager_datasources.append(eager_datasource)
- return json.dumps({
- 'dashboards': copied_dashboards,
- 'datasources': eager_datasources,
- }, cls=utils.DashboardEncoder, indent=4)
- class Database(Model, AuditMixinNullable, ImportMixin):
- """An ORM object that stores Database related information"""
- __tablename__ = 'dbs'
- type = 'table'
- __table_args__ = (UniqueConstraint('database_name'),)
- id = Column(Integer, primary_key=True)
- verbose_name = Column(String(250), unique=True)
- # short unique name, used in permissions
- database_name = Column(String(250), unique=True)
- sqlalchemy_uri = Column(String(1024))
- password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
- cache_timeout = Column(Integer)
- select_as_create_table_as = Column(Boolean, default=False)
- expose_in_sqllab = Column(Boolean, default=False)
- allow_run_sync = Column(Boolean, default=True)
- allow_run_async = Column(Boolean, default=False)
- allow_csv_upload = Column(Boolean, default=False)
- allow_ctas = Column(Boolean, default=False)
- allow_dml = Column(Boolean, default=False)
- force_ctas_schema = Column(String(250))
- allow_multi_schema_metadata_fetch = Column(Boolean, default=True)
- extra = Column(Text, default=textwrap.dedent("""\
- {
- "metadata_params": {},
- "engine_params": {},
- "metadata_cache_timeout": {},
- "schemas_allowed_for_csv_upload": []
- }
- """))
- perm = Column(String(1000))
- impersonate_user = Column(Boolean, default=False)
- export_fields = ('database_name', 'sqlalchemy_uri', 'cache_timeout',
- 'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
- 'allow_ctas', 'allow_csv_upload', 'extra')
- export_children = ['tables']
- def __repr__(self):
- return self.verbose_name if self.verbose_name else self.database_name
- @property
- def name(self):
- return self.verbose_name if self.verbose_name else self.database_name
- @property
- def allows_subquery(self):
- return self.db_engine_spec.allows_subquery
- @property
- def data(self):
- return {
- 'id': self.id,
- 'name': self.database_name,
- 'backend': self.backend,
- 'allow_multi_schema_metadata_fetch':
- self.allow_multi_schema_metadata_fetch,
- 'allows_subquery': self.allows_subquery,
- }
- @property
- def unique_name(self):
- return self.database_name
- @property
- def url_object(self):
- return make_url(self.sqlalchemy_uri_decrypted)
- @property
- def backend(self):
- url = make_url(self.sqlalchemy_uri_decrypted)
- return url.get_backend_name()
- @classmethod
- def get_password_masked_url_from_uri(cls, uri):
- url = make_url(uri)
- return cls.get_password_masked_url(url)
- @classmethod
- def get_password_masked_url(cls, url):
- url_copy = deepcopy(url)
- if url_copy.password is not None and url_copy.password != PASSWORD_MASK:
- url_copy.password = PASSWORD_MASK
- return url_copy
- def set_sqlalchemy_uri(self, uri):
- conn = sqla.engine.url.make_url(uri.strip())
- if conn.password != PASSWORD_MASK and not custom_password_store:
- # do not over-write the password with the password mask
- self.password = conn.password
- conn.password = PASSWORD_MASK if conn.password else None
- self.sqlalchemy_uri = str(conn) # hides the password
- def get_effective_user(self, url, user_name=None):
- """
- Get the effective user, especially during impersonation.
- :param url: SQL Alchemy URL object
- :param user_name: Default username
- :return: The effective username
- """
- effective_username = None
- if self.impersonate_user:
- effective_username = url.username
- if user_name:
- effective_username = user_name
- elif (
- hasattr(g, 'user') and hasattr(g.user, 'username') and
- g.user.username is not None
- ):
- effective_username = g.user.username
- return effective_username
- @utils.memoized(
- watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
- def get_sqla_engine(self, schema=None, nullpool=True, user_name=None):
- extra = self.get_extra()
- url = make_url(self.sqlalchemy_uri_decrypted)
- url = self.db_engine_spec.adjust_database_uri(url, schema)
- effective_username = self.get_effective_user(url, user_name)
- # If using MySQL or Presto for example, will set url.username
- # If using Hive, will not do anything yet since that relies on a
- # configuration parameter instead.
- self.db_engine_spec.modify_url_for_impersonation(
- url,
- self.impersonate_user,
- effective_username)
- masked_url = self.get_password_masked_url(url)
- logging.info('Database.get_sqla_engine(). Masked URL: {0}'.format(masked_url))
- params = extra.get('engine_params', {})
- if nullpool:
- params['poolclass'] = NullPool
- # If using Hive, this will set hive.server2.proxy.user=$effective_username
- configuration = {}
- configuration.update(
- self.db_engine_spec.get_configuration_for_impersonation(
- str(url),
- self.impersonate_user,
- effective_username))
- if configuration:
- params['connect_args'] = {'configuration': configuration}
- DB_CONNECTION_MUTATOR = config.get('DB_CONNECTION_MUTATOR')
- if DB_CONNECTION_MUTATOR:
- url, params = DB_CONNECTION_MUTATOR(
- url, params, effective_username, security_manager)
- return create_engine(url, **params)
- def get_reserved_words(self):
- return self.get_dialect().preparer.reserved_words
- def get_quoter(self):
- return self.get_dialect().identifier_preparer.quote
- def get_df(self, sql, schema):
- sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)]
- engine = self.get_sqla_engine(schema=schema)
- def needs_conversion(df_series):
- if df_series.empty:
- return False
- if isinstance(df_series[0], (list, dict)):
- return True
- return False
- with closing(engine.raw_connection()) as conn:
- with closing(conn.cursor()) as cursor:
- for sql in sqls[:-1]:
- self.db_engine_spec.execute(cursor, sql)
- cursor.fetchall()
- self.db_engine_spec.execute(cursor, sqls[-1])
- if cursor.description is not None:
- columns = [col_desc[0] for col_desc in cursor.description]
- else:
- columns = []
- df = pd.DataFrame.from_records(
- data=list(cursor.fetchall()),
- columns=columns,
- coerce_float=True,
- )
- for k, v in df.dtypes.items():
- if v.type == numpy.object_ and needs_conversion(df[k]):
- df[k] = df[k].apply(utils.json_dumps_w_dates)
- return df
- def compile_sqla_query(self, qry, schema=None):
- engine = self.get_sqla_engine(schema=schema)
- sql = str(
- qry.compile(
- engine,
- compile_kwargs={'literal_binds': True},
- ),
- )
- if engine.dialect.identifier_preparer._double_percents:
- sql = sql.replace('%%', '%')
- return sql
- def select_star(
- self, table_name, schema=None, limit=100, show_cols=False,
- indent=True, latest_partition=False, cols=None):
- """Generates a ``select *`` statement in the proper dialect"""
- eng = self.get_sqla_engine(schema=schema)
- return self.db_engine_spec.select_star(
- self, table_name, schema=schema, engine=eng,
- limit=limit, show_cols=show_cols,
- indent=indent, latest_partition=latest_partition, cols=cols)
- def apply_limit_to_sql(self, sql, limit=1000):
- return self.db_engine_spec.apply_limit_to_sql(sql, limit, self)
- def safe_sqlalchemy_uri(self):
- return self.sqlalchemy_uri
- @property
- def inspector(self):
- engine = self.get_sqla_engine()
- return sqla.inspect(engine)
- def all_table_names(self, schema=None, force=False):
- if not schema:
- if not self.allow_multi_schema_metadata_fetch:
- return []
- tables_dict = self.db_engine_spec.fetch_result_sets(
- self, 'table', force=force)
- return tables_dict.get('', [])
- return sorted(
- self.db_engine_spec.get_table_names(schema, self.inspector))
- def all_view_names(self, schema=None, force=False):
- if not schema:
- if not self.allow_multi_schema_metadata_fetch:
- return []
- views_dict = self.db_engine_spec.fetch_result_sets(
- self, 'view', force=force)
- return views_dict.get('', [])
- views = []
- try:
- views = self.inspector.get_view_names(schema)
- except Exception:
- pass
- return views
- def all_schema_names(self, force_refresh=False):
- extra = self.get_extra()
- medatada_cache_timeout = extra.get('metadata_cache_timeout', {})
- schema_cache_timeout = medatada_cache_timeout.get('schema_cache_timeout')
- enable_cache = 'schema_cache_timeout' in medatada_cache_timeout
- return sorted(self.db_engine_spec.get_schema_names(
- inspector=self.inspector,
- enable_cache=enable_cache,
- cache_timeout=schema_cache_timeout,
- db_id=self.id,
- force=force_refresh))
- @property
- def db_engine_spec(self):
- return db_engine_specs.engines.get(
- self.backend, db_engine_specs.BaseEngineSpec)
- @classmethod
- def get_db_engine_spec_for_backend(cls, backend):
- return db_engine_specs.engines.get(backend, db_engine_specs.BaseEngineSpec)
- def grains(self):
- """Defines time granularity database-specific expressions.
- The idea here is to make it easy for users to change the time grain
- form a datetime (maybe the source grain is arbitrary timestamps, daily
- or 5 minutes increments) to another, "truncated" datetime. Since
- each database has slightly different but similar datetime functions,
- this allows a mapping between database engines and actual functions.
- """
- return self.db_engine_spec.get_time_grains()
- def grains_dict(self):
- """Allowing to lookup grain by either label or duration
- For backward compatibility"""
- d = {grain.duration: grain for grain in self.grains()}
- d.update({grain.label: grain for grain in self.grains()})
- return d
- def get_extra(self):
- extra = {}
- if self.extra:
- try:
- extra = json.loads(self.extra)
- except Exception as e:
- logging.error(e)
- raise e
- return extra
- def get_table(self, table_name, schema=None):
- extra = self.get_extra()
- meta = MetaData(**extra.get('metadata_params', {}))
- return Table(
- table_name, meta,
- schema=schema or None,
- autoload=True,
- autoload_with=self.get_sqla_engine())
- def get_columns(self, table_name, schema=None):
- return self.inspector.get_columns(table_name, schema)
- def get_indexes(self, table_name, schema=None):
- return self.inspector.get_indexes(table_name, schema)
- def get_pk_constraint(self, table_name, schema=None):
- return self.inspector.get_pk_constraint(table_name, schema)
- def get_foreign_keys(self, table_name, schema=None):
- return self.inspector.get_foreign_keys(table_name, schema)
- def get_schema_access_for_csv_upload(self):
- return self.get_extra().get('schemas_allowed_for_csv_upload', [])
- @property
- def sqlalchemy_uri_decrypted(self):
- conn = sqla.engine.url.make_url(self.sqlalchemy_uri)
- if custom_password_store:
- conn.password = custom_password_store(conn)
- else:
- conn.password = self.password
- return str(conn)
- @property
- def sql_url(self):
- return '/superset/sql/{}/'.format(self.id)
- def get_perm(self):
- return (
- '[{obj.database_name}].(id:{obj.id})').format(obj=self)
- def has_table(self, table):
- engine = self.get_sqla_engine()
- return engine.has_table(
- table.table_name, table.schema or None)
- @utils.memoized
- def get_dialect(self):
- sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
- return sqla_url.get_dialect()()
- sqla.event.listen(Database, 'after_insert', security_manager.set_perm)
- sqla.event.listen(Database, 'after_update', security_manager.set_perm)
- class Log(Model):
- """ORM object used to log Superset actions to the database"""
- __tablename__ = 'logs'
- id = Column(Integer, primary_key=True)
- action = Column(String(512))
- user_id = Column(Integer, ForeignKey('ab_user.id'))
- dashboard_id = Column(Integer)
- slice_id = Column(Integer)
- json = Column(Text)
- user = relationship(
- security_manager.user_model, backref='logs', foreign_keys=[user_id])
- dttm = Column(DateTime, default=datetime.utcnow)
- duration_ms = Column(Integer)
- referrer = Column(String(1024))
- @classmethod
- def log_this(cls, f):
- """Decorator to log user actions"""
- @functools.wraps(f)
- def wrapper(*args, **kwargs):
- user_id = None
- if g.user:
- user_id = g.user.get_id()
- d = request.form.to_dict() or {}
- # request parameters can overwrite post body
- request_params = request.args.to_dict()
- d.update(request_params)
- d.update(kwargs)
- slice_id = d.get('slice_id')
- dashboard_id = d.get('dashboard_id')
- try:
- slice_id = int(
- slice_id or json.loads(d.get('form_data')).get('slice_id'))
- except (ValueError, TypeError):
- slice_id = 0
- stats_logger.incr(f.__name__)
- start_dttm = datetime.now()
- value = f(*args, **kwargs)
- duration_ms = (datetime.now() - start_dttm).total_seconds() * 1000
- # bulk insert
- try:
- explode_by = d.get('explode')
- records = json.loads(d.get(explode_by))
- except Exception:
- records = [d]
- referrer = request.referrer[:1000] if request.referrer else None
- logs = []
- for record in records:
- try:
- json_string = json.dumps(record)
- except Exception:
- json_string = None
- log = cls(
- action=f.__name__,
- json=json_string,
- dashboard_id=dashboard_id,
- slice_id=slice_id,
- duration_ms=duration_ms,
- referrer=referrer,
- user_id=user_id)
- logs.append(log)
- sesh = db.session()
- sesh.bulk_save_objects(logs)
- sesh.commit()
- return value
- return wrapper
- class FavStar(Model):
- __tablename__ = 'favstar'
- id = Column(Integer, primary_key=True)
- user_id = Column(Integer, ForeignKey('ab_user.id'))
- class_name = Column(String(50))
- obj_id = Column(Integer)
- dttm = Column(DateTime, default=datetime.utcnow)
- class DatasourceAccessRequest(Model, AuditMixinNullable):
- """ORM model for the access requests for datasources and dbs."""
- __tablename__ = 'access_request'
- id = Column(Integer, primary_key=True)
- datasource_id = Column(Integer)
- datasource_type = Column(String(200))
- ROLES_BLACKLIST = set(config.get('ROBOT_PERMISSION_ROLES', []))
- @property
- def cls_model(self):
- return ConnectorRegistry.sources[self.datasource_type]
- @property
- def username(self):
- return self.creator()
- @property
- def datasource(self):
- return self.get_datasource
- @datasource.getter
- @utils.memoized
- def get_datasource(self):
- # pylint: disable=no-member
- ds = db.session.query(self.cls_model).filter_by(
- id=self.datasource_id).first()
- return ds
- @property
- def datasource_link(self):
- return self.datasource.link # pylint: disable=no-member
- @property
- def roles_with_datasource(self):
- action_list = ''
- perm = self.datasource.perm # pylint: disable=no-member
- pv = security_manager.find_permission_view_menu('datasource_access', perm)
- for r in pv.role:
- if r.name in self.ROLES_BLACKLIST:
- continue
- url = (
- '/superset/approve?datasource_type={self.datasource_type}&'
- 'datasource_id={self.datasource_id}&'
- 'created_by={self.created_by.username}&role_to_grant={r.name}'
- .format(**locals())
- )
- href = '<a href="{}">Grant {} Role</a>'.format(url, r.name)
- action_list = action_list + '<li>' + href + '</li>'
- return '<ul>' + action_list + '</ul>'
- @property
- def user_roles(self):
- action_list = ''
- for r in self.created_by.roles: # pylint: disable=no-member
- url = (
- '/superset/approve?datasource_type={self.datasource_type}&'
- 'datasource_id={self.datasource_id}&'
- 'created_by={self.created_by.username}&role_to_extend={r.name}'
- .format(**locals())
- )
- href = '<a href="{}">Extend {} Role</a>'.format(url, r.name)
- if r.name in self.ROLES_BLACKLIST:
- href = '{} Role'.format(r.name)
- action_list = action_list + '<li>' + href + '</li>'
- return '<ul>' + action_list + '</ul>'
|