helpers.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. """a collection of model-related helper classes and functions"""
  18. import json
  19. import logging
  20. import re
  21. from datetime import datetime
  22. from typing import List, Optional
  23. # isort and pylint disagree, isort should win
  24. # pylint: disable=ungrouped-imports
  25. import humanize
  26. import pandas as pd
  27. import sqlalchemy as sa
  28. import yaml
  29. from flask import escape, g, Markup
  30. from flask_appbuilder.models.decorators import renders
  31. from flask_appbuilder.models.mixins import AuditMixin
  32. from sqlalchemy import and_, or_, UniqueConstraint
  33. from sqlalchemy.ext.declarative import declared_attr
  34. from sqlalchemy.orm.exc import MultipleResultsFound
  35. from superset.utils.core import QueryStatus
  36. logger = logging.getLogger(__name__)
  37. def json_to_dict(json_str):
  38. if json_str:
  39. val = re.sub(",[ \t\r\n]+}", "}", json_str)
  40. val = re.sub(
  41. ",[ \t\r\n]+\]", "]", val # pylint: disable=anomalous-backslash-in-string
  42. )
  43. return json.loads(val)
  44. return {}
  45. class ImportMixin:
  46. export_parent: Optional[str] = None
  47. # The name of the attribute
  48. # with the SQL Alchemy back reference
  49. export_children: List[str] = []
  50. # List of (str) names of attributes
  51. # with the SQL Alchemy forward references
  52. export_fields: List[str] = []
  53. # The names of the attributes
  54. # that are available for import and export
  55. @classmethod
  56. def _parent_foreign_key_mappings(cls):
  57. """Get a mapping of foreign name to the local name of foreign keys"""
  58. parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
  59. if parent_rel:
  60. return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
  61. return {}
  62. @classmethod
  63. def _unique_constrains(cls):
  64. """Get all (single column and multi column) unique constraints"""
  65. unique = [
  66. {c.name for c in u.columns}
  67. for u in cls.__table_args__
  68. if isinstance(u, UniqueConstraint)
  69. ]
  70. unique.extend({c.name} for c in cls.__table__.columns if c.unique)
  71. return unique
  72. @classmethod
  73. def export_schema(cls, recursive=True, include_parent_ref=False):
  74. """Export schema as a dictionary"""
  75. parent_excludes = {}
  76. if not include_parent_ref:
  77. parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
  78. if parent_ref:
  79. parent_excludes = {column.name for column in parent_ref.local_columns}
  80. def formatter(column):
  81. return (
  82. "{0} Default ({1})".format(str(column.type), column.default.arg)
  83. if column.default
  84. else str(column.type)
  85. )
  86. schema = {
  87. column.name: formatter(column)
  88. for column in cls.__table__.columns
  89. if (column.name in cls.export_fields and column.name not in parent_excludes)
  90. }
  91. if recursive:
  92. for column in cls.export_children:
  93. child_class = cls.__mapper__.relationships[column].argument.class_
  94. schema[column] = [
  95. child_class.export_schema(
  96. recursive=recursive, include_parent_ref=include_parent_ref
  97. )
  98. ]
  99. return schema
  100. @classmethod
  101. def import_from_dict(
  102. cls, session, dict_rep, parent=None, recursive=True, sync=None
  103. ): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
  104. """Import obj from a dictionary"""
  105. if sync is None:
  106. sync = []
  107. parent_refs = cls._parent_foreign_key_mappings()
  108. export_fields = set(cls.export_fields) | set(parent_refs.keys())
  109. new_children = {
  110. c: dict_rep.get(c) for c in cls.export_children if c in dict_rep
  111. }
  112. unique_constrains = cls._unique_constrains()
  113. filters = [] # Using these filters to check if obj already exists
  114. # Remove fields that should not get imported
  115. for k in list(dict_rep):
  116. if k not in export_fields:
  117. del dict_rep[k]
  118. if not parent:
  119. if cls.export_parent:
  120. for prnt in parent_refs.keys():
  121. if prnt not in dict_rep:
  122. raise RuntimeError(
  123. "{0}: Missing field {1}".format(cls.__name__, prnt)
  124. )
  125. else:
  126. # Set foreign keys to parent obj
  127. for k, v in parent_refs.items():
  128. dict_rep[k] = getattr(parent, v)
  129. # Add filter for parent obj
  130. filters.extend([getattr(cls, k) == dict_rep.get(k) for k in parent_refs.keys()])
  131. # Add filter for unique constraints
  132. ucs = [
  133. and_(
  134. *[
  135. getattr(cls, k) == dict_rep.get(k)
  136. for k in cs
  137. if dict_rep.get(k) is not None
  138. ]
  139. )
  140. for cs in unique_constrains
  141. ]
  142. filters.append(or_(*ucs))
  143. # Check if object already exists in DB, break if more than one is found
  144. try:
  145. obj_query = session.query(cls).filter(and_(*filters))
  146. obj = obj_query.one_or_none()
  147. except MultipleResultsFound as e:
  148. logger.error(
  149. "Error importing %s \n %s \n %s",
  150. cls.__name__,
  151. str(obj_query),
  152. yaml.safe_dump(dict_rep),
  153. )
  154. raise e
  155. if not obj:
  156. is_new_obj = True
  157. # Create new DB object
  158. obj = cls(**dict_rep)
  159. logger.info("Importing new %s %s", obj.__tablename__, str(obj))
  160. if cls.export_parent and parent:
  161. setattr(obj, cls.export_parent, parent)
  162. session.add(obj)
  163. else:
  164. is_new_obj = False
  165. logger.info("Updating %s %s", obj.__tablename__, str(obj))
  166. # Update columns
  167. for k, v in dict_rep.items():
  168. setattr(obj, k, v)
  169. # Recursively create children
  170. if recursive:
  171. for child in cls.export_children:
  172. child_class = cls.__mapper__.relationships[child].argument.class_
  173. added = []
  174. for c_obj in new_children.get(child, []):
  175. added.append(
  176. child_class.import_from_dict(
  177. session=session, dict_rep=c_obj, parent=obj, sync=sync
  178. )
  179. )
  180. # If children should get synced, delete the ones that did not
  181. # get updated.
  182. if child in sync and not is_new_obj:
  183. back_refs = (
  184. child_class._parent_foreign_key_mappings() # pylint: disable=protected-access
  185. )
  186. delete_filters = [
  187. getattr(child_class, k) == getattr(obj, back_refs.get(k))
  188. for k in back_refs.keys()
  189. ]
  190. to_delete = set(
  191. session.query(child_class).filter(and_(*delete_filters))
  192. ).difference(set(added))
  193. for o in to_delete:
  194. logger.info("Deleting %s %s", child, str(obj))
  195. session.delete(o)
  196. return obj
  197. def export_to_dict(
  198. self, recursive=True, include_parent_ref=False, include_defaults=False
  199. ):
  200. """Export obj to dictionary"""
  201. cls = self.__class__
  202. parent_excludes = {}
  203. if recursive and not include_parent_ref:
  204. parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
  205. if parent_ref:
  206. parent_excludes = {c.name for c in parent_ref.local_columns}
  207. dict_rep = {
  208. c.name: getattr(self, c.name)
  209. for c in cls.__table__.columns
  210. if (
  211. c.name in self.export_fields
  212. and c.name not in parent_excludes
  213. and (
  214. include_defaults
  215. or (
  216. getattr(self, c.name) is not None
  217. and (not c.default or getattr(self, c.name) != c.default.arg)
  218. )
  219. )
  220. )
  221. }
  222. if recursive:
  223. for cld in self.export_children:
  224. # sorting to make lists of children stable
  225. dict_rep[cld] = sorted(
  226. [
  227. child.export_to_dict(
  228. recursive=recursive,
  229. include_parent_ref=include_parent_ref,
  230. include_defaults=include_defaults,
  231. )
  232. for child in getattr(self, cld)
  233. ],
  234. key=lambda k: sorted(str(k.items())),
  235. )
  236. return dict_rep
  237. def override(self, obj):
  238. """Overrides the plain fields of the dashboard."""
  239. for field in obj.__class__.export_fields:
  240. setattr(self, field, getattr(obj, field))
  241. def copy(self):
  242. """Creates a copy of the dashboard without relationships."""
  243. new_obj = self.__class__()
  244. new_obj.override(self)
  245. return new_obj
  246. def alter_params(self, **kwargs):
  247. d = self.params_dict
  248. d.update(kwargs)
  249. self.params = json.dumps(d)
  250. def remove_params(self, param_to_remove: str) -> None:
  251. d = self.params_dict
  252. d.pop(param_to_remove, None)
  253. self.params = json.dumps(d)
  254. def reset_ownership(self):
  255. """ object will belong to the user the current user """
  256. # make sure the object doesn't have relations to a user
  257. # it will be filled by appbuilder on save
  258. self.created_by = None
  259. self.changed_by = None
  260. # flask global context might not exist (in cli or tests for example)
  261. try:
  262. if g.user:
  263. self.owners = [g.user]
  264. except Exception: # pylint: disable=broad-except
  265. self.owners = []
  266. @property
  267. def params_dict(self):
  268. return json_to_dict(self.params)
  269. @property
  270. def template_params_dict(self):
  271. return json_to_dict(self.template_params)
  272. def _user_link(user): # pylint: disable=no-self-use
  273. if not user:
  274. return ""
  275. url = "/superset/profile/{}/".format(user.username)
  276. return Markup('<a href="{}">{}</a>'.format(url, escape(user) or ""))
  277. class AuditMixinNullable(AuditMixin):
  278. """Altering the AuditMixin to use nullable fields
  279. Allows creating objects programmatically outside of CRUD
  280. """
  281. created_on = sa.Column(sa.DateTime, default=datetime.now, nullable=True)
  282. changed_on = sa.Column(
  283. sa.DateTime, default=datetime.now, onupdate=datetime.now, nullable=True
  284. )
  285. @declared_attr
  286. def created_by_fk(self):
  287. return sa.Column(
  288. sa.Integer,
  289. sa.ForeignKey("ab_user.id"),
  290. default=self.get_user_id,
  291. nullable=True,
  292. )
  293. @declared_attr
  294. def changed_by_fk(self):
  295. return sa.Column(
  296. sa.Integer,
  297. sa.ForeignKey("ab_user.id"),
  298. default=self.get_user_id,
  299. onupdate=self.get_user_id,
  300. nullable=True,
  301. )
  302. def changed_by_name(self):
  303. if self.created_by:
  304. return escape("{}".format(self.created_by))
  305. return ""
  306. @renders("created_by")
  307. def creator(self):
  308. return _user_link(self.created_by)
  309. @property
  310. def changed_by_(self):
  311. return _user_link(self.changed_by)
  312. @renders("changed_on")
  313. def changed_on_(self):
  314. return Markup(f'<span class="no-wrap">{self.changed_on}</span>')
  315. @property
  316. def changed_on_humanized(self):
  317. return humanize.naturaltime(datetime.now() - self.changed_on)
  318. @renders("changed_on")
  319. def modified(self):
  320. return Markup(f'<span class="no-wrap">{self.changed_on_humanized}</span>')
  321. class QueryResult: # pylint: disable=too-few-public-methods
  322. """Object returned by the query interface"""
  323. def __init__( # pylint: disable=too-many-arguments
  324. self, df, query, duration, status=QueryStatus.SUCCESS, error_message=None
  325. ):
  326. self.df: pd.DataFrame = df # pylint: disable=invalid-name
  327. self.query: str = query
  328. self.duration: int = duration
  329. self.status: str = status
  330. self.error_message: Optional[str] = error_message
  331. class ExtraJSONMixin:
  332. """Mixin to add an `extra` column (JSON) and utility methods"""
  333. extra_json = sa.Column(sa.Text, default="{}")
  334. @property
  335. def extra(self):
  336. try:
  337. return json.loads(self.extra_json)
  338. except Exception: # pylint: disable=broad-except
  339. return {}
  340. def set_extra_json(self, d):
  341. self.extra_json = json.dumps(d)
  342. def set_extra_json_key(self, key, value):
  343. extra = self.extra
  344. extra[key] = value
  345. self.extra_json = json.dumps(extra)