base.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # pylint: disable=C,R,W
  2. from datetime import datetime
  3. import functools
  4. import logging
  5. import traceback
  6. from flask import abort, flash, g, get_flashed_messages, redirect, Response
  7. from flask_appbuilder import BaseView, ModelView
  8. from flask_appbuilder.actions import action
  9. from flask_appbuilder.models.sqla.filters import BaseFilter
  10. from flask_appbuilder.widgets import ListWidget
  11. from flask_babel import get_locale
  12. from flask_babel import gettext as __
  13. from flask_babel import lazy_gettext as _
  14. import simplejson as json
  15. import yaml
  16. from superset import conf, db, security_manager, utils
  17. from superset.exceptions import SupersetSecurityException
  18. from superset.translations.utils import get_language_pack
  19. FRONTEND_CONF_KEYS = (
  20. 'SUPERSET_WEBSERVER_TIMEOUT',
  21. 'SUPERSET_DASHBOARD_POSITION_DATA_LIMIT',
  22. 'ENABLE_JAVASCRIPT_CONTROLS',
  23. )
  24. def get_error_msg():
  25. if conf.get('SHOW_STACKTRACE'):
  26. error_msg = traceback.format_exc()
  27. else:
  28. error_msg = 'FATAL ERROR \n'
  29. error_msg += (
  30. 'Stacktrace is hidden. Change the SHOW_STACKTRACE '
  31. 'configuration setting to enable it')
  32. return error_msg
  33. def json_error_response(msg=None, status=500, stacktrace=None, payload=None, link=None):
  34. if not payload:
  35. payload = {'error': '{}'.format(msg)}
  36. if stacktrace:
  37. payload['stacktrace'] = stacktrace
  38. if link:
  39. payload['link'] = link
  40. return Response(
  41. json.dumps(payload, default=utils.json_iso_dttm_ser, ignore_nan=True),
  42. status=status, mimetype='application/json')
  43. def generate_download_headers(extension, filename=None):
  44. filename = filename if filename else datetime.now().strftime('%Y%m%d_%H%M%S')
  45. content_disp = 'attachment; filename={}.{}'.format(filename, extension)
  46. headers = {
  47. 'Content-Disposition': content_disp,
  48. }
  49. return headers
  50. def api(f):
  51. """
  52. A decorator to label an endpoint as an API. Catches uncaught exceptions and
  53. return the response in the JSON format
  54. """
  55. def wraps(self, *args, **kwargs):
  56. try:
  57. return f(self, *args, **kwargs)
  58. except Exception as e:
  59. logging.exception(e)
  60. return json_error_response(get_error_msg())
  61. return functools.update_wrapper(wraps, f)
  62. def get_datasource_exist_error_msg(full_name):
  63. return __('Datasource %(name)s already exists', name=full_name)
  64. def get_user_roles():
  65. if g.user.is_anonymous:
  66. public_role = conf.get('AUTH_ROLE_PUBLIC')
  67. return [security_manager.find_role(public_role)] if public_role else []
  68. return g.user.roles
  69. class BaseSupersetView(BaseView):
  70. def json_response(self, obj, status=200):
  71. return Response(
  72. json.dumps(obj, default=utils.json_int_dttm_ser, ignore_nan=True),
  73. status=status,
  74. mimetype='application/json')
  75. def common_bootsrap_payload(self):
  76. """Common data always sent to the client"""
  77. messages = get_flashed_messages(with_categories=True)
  78. locale = str(get_locale())
  79. return {
  80. 'flash_messages': messages,
  81. 'conf': {k: conf.get(k) for k in FRONTEND_CONF_KEYS},
  82. 'locale': locale,
  83. 'language_pack': get_language_pack(locale),
  84. 'feature_flags': conf.get('FEATURE_FLAGS'),
  85. }
  86. class SupersetListWidget(ListWidget):
  87. template = 'superset/fab_overrides/list.html'
  88. class SupersetModelView(ModelView):
  89. page_size = 100
  90. list_widget = SupersetListWidget
  91. class ListWidgetWithCheckboxes(ListWidget):
  92. """An alternative to list view that renders Boolean fields as checkboxes
  93. Works in conjunction with the `checkbox` view."""
  94. template = 'superset/fab_overrides/list_with_checkboxes.html'
  95. def validate_json(form, field): # noqa
  96. try:
  97. json.loads(field.data)
  98. except Exception as e:
  99. logging.exception(e)
  100. raise Exception(_("json isn't valid"))
  101. class YamlExportMixin(object):
  102. @action('yaml_export', __('Export to YAML'), __('Export to YAML?'), 'fa-download')
  103. def yaml_export(self, items):
  104. if not isinstance(items, list):
  105. items = [items]
  106. data = [t.export_to_dict() for t in items]
  107. return Response(
  108. yaml.safe_dump(data),
  109. headers=generate_download_headers('yaml'),
  110. mimetype='application/text')
  111. class DeleteMixin(object):
  112. def _delete(self, pk):
  113. """
  114. Delete function logic, override to implement diferent logic
  115. deletes the record with primary_key = pk
  116. :param pk:
  117. record primary key to delete
  118. """
  119. item = self.datamodel.get(pk, self._base_filters)
  120. if not item:
  121. abort(404)
  122. try:
  123. self.pre_delete(item)
  124. except Exception as e:
  125. flash(str(e), 'danger')
  126. else:
  127. view_menu = security_manager.find_view_menu(item.get_perm())
  128. pvs = security_manager.get_session.query(
  129. security_manager.permissionview_model).filter_by(
  130. view_menu=view_menu).all()
  131. schema_view_menu = None
  132. if hasattr(item, 'schema_perm'):
  133. schema_view_menu = security_manager.find_view_menu(item.schema_perm)
  134. pvs.extend(security_manager.get_session.query(
  135. security_manager.permissionview_model).filter_by(
  136. view_menu=schema_view_menu).all())
  137. if self.datamodel.delete(item):
  138. self.post_delete(item)
  139. for pv in pvs:
  140. security_manager.get_session.delete(pv)
  141. if view_menu:
  142. security_manager.get_session.delete(view_menu)
  143. if schema_view_menu:
  144. security_manager.get_session.delete(schema_view_menu)
  145. security_manager.get_session.commit()
  146. flash(*self.datamodel.message)
  147. self.update_redirect()
  148. @action(
  149. 'muldelete',
  150. __('Delete'),
  151. __('Delete all Really?'),
  152. 'fa-trash',
  153. single=False,
  154. )
  155. def muldelete(self, items):
  156. if not items:
  157. abort(404)
  158. for item in items:
  159. try:
  160. self.pre_delete(item)
  161. except Exception as e:
  162. flash(str(e), 'danger')
  163. else:
  164. self._delete(item.id)
  165. self.update_redirect()
  166. return redirect(self.get_redirect())
  167. class SupersetFilter(BaseFilter):
  168. """Add utility function to make BaseFilter easy and fast
  169. These utility function exist in the SecurityManager, but would do
  170. a database round trip at every check. Here we cache the role objects
  171. to be able to make multiple checks but query the db only once
  172. """
  173. def get_user_roles(self):
  174. return get_user_roles()
  175. def get_all_permissions(self):
  176. """Returns a set of tuples with the perm name and view menu name"""
  177. perms = set()
  178. for role in self.get_user_roles():
  179. for perm_view in role.permissions:
  180. t = (perm_view.permission.name, perm_view.view_menu.name)
  181. perms.add(t)
  182. return perms
  183. def has_role(self, role_name_or_list):
  184. """Whether the user has this role name"""
  185. if not isinstance(role_name_or_list, list):
  186. role_name_or_list = [role_name_or_list]
  187. return any(
  188. [r.name in role_name_or_list for r in self.get_user_roles()])
  189. def has_perm(self, permission_name, view_menu_name):
  190. """Whether the user has this perm"""
  191. return (permission_name, view_menu_name) in self.get_all_permissions()
  192. def get_view_menus(self, permission_name):
  193. """Returns the details of view_menus for a perm name"""
  194. vm = set()
  195. for perm_name, vm_name in self.get_all_permissions():
  196. if perm_name == permission_name:
  197. vm.add(vm_name)
  198. return vm
  199. class DatasourceFilter(SupersetFilter):
  200. def apply(self, query, func): # noqa
  201. if security_manager.all_datasource_access():
  202. return query
  203. perms = self.get_view_menus('datasource_access')
  204. # TODO(bogdan): add `schema_access` support here
  205. return query.filter(self.model.perm.in_(perms))
  206. class CsvResponse(Response):
  207. """
  208. Override Response to take into account csv encoding from config.py
  209. """
  210. charset = conf.get('CSV_EXPORT').get('encoding', 'utf-8')
  211. class XlsxResponse(Response):
  212. """
  213. Override Response to take into account csv encoding from config.py
  214. """
  215. # charset = conf.get('CSV_EXPORT').get('encoding', 'utf-8')
  216. charset = "utf-8"
  217. def check_ownership(obj, raise_if_false=True):
  218. """Meant to be used in `pre_update` hooks on models to enforce ownership
  219. Admin have all access, and other users need to be referenced on either
  220. the created_by field that comes with the ``AuditMixin``, or in a field
  221. named ``owners`` which is expected to be a one-to-many with the User
  222. model. It is meant to be used in the ModelView's pre_update hook in
  223. which raising will abort the update.
  224. """
  225. if not obj:
  226. return False
  227. security_exception = SupersetSecurityException(
  228. "You don't have the rights to alter [{}]".format(obj))
  229. if g.user.is_anonymous:
  230. if raise_if_false:
  231. raise security_exception
  232. return False
  233. roles = [r.name for r in get_user_roles()]
  234. if 'Admin' in roles:
  235. return True
  236. session = db.create_scoped_session()
  237. orig_obj = session.query(obj.__class__).filter_by(id=obj.id).first()
  238. # Making a list of owners that works across ORM models
  239. owners = []
  240. if hasattr(orig_obj, 'owners'):
  241. owners += orig_obj.owners
  242. if hasattr(orig_obj, 'owner'):
  243. owners += [orig_obj.owner]
  244. if hasattr(orig_obj, 'created_by'):
  245. owners += [orig_obj.created_by]
  246. owner_names = [o.username for o in owners if o]
  247. if (
  248. g.user and hasattr(g.user, 'username') and
  249. g.user.username in owner_names):
  250. return True
  251. if raise_if_false:
  252. raise security_exception
  253. else:
  254. return False