helpers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # pylint: disable=C,R,W
  2. """a collection of model-related helper classes and functions"""
  3. from datetime import datetime
  4. import json
  5. import logging
  6. import re
  7. from flask import escape, Markup
  8. from flask_appbuilder.models.decorators import renders
  9. from flask_appbuilder.models.mixins import AuditMixin
  10. import humanize
  11. import sqlalchemy as sa
  12. from sqlalchemy import and_, or_, UniqueConstraint
  13. from sqlalchemy.ext.declarative import declared_attr
  14. from sqlalchemy.orm.exc import MultipleResultsFound
  15. import yaml
  16. from superset.utils import QueryStatus
  17. def json_to_dict(json_str):
  18. if json_str:
  19. val = re.sub(',[ \t\r\n]+}', '}', json_str)
  20. val = re.sub(',[ \t\r\n]+\]', ']', val)
  21. return json.loads(val)
  22. else:
  23. return {}
  24. class ImportMixin(object):
  25. export_parent = None
  26. # The name of the attribute
  27. # with the SQL Alchemy back reference
  28. export_children = []
  29. # List of (str) names of attributes
  30. # with the SQL Alchemy forward references
  31. export_fields = []
  32. # The names of the attributes
  33. # that are available for import and export
  34. @classmethod
  35. def _parent_foreign_key_mappings(cls):
  36. """Get a mapping of foreign name to the local name of foreign keys"""
  37. parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
  38. if parent_rel:
  39. return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
  40. return {}
  41. @classmethod
  42. def _unique_constrains(cls):
  43. """Get all (single column and multi column) unique constraints"""
  44. unique = [{c.name for c in u.columns} for u in cls.__table_args__
  45. if isinstance(u, UniqueConstraint)]
  46. unique.extend({c.name} for c in cls.__table__.columns if c.unique)
  47. return unique
  48. @classmethod
  49. def export_schema(cls, recursive=True, include_parent_ref=False):
  50. """Export schema as a dictionary"""
  51. parent_excludes = {}
  52. if not include_parent_ref:
  53. parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
  54. if parent_ref:
  55. parent_excludes = {c.name for c in parent_ref.local_columns}
  56. def formatter(c):
  57. return ('{0} Default ({1})'.format(
  58. str(c.type), c.default.arg) if c.default else str(c.type))
  59. schema = {c.name: formatter(c) for c in cls.__table__.columns
  60. if (c.name in cls.export_fields and
  61. c.name not in parent_excludes)}
  62. if recursive:
  63. for c in cls.export_children:
  64. child_class = cls.__mapper__.relationships[c].argument.class_
  65. schema[c] = [child_class.export_schema(recursive=recursive,
  66. include_parent_ref=include_parent_ref)]
  67. return schema
  68. @classmethod
  69. def import_from_dict(cls, session, dict_rep, parent=None,
  70. recursive=True, sync=[]):
  71. """Import obj from a dictionary"""
  72. parent_refs = cls._parent_foreign_key_mappings()
  73. export_fields = set(cls.export_fields) | set(parent_refs.keys())
  74. new_children = {c: dict_rep.get(c) for c in cls.export_children
  75. if c in dict_rep}
  76. unique_constrains = cls._unique_constrains()
  77. filters = [] # Using these filters to check if obj already exists
  78. # Remove fields that should not get imported
  79. for k in list(dict_rep):
  80. if k not in export_fields:
  81. del dict_rep[k]
  82. if not parent:
  83. if cls.export_parent:
  84. for p in parent_refs.keys():
  85. if p not in dict_rep:
  86. raise RuntimeError(
  87. '{0}: Missing field {1}'.format(cls.__name__, p))
  88. else:
  89. # Set foreign keys to parent obj
  90. for k, v in parent_refs.items():
  91. dict_rep[k] = getattr(parent, v)
  92. # Add filter for parent obj
  93. filters.extend([getattr(cls, k) == dict_rep.get(k)
  94. for k in parent_refs.keys()])
  95. # Add filter for unique constraints
  96. ucs = [and_(*[getattr(cls, k) == dict_rep.get(k)
  97. for k in cs if dict_rep.get(k) is not None])
  98. for cs in unique_constrains]
  99. filters.append(or_(*ucs))
  100. # Check if object already exists in DB, break if more than one is found
  101. try:
  102. obj_query = session.query(cls).filter(and_(*filters))
  103. obj = obj_query.one_or_none()
  104. except MultipleResultsFound as e:
  105. logging.error('Error importing %s \n %s \n %s', cls.__name__,
  106. str(obj_query),
  107. yaml.safe_dump(dict_rep))
  108. raise e
  109. if not obj:
  110. is_new_obj = True
  111. # Create new DB object
  112. obj = cls(**dict_rep)
  113. logging.info('Importing new %s %s', obj.__tablename__, str(obj))
  114. if cls.export_parent and parent:
  115. setattr(obj, cls.export_parent, parent)
  116. session.add(obj)
  117. else:
  118. is_new_obj = False
  119. logging.info('Updating %s %s', obj.__tablename__, str(obj))
  120. # Update columns
  121. for k, v in dict_rep.items():
  122. setattr(obj, k, v)
  123. # Recursively create children
  124. if recursive:
  125. for c in cls.export_children:
  126. child_class = cls.__mapper__.relationships[c].argument.class_
  127. added = []
  128. for c_obj in new_children.get(c, []):
  129. added.append(child_class.import_from_dict(session=session,
  130. dict_rep=c_obj,
  131. parent=obj,
  132. sync=sync))
  133. # If children should get synced, delete the ones that did not
  134. # get updated.
  135. if c in sync and not is_new_obj:
  136. back_refs = child_class._parent_foreign_key_mappings()
  137. delete_filters = [getattr(child_class, k) ==
  138. getattr(obj, back_refs.get(k))
  139. for k in back_refs.keys()]
  140. to_delete = set(session.query(child_class).filter(
  141. and_(*delete_filters))).difference(set(added))
  142. for o in to_delete:
  143. logging.info('Deleting %s %s', c, str(obj))
  144. session.delete(o)
  145. return obj
  146. def export_to_dict(self, recursive=True, include_parent_ref=False,
  147. include_defaults=False):
  148. """Export obj to dictionary"""
  149. cls = self.__class__
  150. parent_excludes = {}
  151. if recursive and not include_parent_ref:
  152. parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
  153. if parent_ref:
  154. parent_excludes = {c.name for c in parent_ref.local_columns}
  155. dict_rep = {c.name: getattr(self, c.name)
  156. for c in cls.__table__.columns
  157. if (c.name in self.export_fields and
  158. c.name not in parent_excludes and
  159. (include_defaults or (
  160. getattr(self, c.name) is not None and
  161. (not c.default or
  162. getattr(self, c.name) != c.default.arg))))
  163. }
  164. if recursive:
  165. for c in self.export_children:
  166. # sorting to make lists of children stable
  167. dict_rep[c] = sorted(
  168. [
  169. child.export_to_dict(
  170. recursive=recursive,
  171. include_parent_ref=include_parent_ref,
  172. include_defaults=include_defaults,
  173. ) for child in getattr(self, c)
  174. ],
  175. key=lambda k: sorted(k.items()))
  176. return dict_rep
  177. def override(self, obj):
  178. """Overrides the plain fields of the dashboard."""
  179. for field in obj.__class__.export_fields:
  180. setattr(self, field, getattr(obj, field))
  181. def copy(self):
  182. """Creates a copy of the dashboard without relationships."""
  183. new_obj = self.__class__()
  184. new_obj.override(self)
  185. return new_obj
  186. def alter_params(self, **kwargs):
  187. d = self.params_dict
  188. d.update(kwargs)
  189. self.params = json.dumps(d)
  190. @property
  191. def params_dict(self):
  192. return json_to_dict(self.params)
  193. @property
  194. def template_params_dict(self):
  195. return json_to_dict(self.template_params)
  196. class AuditMixinNullable(AuditMixin):
  197. """Altering the AuditMixin to use nullable fields
  198. Allows creating objects programmatically outside of CRUD
  199. """
  200. created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True)
  201. changed_on = sa.Column(
  202. sa.DateTime, default=datetime.now,
  203. onupdate=datetime.now, nullable=True)
  204. @declared_attr
  205. def created_by_fk(self): # noqa
  206. return sa.Column(
  207. sa.Integer, sa.ForeignKey('ab_user.id'),
  208. default=self.get_user_id, nullable=True)
  209. @declared_attr
  210. def changed_by_fk(self): # noqa
  211. return sa.Column(
  212. sa.Integer, sa.ForeignKey('ab_user.id'),
  213. default=self.get_user_id, onupdate=self.get_user_id, nullable=True)
  214. def _user_link(self, user):
  215. if not user:
  216. return ''
  217. url = '/superset/profile/{}/'.format(user.username)
  218. return Markup('<a href="{}">{}</a>'.format(url, escape(user) or ''))
  219. def changed_by_name(self):
  220. if self.created_by:
  221. return escape('{}'.format(self.created_by))
  222. return ''
  223. @renders('created_by')
  224. def creator(self): # noqa
  225. return self._user_link(self.created_by)
  226. @property
  227. def changed_by_(self):
  228. return self._user_link(self.changed_by)
  229. @renders('changed_on')
  230. def changed_on_(self):
  231. return Markup(
  232. '<span class="no-wrap">{}</span>'.format(self.changed_on))
  233. @renders('changed_on')
  234. def modified(self):
  235. return humanize.naturaltime(datetime.now() - self.changed_on)
  236. @property
  237. def icons(self):
  238. return """
  239. <a
  240. href="{self.datasource_edit_url}"
  241. data-toggle="tooltip"
  242. title="{self.datasource}">
  243. <i class="fa fa-database"></i>
  244. </a>
  245. """.format(**locals())
  246. class QueryResult(object):
  247. """Object returned by the query interface"""
  248. def __init__( # noqa
  249. self,
  250. df,
  251. query,
  252. duration,
  253. status=QueryStatus.SUCCESS,
  254. error_message=None):
  255. self.df = df
  256. self.query = query
  257. self.duration = duration
  258. self.status = status
  259. self.error_message = error_message