base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. import functools
  18. import logging
  19. import traceback
  20. from datetime import datetime
  21. from typing import Any, Dict, Optional
  22. import simplejson as json
  23. import yaml
  24. from flask import abort, flash, g, get_flashed_messages, redirect, Response, session
  25. from flask_appbuilder import BaseView, ModelView
  26. from flask_appbuilder.actions import action
  27. from flask_appbuilder.forms import DynamicForm
  28. from flask_appbuilder.models.sqla.filters import BaseFilter
  29. from flask_appbuilder.widgets import ListWidget
  30. from flask_babel import get_locale, gettext as __, lazy_gettext as _
  31. from flask_wtf.form import FlaskForm
  32. from sqlalchemy import or_
  33. from werkzeug.exceptions import HTTPException
  34. from wtforms.fields.core import Field, UnboundField
  35. from superset import appbuilder, conf, db, get_feature_flags, security_manager
  36. from superset.exceptions import SupersetException, SupersetSecurityException
  37. from superset.translations.utils import get_language_pack
  38. from superset.utils import core as utils
  39. from .utils import bootstrap_user_data
  40. FRONTEND_CONF_KEYS = (
  41. "SUPERSET_WEBSERVER_TIMEOUT",
  42. "SUPERSET_DASHBOARD_POSITION_DATA_LIMIT",
  43. "ENABLE_JAVASCRIPT_CONTROLS",
  44. "DEFAULT_SQLLAB_LIMIT",
  45. "SQL_MAX_ROW",
  46. "SUPERSET_WEBSERVER_DOMAINS",
  47. "SQLLAB_SAVE_WARNING_MESSAGE",
  48. "DISPLAY_MAX_ROW",
  49. )
  50. logger = logging.getLogger(__name__)
  51. def get_error_msg():
  52. if conf.get("SHOW_STACKTRACE"):
  53. error_msg = traceback.format_exc()
  54. else:
  55. error_msg = "FATAL ERROR \n"
  56. error_msg += (
  57. "Stacktrace is hidden. Change the SHOW_STACKTRACE "
  58. "configuration setting to enable it"
  59. )
  60. return error_msg
  61. def json_error_response(msg=None, status=500, payload=None, link=None):
  62. if not payload:
  63. payload = {"error": "{}".format(msg)}
  64. if link:
  65. payload["link"] = link
  66. return Response(
  67. json.dumps(payload, default=utils.json_iso_dttm_ser, ignore_nan=True),
  68. status=status,
  69. mimetype="application/json",
  70. )
  71. def json_success(json_msg, status=200):
  72. return Response(json_msg, status=status, mimetype="application/json")
  73. def data_payload_response(payload_json, has_error=False):
  74. status = 400 if has_error else 200
  75. return json_success(payload_json, status=status)
  76. def generate_download_headers(extension, filename=None):
  77. filename = filename if filename else datetime.now().strftime("%Y%m%d_%H%M%S")
  78. content_disp = f"attachment; filename={filename}.{extension}"
  79. headers = {"Content-Disposition": content_disp}
  80. return headers
  81. def api(f):
  82. """
  83. A decorator to label an endpoint as an API. Catches uncaught exceptions and
  84. return the response in the JSON format
  85. """
  86. def wraps(self, *args, **kwargs):
  87. try:
  88. return f(self, *args, **kwargs)
  89. except Exception as e: # pylint: disable=broad-except
  90. logger.exception(e)
  91. return json_error_response(get_error_msg())
  92. return functools.update_wrapper(wraps, f)
  93. def handle_api_exception(f):
  94. """
  95. A decorator to catch superset exceptions. Use it after the @api decorator above
  96. so superset exception handler is triggered before the handler for generic
  97. exceptions.
  98. """
  99. def wraps(self, *args, **kwargs):
  100. try:
  101. return f(self, *args, **kwargs)
  102. except SupersetSecurityException as e:
  103. logger.exception(e)
  104. return json_error_response(
  105. utils.error_msg_from_exception(e), status=e.status, link=e.link
  106. )
  107. except SupersetException as e:
  108. logger.exception(e)
  109. return json_error_response(
  110. utils.error_msg_from_exception(e), status=e.status
  111. )
  112. except HTTPException as e:
  113. logger.exception(e)
  114. return json_error_response(utils.error_msg_from_exception(e), status=e.code)
  115. except Exception as e: # pylint: disable=broad-except
  116. logger.exception(e)
  117. return json_error_response(utils.error_msg_from_exception(e))
  118. return functools.update_wrapper(wraps, f)
  119. def get_datasource_exist_error_msg(full_name):
  120. return __("Datasource %(name)s already exists", name=full_name)
  121. def get_user_roles():
  122. if g.user.is_anonymous:
  123. public_role = conf.get("AUTH_ROLE_PUBLIC")
  124. return [security_manager.find_role(public_role)] if public_role else []
  125. return g.user.roles
  126. class BaseSupersetView(BaseView):
  127. @staticmethod
  128. def json_response(obj, status=200) -> Response: # pylint: disable=no-self-use
  129. return Response(
  130. json.dumps(obj, default=utils.json_int_dttm_ser, ignore_nan=True),
  131. status=status,
  132. mimetype="application/json",
  133. )
  134. def menu_data():
  135. menu = appbuilder.menu.get_data()
  136. root_path = "#"
  137. logo_target_path = ""
  138. if not g.user.is_anonymous:
  139. try:
  140. logo_target_path = (
  141. appbuilder.app.config.get("LOGO_TARGET_PATH")
  142. or f"/profile/{g.user.username}/"
  143. )
  144. # when user object has no username
  145. except NameError as e:
  146. logger.exception(e)
  147. if logo_target_path.startswith("/"):
  148. root_path = f"/superset{logo_target_path}"
  149. else:
  150. root_path = logo_target_path
  151. languages = {}
  152. for lang in appbuilder.languages:
  153. languages[lang] = {
  154. **appbuilder.languages[lang],
  155. "url": appbuilder.get_url_for_locale(lang),
  156. }
  157. return {
  158. "menu": menu,
  159. "brand": {
  160. "path": root_path,
  161. "icon": appbuilder.app_icon,
  162. "alt": appbuilder.app_name,
  163. },
  164. "navbar_right": {
  165. "bug_report_url": appbuilder.app.config.get("BUG_REPORT_URL"),
  166. "documentation_url": appbuilder.app.config.get("DOCUMENTATION_URL"),
  167. "version_string": appbuilder.app.config.get("VERSION_STRING"),
  168. "version_sha": appbuilder.app.config.get("VERSION_SHA"),
  169. "languages": languages,
  170. "show_language_picker": len(languages.keys()) > 1,
  171. "user_is_anonymous": g.user.is_anonymous,
  172. "user_info_url": appbuilder.get_url_for_userinfo,
  173. "user_logout_url": appbuilder.get_url_for_logout,
  174. "user_login_url": appbuilder.get_url_for_login,
  175. "locale": session.get("locale", "en"),
  176. },
  177. }
  178. def common_bootstrap_payload():
  179. """Common data always sent to the client"""
  180. messages = get_flashed_messages(with_categories=True)
  181. locale = str(get_locale())
  182. return {
  183. "flash_messages": messages,
  184. "conf": {k: conf.get(k) for k in FRONTEND_CONF_KEYS},
  185. "locale": locale,
  186. "language_pack": get_language_pack(locale),
  187. "feature_flags": get_feature_flags(),
  188. "menu_data": menu_data(),
  189. }
  190. class SupersetListWidget(ListWidget): # pylint: disable=too-few-public-methods
  191. template = "superset/fab_overrides/list.html"
  192. class SupersetModelView(ModelView):
  193. page_size = 100
  194. list_widget = SupersetListWidget
  195. def render_app_template(self):
  196. payload = {
  197. "user": bootstrap_user_data(g.user),
  198. "common": common_bootstrap_payload(),
  199. }
  200. return self.render_template(
  201. "superset/welcome.html",
  202. entry="welcome",
  203. bootstrap_data=json.dumps(
  204. payload, default=utils.pessimistic_json_iso_dttm_ser
  205. ),
  206. )
  207. class ListWidgetWithCheckboxes(ListWidget): # pylint: disable=too-few-public-methods
  208. """An alternative to list view that renders Boolean fields as checkboxes
  209. Works in conjunction with the `checkbox` view."""
  210. template = "superset/fab_overrides/list_with_checkboxes.html"
  211. def validate_json(_form, field):
  212. try:
  213. json.loads(field.data)
  214. except Exception as e:
  215. logger.exception(e)
  216. raise Exception(_("json isn't valid"))
  217. class YamlExportMixin: # pylint: disable=too-few-public-methods
  218. """
  219. Override this if you want a dict response instead, with a certain key.
  220. Used on DatabaseView for cli compatibility
  221. """
  222. yaml_dict_key: Optional[str] = None
  223. @action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download")
  224. def yaml_export(self, items):
  225. if not isinstance(items, list):
  226. items = [items]
  227. data = [t.export_to_dict() for t in items]
  228. if self.yaml_dict_key:
  229. data = {self.yaml_dict_key: data}
  230. return Response(
  231. yaml.safe_dump(data),
  232. headers=generate_download_headers("yaml"),
  233. mimetype="application/text",
  234. )
  235. class DeleteMixin: # pylint: disable=too-few-public-methods
  236. def _delete(self, primary_key):
  237. """
  238. Delete function logic, override to implement diferent logic
  239. deletes the record with primary_key = primary_key
  240. :param primary_key:
  241. record primary key to delete
  242. """
  243. item = self.datamodel.get(primary_key, self._base_filters)
  244. if not item:
  245. abort(404)
  246. try:
  247. self.pre_delete(item)
  248. except Exception as e: # pylint: disable=broad-except
  249. flash(str(e), "danger")
  250. else:
  251. view_menu = security_manager.find_view_menu(item.get_perm())
  252. pvs = (
  253. security_manager.get_session.query(
  254. security_manager.permissionview_model
  255. )
  256. .filter_by(view_menu=view_menu)
  257. .all()
  258. )
  259. if self.datamodel.delete(item):
  260. self.post_delete(item)
  261. for pv in pvs:
  262. security_manager.get_session.delete(pv)
  263. if view_menu:
  264. security_manager.get_session.delete(view_menu)
  265. security_manager.get_session.commit()
  266. flash(*self.datamodel.message)
  267. self.update_redirect()
  268. @action(
  269. "muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
  270. )
  271. def muldelete(self, items):
  272. if not items:
  273. abort(404)
  274. for item in items:
  275. try:
  276. self.pre_delete(item)
  277. except Exception as e: # pylint: disable=broad-except
  278. flash(str(e), "danger")
  279. else:
  280. self._delete(item.id)
  281. self.update_redirect()
  282. return redirect(self.get_redirect())
  283. class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
  284. def apply(self, query, value):
  285. if security_manager.all_datasource_access():
  286. return query
  287. datasource_perms = security_manager.user_view_menu_names("datasource_access")
  288. schema_perms = security_manager.user_view_menu_names("schema_access")
  289. return query.filter(
  290. or_(
  291. self.model.perm.in_(datasource_perms),
  292. self.model.schema_perm.in_(schema_perms),
  293. )
  294. )
  295. class CsvResponse(Response): # pylint: disable=too-many-ancestors
  296. """
  297. Override Response to take into account csv encoding from config.py
  298. """
  299. charset = conf["CSV_EXPORT"].get("encoding", "utf-8")
  300. def check_ownership(obj, raise_if_false=True):
  301. """Meant to be used in `pre_update` hooks on models to enforce ownership
  302. Admin have all access, and other users need to be referenced on either
  303. the created_by field that comes with the ``AuditMixin``, or in a field
  304. named ``owners`` which is expected to be a one-to-many with the User
  305. model. It is meant to be used in the ModelView's pre_update hook in
  306. which raising will abort the update.
  307. """
  308. if not obj:
  309. return False
  310. security_exception = SupersetSecurityException(
  311. "You don't have the rights to alter [{}]".format(obj)
  312. )
  313. if g.user.is_anonymous:
  314. if raise_if_false:
  315. raise security_exception
  316. return False
  317. roles = [r.name for r in get_user_roles()]
  318. if "Admin" in roles:
  319. return True
  320. scoped_session = db.create_scoped_session()
  321. orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
  322. # Making a list of owners that works across ORM models
  323. owners = []
  324. if hasattr(orig_obj, "owners"):
  325. owners += orig_obj.owners
  326. if hasattr(orig_obj, "owner"):
  327. owners += [orig_obj.owner]
  328. if hasattr(orig_obj, "created_by"):
  329. owners += [orig_obj.created_by]
  330. owner_names = [o.username for o in owners if o]
  331. if g.user and hasattr(g.user, "username") and g.user.username in owner_names:
  332. return True
  333. if raise_if_false:
  334. raise security_exception
  335. else:
  336. return False
  337. def bind_field(
  338. _, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
  339. ) -> Field:
  340. """
  341. Customize how fields are bound by stripping all whitespace.
  342. :param form: The form
  343. :param unbound_field: The unbound field
  344. :param options: The field options
  345. :returns: The bound field
  346. """
  347. filters = unbound_field.kwargs.get("filters", [])
  348. filters.append(lambda x: x.strip() if isinstance(x, str) else x)
  349. return unbound_field.bind(form=form, filters=filters, **options)
  350. FlaskForm.Meta.bind_field = bind_field