jinja_context.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # pylint: disable=C,R,W
  2. """Defines the templating context for SQL Lab"""
  3. from datetime import datetime, timedelta
  4. import inspect
  5. import json
  6. import random
  7. import time
  8. import uuid
  9. from dateutil.relativedelta import relativedelta
  10. from flask import g, request
  11. from jinja2.sandbox import SandboxedEnvironment
  12. from superset import app
  13. config = app.config
  14. BASE_CONTEXT = {
  15. 'datetime': datetime,
  16. 'random': random,
  17. 'relativedelta': relativedelta,
  18. 'time': time,
  19. 'timedelta': timedelta,
  20. 'uuid': uuid,
  21. }
  22. BASE_CONTEXT.update(config.get('JINJA_CONTEXT_ADDONS', {}))
  23. def url_param(param, default=None):
  24. """Get a url or post data parameter
  25. :param param: the parameter to lookup
  26. :type param: str
  27. :param default: the value to return in the absence of the parameter
  28. :type default: str
  29. """
  30. if request.args.get(param):
  31. return request.args.get(param, default)
  32. # Supporting POST as well as get
  33. if request.form.get('form_data'):
  34. form_data = json.loads(request.form.get('form_data'))
  35. url_params = form_data['url_params'] or {}
  36. return url_params.get(param, default)
  37. return default
  38. def current_user_id():
  39. """The id of the user who is currently logged in"""
  40. if hasattr(g, 'user') and g.user:
  41. return g.user.id
  42. def current_username():
  43. """The username of the user who is currently logged in"""
  44. if g.user:
  45. return g.user.username
  46. def filter_values(column, default=None):
  47. """ Gets a values for a particular filter as a list
  48. This is useful if:
  49. - you want to use a filter box to filter a query where the name of filter box
  50. column doesn't match the one in the select statement
  51. - you want to have the ability for filter inside the main query for speed purposes
  52. This searches for "filters" and "extra_filters" in form_data for a match
  53. Usage example:
  54. SELECT action, count(*) as times
  55. FROM logs
  56. WHERE action in ( {{ "'" + "','".join(filter_values('action_type')) + "'" )
  57. GROUP BY 1
  58. :param column: column/filter name to lookup
  59. :type column: str
  60. :param default: default value to return if there's no matching columns
  61. :type default: str
  62. :return: returns a list of filter values
  63. :type: list
  64. """
  65. form_data = json.loads(request.form.get('form_data', '{}'))
  66. return_val = []
  67. for filter_type in ['filters', 'extra_filters']:
  68. if filter_type not in form_data:
  69. continue
  70. for f in form_data[filter_type]:
  71. if f['col'] == column:
  72. for v in f['val']:
  73. return_val.append(v)
  74. if return_val:
  75. return return_val
  76. if default:
  77. return [default]
  78. else:
  79. return []
  80. class BaseTemplateProcessor(object):
  81. """Base class for database-specific jinja context
  82. There's this bit of magic in ``process_template`` that instantiates only
  83. the database context for the active database as a ``models.Database``
  84. object binds it to the context object, so that object methods
  85. have access to
  86. that context. This way, {{ hive.latest_partition('mytable') }} just
  87. knows about the database it is operating in.
  88. This means that object methods are only available for the active database
  89. and are given access to the ``models.Database`` object and schema
  90. name. For globally available methods use ``@classmethod``.
  91. """
  92. engine = None
  93. def __init__(self, database=None, query=None, table=None, **kwargs):
  94. self.database = database
  95. self.query = query
  96. self.schema = None
  97. if query and query.schema:
  98. self.schema = query.schema
  99. elif table:
  100. self.schema = table.schema
  101. self.context = {
  102. 'url_param': url_param,
  103. 'current_user_id': current_user_id,
  104. 'current_username': current_username,
  105. 'filter_values': filter_values,
  106. 'form_data': {},
  107. }
  108. self.context.update(kwargs)
  109. self.context.update(BASE_CONTEXT)
  110. if self.engine:
  111. self.context[self.engine] = self
  112. self.env = SandboxedEnvironment()
  113. def process_template(self, sql, **kwargs):
  114. """Processes a sql template
  115. >>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'"
  116. >>> process_template(sql)
  117. "SELECT '2017-01-01T00:00:00'"
  118. """
  119. template = self.env.from_string(sql)
  120. kwargs.update(self.context)
  121. return template.render(kwargs)
  122. class PrestoTemplateProcessor(BaseTemplateProcessor):
  123. """Presto Jinja context
  124. The methods described here are namespaced under ``presto`` in the
  125. jinja context as in ``SELECT '{{ presto.some_macro_call() }}'``
  126. """
  127. engine = 'presto'
  128. @staticmethod
  129. def _schema_table(table_name, schema):
  130. if '.' in table_name:
  131. schema, table_name = table_name.split('.')
  132. return table_name, schema
  133. def latest_partition(self, table_name):
  134. table_name, schema = self._schema_table(table_name, self.schema)
  135. return self.database.db_engine_spec.latest_partition(
  136. table_name, schema, self.database)[1]
  137. def latest_sub_partition(self, table_name, **kwargs):
  138. table_name, schema = self._schema_table(table_name, self.schema)
  139. return self.database.db_engine_spec.latest_sub_partition(
  140. table_name=table_name,
  141. schema=schema,
  142. database=self.database,
  143. **kwargs)
  144. class HiveTemplateProcessor(PrestoTemplateProcessor):
  145. engine = 'hive'
  146. template_processors = {}
  147. keys = tuple(globals().keys())
  148. for k in keys:
  149. o = globals()[k]
  150. if o and inspect.isclass(o) and issubclass(o, BaseTemplateProcessor):
  151. template_processors[o.engine] = o
  152. def get_template_processor(database, table=None, query=None, **kwargs):
  153. TP = template_processors.get(database.backend, BaseTemplateProcessor)
  154. return TP(database=database, table=table, query=query, **kwargs)