viz.py 91 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827
  1. # pylint: disable=C,R,W
  2. """This module contains the 'Viz' objects
  3. These objects represent the backend of all the visualizations that
  4. Superset can render.
  5. """
  6. from collections import defaultdict, OrderedDict
  7. import copy
  8. from datetime import datetime, timedelta
  9. from functools import reduce
  10. import hashlib
  11. import inspect
  12. from itertools import product
  13. import logging
  14. import math
  15. import pickle as pkl
  16. import re
  17. import traceback
  18. import uuid
  19. import os
  20. from dateutil import relativedelta as rdelta
  21. from flask import request
  22. from flask_babel import lazy_gettext as _
  23. import geohash
  24. from geopy.point import Point
  25. from markdown import markdown
  26. import numpy as np
  27. import pandas as pd
  28. from pandas.tseries.frequencies import to_offset
  29. from past.builtins import basestring
  30. import polyline
  31. import simplejson as json
  32. from superset import app, cache, get_css_manifest_files, utils
  33. from superset.exceptions import NullValueException, SpatialException
  34. from superset.utils import (
  35. DTTM_ALIAS,
  36. JS_MAX_INTEGER,
  37. merge_extra_filters,
  38. to_adhoc,
  39. )
  40. config = app.config
  41. stats_logger = config.get('STATS_LOGGER')
  42. METRIC_KEYS = [
  43. 'metric', 'metrics', 'percent_metrics', 'metric_2', 'secondary_metric',
  44. 'x', 'y', 'size',
  45. ]
  46. class BaseViz(object):
  47. """All visualizations derive this base class"""
  48. viz_type = None
  49. verbose_name = 'Base Viz'
  50. credits = ''
  51. is_timeseries = False
  52. default_fillna = 0
  53. cache_type = 'df'
  54. enforce_numerical_metrics = True
  55. def __init__(self, datasource, form_data, force=False):
  56. if not datasource:
  57. raise Exception(_('Viz is missing a datasource'))
  58. self.datasource = datasource
  59. self.request = request
  60. self.viz_type = form_data.get('viz_type')
  61. self.form_data = form_data
  62. self.query = ''
  63. self.token = self.form_data.get(
  64. 'token', 'token_' + uuid.uuid4().hex[:8])
  65. self.groupby = self.form_data.get('groupby') or []
  66. self.time_shift = timedelta()
  67. self.status = None
  68. self.error_message = None
  69. self.force = force
  70. # Keeping track of whether some data came from cache
  71. # this is useful to trigger the <CachedLabel /> when
  72. # in the cases where visualization have many queries
  73. # (FilterBox for instance)
  74. self._some_from_cache = False
  75. self._any_cache_key = None
  76. self._any_cached_dttm = None
  77. self._extra_chart_data = []
  78. self.process_metrics()
  79. def process_metrics(self):
  80. # metrics in TableViz is order sensitive, so metric_dict should be
  81. # OrderedDict
  82. self.metric_dict = OrderedDict()
  83. fd = self.form_data
  84. for mkey in METRIC_KEYS:
  85. val = fd.get(mkey)
  86. if val:
  87. if not isinstance(val, list):
  88. val = [val]
  89. for o in val:
  90. label = self.get_metric_label(o)
  91. if isinstance(o, dict):
  92. o['label'] = label
  93. self.metric_dict[label] = o
  94. # Cast to list needed to return serializable object in py3
  95. self.all_metrics = list(self.metric_dict.values())
  96. self.metric_labels = list(self.metric_dict.keys())
  97. def get_metric_label(self, metric):
  98. if isinstance(metric, str):
  99. return metric
  100. if isinstance(metric, dict):
  101. metric = metric.get('label')
  102. if self.datasource.type == 'table':
  103. db_engine_spec = self.datasource.database.db_engine_spec
  104. metric = db_engine_spec.mutate_expression_label(metric)
  105. return metric
  106. @staticmethod
  107. def handle_js_int_overflow(data):
  108. for d in data.get('records', dict()):
  109. for k, v in list(d.items()):
  110. if isinstance(v, int):
  111. # if an int is too big for Java Script to handle
  112. # convert it to a string
  113. if abs(v) > JS_MAX_INTEGER:
  114. d[k] = str(v)
  115. return data
  116. def run_extra_queries(self):
  117. """Lifecycle method to use when more than one query is needed
  118. In rare-ish cases, a visualization may need to execute multiple
  119. queries. That is the case for FilterBox or for time comparison
  120. in Line chart for instance.
  121. In those cases, we need to make sure these queries run before the
  122. main `get_payload` method gets called, so that the overall caching
  123. metadata can be right. The way it works here is that if any of
  124. the previous `get_df_payload` calls hit the cache, the main
  125. payload's metadata will reflect that.
  126. The multi-query support may need more work to become a first class
  127. use case in the framework, and for the UI to reflect the subtleties
  128. (show that only some of the queries were served from cache for
  129. instance). In the meantime, since multi-query is rare, we treat
  130. it with a bit of a hack. Note that the hack became necessary
  131. when moving from caching the visualization's data itself, to caching
  132. the underlying query(ies).
  133. """
  134. pass
  135. def handle_nulls(self, df):
  136. fillna = self.get_fillna_for_columns(df.columns)
  137. return df.fillna(fillna)
  138. def get_fillna_for_col(self, col):
  139. """Returns the value to use as filler for a specific Column.type"""
  140. if col:
  141. if col.is_string:
  142. return ' NULL'
  143. return self.default_fillna
  144. def get_fillna_for_columns(self, columns=None):
  145. """Returns a dict or scalar that can be passed to DataFrame.fillna"""
  146. if columns is None:
  147. return self.default_fillna
  148. columns_dict = {col.column_name: col for col in self.datasource.columns}
  149. fillna = {
  150. c: self.get_fillna_for_col(columns_dict.get(c))
  151. for c in columns
  152. }
  153. return fillna
  154. def get_samples(self):
  155. query_obj = self.query_obj()
  156. query_obj.update({
  157. 'groupby': [],
  158. 'metrics': [],
  159. 'row_limit': 1000,
  160. 'columns': [o.column_name for o in self.datasource.columns],
  161. })
  162. df = self.get_df(query_obj)
  163. return df.to_dict(orient='records')
  164. def get_df(self, query_obj=None):
  165. """Returns a pandas dataframe based on the query object"""
  166. if not query_obj:
  167. query_obj = self.query_obj()
  168. if not query_obj:
  169. return None
  170. self.error_msg = ''
  171. timestamp_format = None
  172. if self.datasource.type == 'table':
  173. dttm_col = self.datasource.get_col(query_obj['granularity'])
  174. if dttm_col:
  175. timestamp_format = dttm_col.python_date_format
  176. # The datasource here can be different backend but the interface is common
  177. self.results = self.datasource.query(query_obj)
  178. self.query = self.results.query
  179. self.status = self.results.status
  180. self.error_message = self.results.error_message
  181. df = self.results.df
  182. # Transform the timestamp we received from database to pandas supported
  183. # datetime format. If no python_date_format is specified, the pattern will
  184. # be considered as the default ISO date format
  185. # If the datetime format is unix, the parse will use the corresponding
  186. # parsing logic.
  187. if df is not None and not df.empty:
  188. if DTTM_ALIAS in df.columns:
  189. if timestamp_format in ('epoch_s', 'epoch_ms'):
  190. # Column has already been formatted as a timestamp.
  191. df[DTTM_ALIAS] = df[DTTM_ALIAS].apply(pd.Timestamp)
  192. else:
  193. df[DTTM_ALIAS] = pd.to_datetime(
  194. df[DTTM_ALIAS], utc=False, format=timestamp_format)
  195. if self.datasource.offset:
  196. df[DTTM_ALIAS] += timedelta(hours=self.datasource.offset)
  197. df[DTTM_ALIAS] += self.time_shift
  198. if self.enforce_numerical_metrics:
  199. self.df_metrics_to_num(df)
  200. df.replace([np.inf, -np.inf], np.nan)
  201. df = self.handle_nulls(df)
  202. return df
  203. def df_metrics_to_num(self, df):
  204. """Converting metrics to numeric when pandas.read_sql cannot"""
  205. metrics = self.metric_labels
  206. for col, dtype in df.dtypes.items():
  207. if dtype.type == np.object_ and col in metrics:
  208. df[col] = pd.to_numeric(df[col], errors='coerce')
  209. def process_query_filters(self):
  210. utils.convert_legacy_filters_into_adhoc(self.form_data)
  211. merge_extra_filters(self.form_data)
  212. utils.split_adhoc_filters_into_base_filters(self.form_data)
  213. def query_obj(self):
  214. """Building a query object"""
  215. form_data = self.form_data
  216. self.process_query_filters()
  217. gb = form_data.get('groupby') or []
  218. metrics = self.all_metrics or []
  219. columns = form_data.get('columns') or []
  220. groupby = []
  221. for o in gb + columns:
  222. if o not in groupby:
  223. groupby.append(o)
  224. is_timeseries = self.is_timeseries
  225. if DTTM_ALIAS in groupby:
  226. groupby.remove(DTTM_ALIAS)
  227. is_timeseries = True
  228. granularity = (
  229. form_data.get('granularity') or
  230. form_data.get('granularity_sqla')
  231. )
  232. limit = int(form_data.get('limit') or 0)
  233. timeseries_limit_metric = form_data.get('timeseries_limit_metric')
  234. row_limit = int(form_data.get('row_limit') or config.get('ROW_LIMIT'))
  235. # default order direction
  236. order_desc = form_data.get('order_desc', True)
  237. since, until = utils.get_since_until(form_data)
  238. time_shift = form_data.get('time_shift', '')
  239. self.time_shift = utils.parse_human_timedelta(time_shift)
  240. from_dttm = None if since is None else (since - self.time_shift)
  241. to_dttm = None if until is None else (until - self.time_shift)
  242. if from_dttm and to_dttm and from_dttm > to_dttm:
  243. raise Exception(_('From date cannot be larger than to date'))
  244. self.from_dttm = from_dttm
  245. self.to_dttm = to_dttm
  246. # extras are used to query elements specific to a datasource type
  247. # for instance the extra where clause that applies only to Tables
  248. extras = {
  249. 'where': form_data.get('where', ''),
  250. 'having': form_data.get('having', ''),
  251. 'having_druid': form_data.get('having_filters', []),
  252. 'time_grain_sqla': form_data.get('time_grain_sqla', ''),
  253. 'druid_time_origin': form_data.get('druid_time_origin', ''),
  254. }
  255. d = {
  256. 'granularity': granularity,
  257. 'from_dttm': from_dttm,
  258. 'to_dttm': to_dttm,
  259. 'is_timeseries': is_timeseries,
  260. 'groupby': groupby,
  261. 'metrics': metrics,
  262. 'row_limit': row_limit,
  263. 'filter': self.form_data.get('filters', []),
  264. 'timeseries_limit': limit,
  265. 'extras': extras,
  266. 'timeseries_limit_metric': timeseries_limit_metric,
  267. 'order_desc': order_desc,
  268. 'prequeries': [],
  269. 'is_prequery': False,
  270. }
  271. return d
  272. @property
  273. def cache_timeout(self):
  274. if self.form_data.get('cache_timeout') is not None:
  275. return int(self.form_data.get('cache_timeout'))
  276. if self.datasource.cache_timeout is not None:
  277. return self.datasource.cache_timeout
  278. if (
  279. hasattr(self.datasource, 'database') and
  280. self.datasource.database.cache_timeout) is not None:
  281. return self.datasource.database.cache_timeout
  282. return config.get('CACHE_DEFAULT_TIMEOUT')
  283. def get_json(self):
  284. return json.dumps(
  285. self.get_payload(),
  286. default=utils.json_int_dttm_ser, ignore_nan=True)
  287. def cache_key(self, query_obj, **extra):
  288. """
  289. The cache key is made out of the key/values in `query_obj`, plus any
  290. other key/values in `extra`.
  291. We remove datetime bounds that are hard values, and replace them with
  292. the use-provided inputs to bounds, which may be time-relative (as in
  293. "5 days ago" or "now").
  294. The `extra` arguments are currently used by time shift queries, since
  295. different time shifts wil differ only in the `from_dttm` and `to_dttm`
  296. values which are stripped.
  297. """
  298. cache_dict = copy.copy(query_obj)
  299. cache_dict.update(extra)
  300. for k in ['from_dttm', 'to_dttm']:
  301. del cache_dict[k]
  302. cache_dict['time_range'] = self.form_data.get('time_range')
  303. cache_dict['datasource'] = self.datasource.uid
  304. json_data = self.json_dumps(cache_dict, sort_keys=True)
  305. return hashlib.md5(json_data.encode('utf-8')).hexdigest()
  306. def get_payload(self, query_obj=None):
  307. """Returns a payload of metadata and data"""
  308. self.run_extra_queries()
  309. payload = self.get_df_payload(query_obj)
  310. df = payload.get('df')
  311. if self.status != utils.QueryStatus.FAILED:
  312. if df is not None and df.empty:
  313. payload['error'] = 'No data'
  314. else:
  315. payload['data'] = self.get_data(df)
  316. if 'df' in payload:
  317. del payload['df']
  318. return payload
  319. def get_df_payload(self, query_obj=None, **kwargs):
  320. """Handles caching around the df payload retrieval"""
  321. if not query_obj:
  322. query_obj = self.query_obj()
  323. cache_key = self.cache_key(query_obj, **kwargs) if query_obj else None
  324. logging.info('Cache key: {}'.format(cache_key))
  325. is_loaded = False
  326. stacktrace = None
  327. df = None
  328. cached_dttm = datetime.utcnow().isoformat().split('.')[0]
  329. if cache_key and cache and not self.force:
  330. cache_value = cache.get(cache_key)
  331. if cache_value:
  332. stats_logger.incr('loaded_from_cache')
  333. try:
  334. cache_value = pkl.loads(cache_value)
  335. df = cache_value['df']
  336. self.query = cache_value['query']
  337. self._any_cached_dttm = cache_value['dttm']
  338. self._any_cache_key = cache_key
  339. self.status = utils.QueryStatus.SUCCESS
  340. is_loaded = True
  341. except Exception as e:
  342. logging.exception(e)
  343. logging.error('Error reading cache: ' +
  344. utils.error_msg_from_exception(e))
  345. logging.info('Serving from cache')
  346. if query_obj and not is_loaded:
  347. try:
  348. df = self.get_df(query_obj)
  349. if self.status != utils.QueryStatus.FAILED:
  350. stats_logger.incr('loaded_from_source')
  351. is_loaded = True
  352. except Exception as e:
  353. logging.exception(e)
  354. if not self.error_message:
  355. self.error_message = '{}'.format(e)
  356. self.status = utils.QueryStatus.FAILED
  357. stacktrace = traceback.format_exc()
  358. if (
  359. is_loaded and
  360. cache_key and
  361. cache and
  362. self.status != utils.QueryStatus.FAILED):
  363. try:
  364. cache_value = dict(
  365. dttm=cached_dttm,
  366. df=df if df is not None else None,
  367. query=self.query,
  368. )
  369. cache_value = pkl.dumps(
  370. cache_value, protocol=pkl.HIGHEST_PROTOCOL)
  371. logging.info('Caching {} chars at key {}'.format(
  372. len(cache_value), cache_key))
  373. stats_logger.incr('set_cache_key')
  374. cache.set(
  375. cache_key,
  376. cache_value,
  377. timeout=self.cache_timeout)
  378. except Exception as e:
  379. # cache.set call can fail if the backend is down or if
  380. # the key is too large or whatever other reasons
  381. logging.warning('Could not cache key {}'.format(cache_key))
  382. logging.exception(e)
  383. cache.delete(cache_key)
  384. return {
  385. 'cache_key': self._any_cache_key,
  386. 'cached_dttm': self._any_cached_dttm,
  387. 'cache_timeout': self.cache_timeout,
  388. 'df': df,
  389. 'error': self.error_message,
  390. 'form_data': self.form_data,
  391. 'is_cached': self._any_cache_key is not None,
  392. 'query': self.query,
  393. 'status': self.status,
  394. 'stacktrace': stacktrace,
  395. 'rowcount': len(df.index) if df is not None else 0,
  396. }
  397. def json_dumps(self, obj, sort_keys=False):
  398. return json.dumps(
  399. obj,
  400. default=utils.json_int_dttm_ser,
  401. ignore_nan=True,
  402. sort_keys=sort_keys,
  403. )
  404. @property
  405. def data(self):
  406. """This is the data object serialized to the js layer"""
  407. content = {
  408. 'form_data': self.form_data,
  409. 'token': self.token,
  410. 'viz_name': self.viz_type,
  411. 'filter_select_enabled': self.datasource.filter_select_enabled,
  412. }
  413. return content
  414. def get_csv(self):
  415. df = self.get_df()
  416. include_index = not isinstance(df.index, pd.RangeIndex)
  417. csv = df.to_csv(index=include_index, **config.get('CSV_EXPORT'))
  418. return csv
  419. def get_xlsx(self):
  420. df = self.get_df()
  421. # return df.to_csv(index=include_index, **config.get('CSV_EXPORT'))
  422. # 先删除原来有的xlsx
  423. for root, dirs, files in os.walk(".", topdown=False):
  424. for name in files:
  425. str=os.path.join(root, name)
  426. if str.split('.')[-1] == 'xlsx':
  427. os.remove(str)
  428. # 写入excel文件
  429. dt = datetime.now()
  430. name = dt.strftime('%Y%m%d_%H%M%S.xlsx')
  431. writer = pd.ExcelWriter(name)
  432. df.to_excel(writer, 'Sheet1')
  433. writer.save()
  434. # 读取这个excel文件
  435. file = open(name, 'rb')
  436. file_context = file.read()
  437. return file_context
  438. def get_data(self, df):
  439. return self.get_df().to_dict(orient='records')
  440. @property
  441. def json_data(self):
  442. return json.dumps(self.data)
  443. # 额外添加的ECharts极坐标柱状图
  444. class EchartsBarPolar(BaseViz):
  445. viz_type = 'echarts_bar_polar' # 对应前端的名字
  446. is_timeseries = False
  447. def should_be_timeseries(self):
  448. fd = self.form_data
  449. conditions_met = (
  450. (fd.get('granularity') and fd.get('granularity') != 'all') or
  451. (fd.get('granularity_sqla') and fd.get('time_grain_sqla'))
  452. )
  453. if fd.get('include_time') and not conditions_met:
  454. raise Exception(_(
  455. 'Pick a granularity in the Time section or '
  456. "uncheck 'Include Time'"))
  457. return fd.get('include_time')
  458. def query_obj(self):
  459. d = super(EchartsBarPolar, self).query_obj()
  460. fd = self.form_data
  461. if fd.get('all_columns') and (fd.get('groupby') or fd.get('metrics')):
  462. raise Exception(_(
  463. 'Choose either fields to [Group By] and [Metrics] or '
  464. '[Columns], not both'))
  465. sort_by = fd.get('timeseries_limit_metric')
  466. if fd.get('all_columns'):
  467. d['columns'] = fd.get('all_columns')
  468. d['groupby'] = []
  469. order_by_cols = fd.get('order_by_cols') or []
  470. d['orderby'] = [json.loads(t) for t in order_by_cols]
  471. elif sort_by:
  472. if sort_by not in d['metrics']:
  473. d['metrics'] += [sort_by]
  474. d['orderby'] = [(sort_by, not fd.get('order_desc', True))]
  475. if 'percent_metrics' in fd:
  476. d['metrics'] = d['metrics'] + list(filter(
  477. lambda m: m not in d['metrics'],
  478. fd['percent_metrics'],
  479. ))
  480. d['is_timeseries'] = self.should_be_timeseries()
  481. return d
  482. def get_data(self, df):
  483. fd = self.form_data
  484. if not self.should_be_timeseries() and DTTM_ALIAS in df:
  485. del df[DTTM_ALIAS]
  486. return dict(
  487. records=df.to_dict(orient='records'),
  488. columns=list(df.columns),
  489. )
  490. # 中国地图
  491. class ChinaMap(BaseViz):
  492. """ ChinaMap Viz """
  493. viz_type = "ChinaMap"
  494. verbose_name = _("ChinaMap")
  495. is_timeseries = False # 是否是有时效性的,即查数据是通过日期字段的
  496. def get_data(self, df):
  497. form_data = self.form_data
  498. df.sort_values(by=df.columns[0], inplace=True)
  499. print(df.values.tolist())
  500. ori_data = df.values.tolist()
  501. data = [{'name': ori_data[i][0], 'value': ori_data[i][1]} for i in range(len(ori_data))]
  502. data_name = [ori_data[i][0] for i in range(len(ori_data))]
  503. max_data = max([ori_data[i][1] for i in range(len(ori_data))])
  504. min_data = min([ori_data[i][1] for i in range(len(ori_data))])
  505. return [data, data_name, max_data, min_data]
  506. # 尝试自己添加的echarts的图
  507. class MyEchartsBar(BaseViz):
  508. """ MyEchartsBar """
  509. viz_type = "MyEchartsBar"
  510. verbose_name = _("MyEchartsBar")
  511. is_timeseries = False
  512. # 查询数据,根据用户查询的信息(字段,order by, group by等)而定
  513. def query_obj(self):
  514. d = super(MyEchartsBar, self).query_obj()
  515. fd = self.form_data
  516. if fd.get('all_columns') and (fd.get('groupby') or fd.get('metrics')):
  517. raise Exception(_(
  518. 'Choose either fields to [Group By] and [Metrics] or '
  519. '[Columns], not both'))
  520. sort_by = fd.get('timeseries_limit_metric')
  521. if fd.get('all_columns'):
  522. d['columns'] = fd.get('all_columns')
  523. d['groupby'] = []
  524. order_by_cols = fd.get('order_by_cols') or []
  525. d['orderby'] = [json.loads(t) for t in order_by_cols]
  526. elif sort_by:
  527. if sort_by not in d['metrics']:
  528. d['metrics'] += [sort_by]
  529. d['orderby'] = [(sort_by, not fd.get('order_desc', True))]
  530. if 'percent_metrics' in fd:
  531. d['metrics'] = d['metrics'] + list(filter(
  532. lambda m: m not in d['metrics'],
  533. fd['percent_metrics'],
  534. ))
  535. d['is_timeseries'] = False
  536. return d
  537. # 返回数据
  538. def get_data(self, df):
  539. return dict(
  540. records=df.to_dict(orient='records'),
  541. columns=list(df.columns),
  542. )
  543. class TableViz(BaseViz):
  544. """A basic html table that is sortable and searchable"""
  545. viz_type = 'table'
  546. verbose_name = _('Table View')
  547. credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
  548. is_timeseries = False
  549. enforce_numerical_metrics = False
  550. def should_be_timeseries(self):
  551. fd = self.form_data
  552. # TODO handle datasource-type-specific code in datasource
  553. conditions_met = (
  554. (fd.get('granularity') and fd.get('granularity') != 'all') or
  555. (fd.get('granularity_sqla') and fd.get('time_grain_sqla'))
  556. )
  557. if fd.get('include_time') and not conditions_met:
  558. raise Exception(_(
  559. 'Pick a granularity in the Time section or '
  560. "uncheck 'Include Time'"))
  561. return fd.get('include_time')
  562. def query_obj(self):
  563. d = super(TableViz, self).query_obj()
  564. fd = self.form_data
  565. if fd.get('all_columns') and (fd.get('groupby') or fd.get('metrics')):
  566. raise Exception(_(
  567. 'Choose either fields to [Group By] and [Metrics] or '
  568. '[Columns], not both'))
  569. sort_by = fd.get('timeseries_limit_metric')
  570. if fd.get('all_columns'):
  571. d['columns'] = fd.get('all_columns')
  572. d['groupby'] = []
  573. order_by_cols = fd.get('order_by_cols') or []
  574. d['orderby'] = [json.loads(t) for t in order_by_cols]
  575. elif sort_by:
  576. sort_by_label = utils.get_metric_name(sort_by)
  577. if sort_by_label not in utils.get_metric_names(d['metrics']):
  578. d['metrics'] += [sort_by]
  579. d['orderby'] = [(sort_by, not fd.get('order_desc', True))]
  580. # Add all percent metrics that are not already in the list
  581. if 'percent_metrics' in fd:
  582. d['metrics'] = d['metrics'] + list(filter(
  583. lambda m: m not in d['metrics'],
  584. fd['percent_metrics'] or [],
  585. ))
  586. d['is_timeseries'] = self.should_be_timeseries()
  587. return d
  588. def get_data(self, df):
  589. fd = self.form_data
  590. if (
  591. not self.should_be_timeseries() and
  592. df is not None and
  593. DTTM_ALIAS in df
  594. ):
  595. del df[DTTM_ALIAS]
  596. # Sum up and compute percentages for all percent metrics
  597. percent_metrics = fd.get('percent_metrics') or []
  598. percent_metrics = [self.get_metric_label(m) for m in percent_metrics]
  599. if len(percent_metrics):
  600. percent_metrics = list(filter(lambda m: m in df, percent_metrics))
  601. metric_sums = {
  602. m: reduce(lambda a, b: a + b, df[m])
  603. for m in percent_metrics
  604. }
  605. metric_percents = {
  606. m: list(map(
  607. lambda a: None if metric_sums[m] == 0 else a / metric_sums[m], df[m]))
  608. for m in percent_metrics
  609. }
  610. for m in percent_metrics:
  611. m_name = '%' + m
  612. df[m_name] = pd.Series(metric_percents[m], name=m_name)
  613. # Remove metrics that are not in the main metrics list
  614. metrics = fd.get('metrics') or []
  615. metrics = [self.get_metric_label(m) for m in metrics]
  616. for m in filter(
  617. lambda m: m not in metrics and m in df.columns,
  618. percent_metrics,
  619. ):
  620. del df[m]
  621. data = self.handle_js_int_overflow(
  622. dict(
  623. records=df.to_dict(orient='records'),
  624. columns=list(df.columns),
  625. ))
  626. return data
  627. def json_dumps(self, obj, sort_keys=False):
  628. if self.form_data.get('all_columns'):
  629. return json.dumps(
  630. obj,
  631. default=utils.json_iso_dttm_ser,
  632. sort_keys=sort_keys,
  633. ignore_nan=True)
  634. else:
  635. return super(TableViz, self).json_dumps(obj)
  636. class TimeTableViz(BaseViz):
  637. """A data table with rich time-series related columns"""
  638. viz_type = 'time_table'
  639. verbose_name = _('Time Table View')
  640. credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
  641. is_timeseries = True
  642. def query_obj(self):
  643. d = super(TimeTableViz, self).query_obj()
  644. fd = self.form_data
  645. if not fd.get('metrics'):
  646. raise Exception(_('Pick at least one metric'))
  647. if fd.get('groupby') and len(fd.get('metrics')) > 1:
  648. raise Exception(_(
  649. "When using 'Group By' you are limited to use a single metric"))
  650. return d
  651. def get_data(self, df):
  652. fd = self.form_data
  653. columns = None
  654. values = self.metric_labels
  655. if fd.get('groupby'):
  656. values = self.metric_labels[0]
  657. columns = fd.get('groupby')
  658. pt = df.pivot_table(
  659. index=DTTM_ALIAS,
  660. columns=columns,
  661. values=values)
  662. pt.index = pt.index.map(str)
  663. pt = pt.sort_index()
  664. return dict(
  665. records=pt.to_dict(orient='index'),
  666. columns=list(pt.columns),
  667. is_group_by=len(fd.get('groupby')) > 0,
  668. )
  669. class PivotTableViz(BaseViz):
  670. """A pivot table view, define your rows, columns and metrics"""
  671. viz_type = 'pivot_table'
  672. verbose_name = _('Pivot Table')
  673. credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
  674. is_timeseries = False
  675. def query_obj(self):
  676. d = super(PivotTableViz, self).query_obj()
  677. groupby = self.form_data.get('groupby')
  678. columns = self.form_data.get('columns')
  679. metrics = self.form_data.get('metrics')
  680. if not columns:
  681. columns = []
  682. if not groupby:
  683. groupby = []
  684. if not groupby:
  685. raise Exception(_("Please choose at least one 'Group by' field "))
  686. if not metrics:
  687. raise Exception(_('Please choose at least one metric'))
  688. if (
  689. any(v in groupby for v in columns) or
  690. any(v in columns for v in groupby)):
  691. raise Exception(_("Group By' and 'Columns' can't overlap"))
  692. return d
  693. def get_data(self, df):
  694. if (
  695. self.form_data.get('granularity') == 'all' and
  696. DTTM_ALIAS in df):
  697. del df[DTTM_ALIAS]
  698. df = df.pivot_table(
  699. index=self.form_data.get('groupby'),
  700. columns=self.form_data.get('columns'),
  701. values=[self.get_metric_label(m) for m in self.form_data.get('metrics')],
  702. aggfunc=self.form_data.get('pandas_aggfunc'),
  703. margins=self.form_data.get('pivot_margins'),
  704. )
  705. # Display metrics side by side with each column
  706. if self.form_data.get('combine_metric'):
  707. df = df.stack(0).unstack()
  708. return dict(
  709. columns=list(df.columns),
  710. html=df.to_html(
  711. na_rep='',
  712. classes=(
  713. 'dataframe table table-striped table-bordered '
  714. 'table-condensed table-hover').split(' ')),
  715. )
  716. class MarkupViz(BaseViz):
  717. """Use html or markdown to create a free form widget"""
  718. viz_type = 'markup'
  719. verbose_name = _('Markup')
  720. is_timeseries = False
  721. def query_obj(self):
  722. return None
  723. def get_df(self, query_obj=None):
  724. return None
  725. def get_data(self, df):
  726. markup_type = self.form_data.get('markup_type')
  727. code = self.form_data.get('code', '')
  728. if markup_type == 'markdown':
  729. code = markdown(code)
  730. return dict(html=code, theme_css=get_css_manifest_files('theme'))
  731. class SeparatorViz(MarkupViz):
  732. """Use to create section headers in a dashboard, similar to `Markup`"""
  733. viz_type = 'separator'
  734. verbose_name = _('Separator')
  735. class WordCloudViz(BaseViz):
  736. """Build a colorful word cloud
  737. Uses the nice library at:
  738. https://github.com/jasondavies/d3-cloud
  739. """
  740. viz_type = 'word_cloud'
  741. verbose_name = _('Word Cloud')
  742. is_timeseries = False
  743. def query_obj(self):
  744. d = super(WordCloudViz, self).query_obj()
  745. d['groupby'] = [self.form_data.get('series')]
  746. return d
  747. class TreemapViz(BaseViz):
  748. """Tree map visualisation for hierarchical data."""
  749. viz_type = 'treemap'
  750. verbose_name = _('Treemap')
  751. credits = '<a href="https://d3js.org">d3.js</a>'
  752. is_timeseries = False
  753. def _nest(self, metric, df):
  754. nlevels = df.index.nlevels
  755. if nlevels == 1:
  756. result = [{'name': n, 'value': v}
  757. for n, v in zip(df.index, df[metric])]
  758. else:
  759. result = [{'name': l, 'children': self._nest(metric, df.loc[l])}
  760. for l in df.index.levels[0]]
  761. return result
  762. def get_data(self, df):
  763. df = df.set_index(self.form_data.get('groupby'))
  764. chart_data = [{'name': metric, 'children': self._nest(metric, df)}
  765. for metric in df.columns]
  766. return chart_data
  767. class CalHeatmapViz(BaseViz):
  768. """Calendar heatmap."""
  769. viz_type = 'cal_heatmap'
  770. verbose_name = _('Calendar Heatmap')
  771. credits = (
  772. '<a href=https://github.com/wa0x6e/cal-heatmap>cal-heatmap</a>')
  773. is_timeseries = True
  774. def get_data(self, df):
  775. form_data = self.form_data
  776. data = {}
  777. records = df.to_dict('records')
  778. for metric in self.metric_labels:
  779. data[metric] = {
  780. str(obj[DTTM_ALIAS].value / 10**9): obj.get(metric)
  781. for obj in records
  782. }
  783. start, end = utils.get_since_until(form_data)
  784. if not start or not end:
  785. raise Exception('Please provide both time bounds (Since and Until)')
  786. domain = form_data.get('domain_granularity')
  787. diff_delta = rdelta.relativedelta(end, start)
  788. diff_secs = (end - start).total_seconds()
  789. if domain == 'year':
  790. range_ = diff_delta.years + 1
  791. elif domain == 'month':
  792. range_ = diff_delta.years * 12 + diff_delta.months + 1
  793. elif domain == 'week':
  794. range_ = diff_delta.years * 53 + diff_delta.weeks + 1
  795. elif domain == 'day':
  796. range_ = diff_secs // (24 * 60 * 60) + 1
  797. else:
  798. range_ = diff_secs // (60 * 60) + 1
  799. return {
  800. 'data': data,
  801. 'start': start,
  802. 'domain': domain,
  803. 'subdomain': form_data.get('subdomain_granularity'),
  804. 'range': range_,
  805. }
  806. def query_obj(self):
  807. d = super(CalHeatmapViz, self).query_obj()
  808. fd = self.form_data
  809. d['metrics'] = fd.get('metrics')
  810. return d
  811. class NVD3Viz(BaseViz):
  812. """Base class for all nvd3 vizs"""
  813. credits = '<a href="http://nvd3.org/">NVD3.org</a>'
  814. viz_type = None
  815. verbose_name = 'Base NVD3 Viz'
  816. is_timeseries = False
  817. class BoxPlotViz(NVD3Viz):
  818. """Box plot viz from ND3"""
  819. viz_type = 'box_plot'
  820. verbose_name = _('Box Plot')
  821. sort_series = False
  822. is_timeseries = True
  823. def to_series(self, df, classed='', title_suffix=''):
  824. label_sep = ' - '
  825. chart_data = []
  826. for index_value, row in zip(df.index, df.to_dict(orient='records')):
  827. if isinstance(index_value, tuple):
  828. index_value = label_sep.join(index_value)
  829. boxes = defaultdict(dict)
  830. for (label, key), value in row.items():
  831. if key == 'median':
  832. key = 'Q2'
  833. boxes[label][key] = value
  834. for label, box in boxes.items():
  835. if len(self.form_data.get('metrics')) > 1:
  836. # need to render data labels with metrics
  837. chart_label = label_sep.join([index_value, label])
  838. else:
  839. chart_label = index_value
  840. chart_data.append({
  841. 'label': chart_label,
  842. 'values': box,
  843. })
  844. return chart_data
  845. def get_data(self, df):
  846. form_data = self.form_data
  847. df = df.fillna(0)
  848. # conform to NVD3 names
  849. def Q1(series): # need to be named functions - can't use lambdas
  850. return np.percentile(series, 25)
  851. def Q3(series):
  852. return np.percentile(series, 75)
  853. whisker_type = form_data.get('whisker_options')
  854. if whisker_type == 'Tukey':
  855. def whisker_high(series):
  856. upper_outer_lim = Q3(series) + 1.5 * (Q3(series) - Q1(series))
  857. series = series[series <= upper_outer_lim]
  858. return series[np.abs(series - upper_outer_lim).argmin()]
  859. def whisker_low(series):
  860. lower_outer_lim = Q1(series) - 1.5 * (Q3(series) - Q1(series))
  861. # find the closest value above the lower outer limit
  862. series = series[series >= lower_outer_lim]
  863. return series[np.abs(series - lower_outer_lim).argmin()]
  864. elif whisker_type == 'Min/max (no outliers)':
  865. def whisker_high(series):
  866. return series.max()
  867. def whisker_low(series):
  868. return series.min()
  869. elif ' percentiles' in whisker_type:
  870. low, high = whisker_type.replace(' percentiles', '').split('/')
  871. def whisker_high(series):
  872. return np.percentile(series, int(high))
  873. def whisker_low(series):
  874. return np.percentile(series, int(low))
  875. else:
  876. raise ValueError('Unknown whisker type: {}'.format(whisker_type))
  877. def outliers(series):
  878. above = series[series > whisker_high(series)]
  879. below = series[series < whisker_low(series)]
  880. # pandas sometimes doesn't like getting lists back here
  881. return set(above.tolist() + below.tolist())
  882. aggregate = [Q1, np.median, Q3, whisker_high, whisker_low, outliers]
  883. df = df.groupby(form_data.get('groupby')).agg(aggregate)
  884. chart_data = self.to_series(df)
  885. return chart_data
  886. class BubbleViz(NVD3Viz):
  887. """Based on the NVD3 bubble chart"""
  888. viz_type = 'bubble'
  889. verbose_name = _('Bubble Chart')
  890. is_timeseries = False
  891. def query_obj(self):
  892. form_data = self.form_data
  893. d = super(BubbleViz, self).query_obj()
  894. d['groupby'] = [
  895. form_data.get('entity'),
  896. ]
  897. if form_data.get('series'):
  898. d['groupby'].append(form_data.get('series'))
  899. self.x_metric = form_data.get('x')
  900. self.y_metric = form_data.get('y')
  901. self.z_metric = form_data.get('size')
  902. self.entity = form_data.get('entity')
  903. self.series = form_data.get('series') or self.entity
  904. d['row_limit'] = form_data.get('limit')
  905. d['metrics'] = [
  906. self.z_metric,
  907. self.x_metric,
  908. self.y_metric,
  909. ]
  910. if not all(d['metrics'] + [self.entity]):
  911. raise Exception(_('Pick a metric for x, y and size'))
  912. return d
  913. def get_data(self, df):
  914. df['x'] = df[[utils.get_metric_name(self.x_metric)]]
  915. df['y'] = df[[utils.get_metric_name(self.y_metric)]]
  916. df['size'] = df[[utils.get_metric_name(self.z_metric)]]
  917. df['shape'] = 'circle'
  918. df['group'] = df[[self.series]]
  919. series = defaultdict(list)
  920. for row in df.to_dict(orient='records'):
  921. series[row['group']].append(row)
  922. chart_data = []
  923. for k, v in series.items():
  924. chart_data.append({
  925. 'key': k,
  926. 'values': v})
  927. return chart_data
  928. class BulletViz(NVD3Viz):
  929. """Based on the NVD3 bullet chart"""
  930. viz_type = 'bullet'
  931. verbose_name = _('Bullet Chart')
  932. is_timeseries = False
  933. def query_obj(self):
  934. form_data = self.form_data
  935. d = super(BulletViz, self).query_obj()
  936. self.metric = form_data.get('metric')
  937. def as_strings(field):
  938. value = form_data.get(field)
  939. return value.split(',') if value else []
  940. def as_floats(field):
  941. return [float(x) for x in as_strings(field)]
  942. self.ranges = as_floats('ranges')
  943. self.range_labels = as_strings('range_labels')
  944. self.markers = as_floats('markers')
  945. self.marker_labels = as_strings('marker_labels')
  946. self.marker_lines = as_floats('marker_lines')
  947. self.marker_line_labels = as_strings('marker_line_labels')
  948. d['metrics'] = [
  949. self.metric,
  950. ]
  951. if not self.metric:
  952. raise Exception(_('Pick a metric to display'))
  953. return d
  954. def get_data(self, df):
  955. df = df.fillna(0)
  956. df['metric'] = df[[self.get_metric_label(self.metric)]]
  957. values = df['metric'].values
  958. return {
  959. 'measures': values.tolist(),
  960. 'ranges': self.ranges or [0, values.max() * 1.1],
  961. 'rangeLabels': self.range_labels or None,
  962. 'markers': self.markers or None,
  963. 'markerLabels': self.marker_labels or None,
  964. 'markerLines': self.marker_lines or None,
  965. 'markerLineLabels': self.marker_line_labels or None,
  966. }
  967. class BigNumberViz(BaseViz):
  968. """Put emphasis on a single metric with this big number viz"""
  969. viz_type = 'big_number'
  970. verbose_name = _('Big Number with Trendline')
  971. credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
  972. is_timeseries = True
  973. def query_obj(self):
  974. d = super(BigNumberViz, self).query_obj()
  975. metric = self.form_data.get('metric')
  976. if not metric:
  977. raise Exception(_('Pick a metric!'))
  978. d['metrics'] = [self.form_data.get('metric')]
  979. self.form_data['metric'] = metric
  980. return d
  981. class BigNumberTotalViz(BaseViz):
  982. """Put emphasis on a single metric with this big number viz"""
  983. viz_type = 'big_number_total'
  984. verbose_name = _('Big Number')
  985. credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
  986. is_timeseries = False
  987. def query_obj(self):
  988. d = super(BigNumberTotalViz, self).query_obj()
  989. metric = self.form_data.get('metric')
  990. if not metric:
  991. raise Exception(_('Pick a metric!'))
  992. d['metrics'] = [self.form_data.get('metric')]
  993. self.form_data['metric'] = metric
  994. return d
  995. class NVD3TimeSeriesViz(NVD3Viz):
  996. """A rich line chart component with tons of options"""
  997. viz_type = 'line'
  998. verbose_name = _('Time Series - Line Chart')
  999. sort_series = False
  1000. is_timeseries = True
  1001. def to_series(self, df, classed='', title_suffix=''):
  1002. cols = []
  1003. for col in df.columns:
  1004. if col == '':
  1005. cols.append('N/A')
  1006. elif col is None:
  1007. cols.append('NULL')
  1008. else:
  1009. cols.append(col)
  1010. df.columns = cols
  1011. series = df.to_dict('series')
  1012. chart_data = []
  1013. for name in df.T.index.tolist():
  1014. ys = series[name]
  1015. if df[name].dtype.kind not in 'biufc':
  1016. continue
  1017. if isinstance(name, list):
  1018. series_title = [str(title) for title in name]
  1019. elif isinstance(name, tuple):
  1020. series_title = tuple(str(title) for title in name)
  1021. else:
  1022. series_title = str(name)
  1023. if (
  1024. isinstance(series_title, (list, tuple)) and
  1025. len(series_title) > 1 and
  1026. len(self.metric_labels) == 1):
  1027. # Removing metric from series name if only one metric
  1028. series_title = series_title[1:]
  1029. if title_suffix:
  1030. if isinstance(series_title, str):
  1031. series_title = (series_title, title_suffix)
  1032. elif isinstance(series_title, (list, tuple)):
  1033. series_title = series_title + (title_suffix,)
  1034. values = []
  1035. for ds in df.index:
  1036. if ds in ys:
  1037. d = {
  1038. 'x': ds,
  1039. 'y': ys[ds],
  1040. }
  1041. else:
  1042. d = {}
  1043. values.append(d)
  1044. d = {
  1045. 'key': series_title,
  1046. 'values': values,
  1047. }
  1048. if classed:
  1049. d['classed'] = classed
  1050. chart_data.append(d)
  1051. return chart_data
  1052. def process_data(self, df, aggregate=False):
  1053. fd = self.form_data
  1054. df = df.fillna(0)
  1055. if fd.get('granularity') == 'all':
  1056. raise Exception(_('Pick a time granularity for your time series'))
  1057. if not aggregate:
  1058. df = df.pivot_table(
  1059. index=DTTM_ALIAS,
  1060. columns=fd.get('groupby'),
  1061. values=self.metric_labels)
  1062. else:
  1063. df = df.pivot_table(
  1064. index=DTTM_ALIAS,
  1065. columns=fd.get('groupby'),
  1066. values=self.metric_labels,
  1067. fill_value=0,
  1068. aggfunc=sum)
  1069. fm = fd.get('resample_fillmethod')
  1070. if not fm:
  1071. fm = None
  1072. how = fd.get('resample_how')
  1073. rule = fd.get('resample_rule')
  1074. if how and rule:
  1075. df = df.resample(rule, how=how, fill_method=fm)
  1076. if not fm:
  1077. df = df.fillna(0)
  1078. if self.sort_series:
  1079. dfs = df.sum()
  1080. dfs.sort_values(ascending=False, inplace=True)
  1081. df = df[dfs.index]
  1082. if fd.get('contribution'):
  1083. dft = df.T
  1084. df = (dft / dft.sum()).T
  1085. rolling_type = fd.get('rolling_type')
  1086. rolling_periods = int(fd.get('rolling_periods') or 0)
  1087. min_periods = int(fd.get('min_periods') or 0)
  1088. if rolling_type in ('mean', 'std', 'sum') and rolling_periods:
  1089. kwargs = dict(
  1090. window=rolling_periods,
  1091. min_periods=min_periods)
  1092. if rolling_type == 'mean':
  1093. df = df.rolling(**kwargs).mean()
  1094. elif rolling_type == 'std':
  1095. df = df.rolling(**kwargs).std()
  1096. elif rolling_type == 'sum':
  1097. df = df.rolling(**kwargs).sum()
  1098. elif rolling_type == 'cumsum':
  1099. df = df.cumsum()
  1100. if min_periods:
  1101. df = df[min_periods:]
  1102. return df
  1103. def run_extra_queries(self):
  1104. fd = self.form_data
  1105. time_compare = fd.get('time_compare') or []
  1106. # backwards compatibility
  1107. if not isinstance(time_compare, list):
  1108. time_compare = [time_compare]
  1109. for option in time_compare:
  1110. query_object = self.query_obj()
  1111. delta = utils.parse_human_timedelta(option)
  1112. query_object['inner_from_dttm'] = query_object['from_dttm']
  1113. query_object['inner_to_dttm'] = query_object['to_dttm']
  1114. if not query_object['from_dttm'] or not query_object['to_dttm']:
  1115. raise Exception(_(
  1116. '`Since` and `Until` time bounds should be specified '
  1117. 'when using the `Time Shift` feature.'))
  1118. query_object['from_dttm'] -= delta
  1119. query_object['to_dttm'] -= delta
  1120. df2 = self.get_df_payload(query_object, time_compare=option).get('df')
  1121. if df2 is not None and DTTM_ALIAS in df2:
  1122. label = '{} offset'. format(option)
  1123. df2[DTTM_ALIAS] += delta
  1124. df2 = self.process_data(df2)
  1125. self._extra_chart_data.append((label, df2))
  1126. def get_data(self, df):
  1127. fd = self.form_data
  1128. comparison_type = fd.get('comparison_type') or 'values'
  1129. df = self.process_data(df)
  1130. if comparison_type == 'values':
  1131. chart_data = self.to_series(df)
  1132. for i, (label, df2) in enumerate(self._extra_chart_data):
  1133. chart_data.extend(
  1134. self.to_series(
  1135. df2, classed='time-shift-{}'.format(i), title_suffix=label))
  1136. else:
  1137. chart_data = []
  1138. for i, (label, df2) in enumerate(self._extra_chart_data):
  1139. # reindex df2 into the df2 index
  1140. combined_index = df.index.union(df2.index)
  1141. df2 = df2.reindex(combined_index) \
  1142. .interpolate(method='time') \
  1143. .reindex(df.index)
  1144. if comparison_type == 'absolute':
  1145. diff = df - df2
  1146. elif comparison_type == 'percentage':
  1147. diff = (df - df2) / df2
  1148. elif comparison_type == 'ratio':
  1149. diff = df / df2
  1150. else:
  1151. raise Exception(
  1152. 'Invalid `comparison_type`: {0}'.format(comparison_type))
  1153. # remove leading/trailing NaNs from the time shift difference
  1154. diff = diff[diff.first_valid_index():diff.last_valid_index()]
  1155. chart_data.extend(
  1156. self.to_series(
  1157. diff, classed='time-shift-{}'.format(i), title_suffix=label))
  1158. return sorted(chart_data, key=lambda x: tuple(x['key']))
  1159. class MultiLineViz(NVD3Viz):
  1160. """Pile on multiple line charts"""
  1161. viz_type = 'line_multi'
  1162. verbose_name = _('Time Series - Multiple Line Charts')
  1163. is_timeseries = True
  1164. def query_obj(self):
  1165. return None
  1166. def get_data(self, df):
  1167. fd = self.form_data
  1168. # Late imports to avoid circular import issues
  1169. from superset.models.core import Slice
  1170. from superset import db
  1171. slice_ids1 = fd.get('line_charts')
  1172. slices1 = db.session.query(Slice).filter(Slice.id.in_(slice_ids1)).all()
  1173. slice_ids2 = fd.get('line_charts_2')
  1174. slices2 = db.session.query(Slice).filter(Slice.id.in_(slice_ids2)).all()
  1175. return {
  1176. 'slices': {
  1177. 'axis1': [slc.data for slc in slices1],
  1178. 'axis2': [slc.data for slc in slices2],
  1179. },
  1180. }
  1181. class NVD3DualLineViz(NVD3Viz):
  1182. """A rich line chart with dual axis"""
  1183. viz_type = 'dual_line'
  1184. verbose_name = _('Time Series - Dual Axis Line Chart')
  1185. sort_series = False
  1186. is_timeseries = True
  1187. def query_obj(self):
  1188. d = super(NVD3DualLineViz, self).query_obj()
  1189. m1 = self.form_data.get('metric')
  1190. m2 = self.form_data.get('metric_2')
  1191. d['metrics'] = [m1, m2]
  1192. if not m1:
  1193. raise Exception(_('Pick a metric for left axis!'))
  1194. if not m2:
  1195. raise Exception(_('Pick a metric for right axis!'))
  1196. if m1 == m2:
  1197. raise Exception(_('Please choose different metrics'
  1198. ' on left and right axis'))
  1199. return d
  1200. def to_series(self, df, classed=''):
  1201. cols = []
  1202. for col in df.columns:
  1203. if col == '':
  1204. cols.append('N/A')
  1205. elif col is None:
  1206. cols.append('NULL')
  1207. else:
  1208. cols.append(col)
  1209. df.columns = cols
  1210. series = df.to_dict('series')
  1211. chart_data = []
  1212. metrics = [
  1213. self.form_data.get('metric'),
  1214. self.form_data.get('metric_2'),
  1215. ]
  1216. for i, m in enumerate(metrics):
  1217. m = utils.get_metric_name(m)
  1218. ys = series[m]
  1219. if df[m].dtype.kind not in 'biufc':
  1220. continue
  1221. series_title = m
  1222. d = {
  1223. 'key': series_title,
  1224. 'classed': classed,
  1225. 'values': [
  1226. {'x': ds, 'y': ys[ds] if ds in ys else None}
  1227. for ds in df.index
  1228. ],
  1229. 'yAxis': i + 1,
  1230. 'type': 'line',
  1231. }
  1232. chart_data.append(d)
  1233. return chart_data
  1234. def get_data(self, df):
  1235. fd = self.form_data
  1236. df = df.fillna(0)
  1237. if self.form_data.get('granularity') == 'all':
  1238. raise Exception(_('Pick a time granularity for your time series'))
  1239. metric = self.get_metric_label(fd.get('metric'))
  1240. metric_2 = self.get_metric_label(fd.get('metric_2'))
  1241. df = df.pivot_table(
  1242. index=DTTM_ALIAS,
  1243. values=[metric, metric_2])
  1244. chart_data = self.to_series(df)
  1245. return chart_data
  1246. class NVD3TimeSeriesBarViz(NVD3TimeSeriesViz):
  1247. """A bar chart where the x axis is time"""
  1248. viz_type = 'bar'
  1249. sort_series = True
  1250. verbose_name = _('Time Series - Bar Chart')
  1251. class NVD3TimePivotViz(NVD3TimeSeriesViz):
  1252. """Time Series - Periodicity Pivot"""
  1253. viz_type = 'time_pivot'
  1254. sort_series = True
  1255. verbose_name = _('Time Series - Period Pivot')
  1256. def query_obj(self):
  1257. d = super(NVD3TimePivotViz, self).query_obj()
  1258. d['metrics'] = [self.form_data.get('metric')]
  1259. return d
  1260. def get_data(self, df):
  1261. fd = self.form_data
  1262. df = self.process_data(df)
  1263. freq = to_offset(fd.get('freq'))
  1264. freq.normalize = True
  1265. df[DTTM_ALIAS] = df.index.map(freq.rollback)
  1266. df['ranked'] = df[DTTM_ALIAS].rank(method='dense', ascending=False) - 1
  1267. df.ranked = df.ranked.map(int)
  1268. df['series'] = '-' + df.ranked.map(str)
  1269. df['series'] = df['series'].str.replace('-0', 'current')
  1270. rank_lookup = {
  1271. row['series']: row['ranked']
  1272. for row in df.to_dict(orient='records')
  1273. }
  1274. max_ts = df[DTTM_ALIAS].max()
  1275. max_rank = df['ranked'].max()
  1276. df[DTTM_ALIAS] = df.index + (max_ts - df[DTTM_ALIAS])
  1277. df = df.pivot_table(
  1278. index=DTTM_ALIAS,
  1279. columns='series',
  1280. values=self.get_metric_label(fd.get('metric')))
  1281. chart_data = self.to_series(df)
  1282. for serie in chart_data:
  1283. serie['rank'] = rank_lookup[serie['key']]
  1284. serie['perc'] = 1 - (serie['rank'] / (max_rank + 1))
  1285. return chart_data
  1286. class NVD3CompareTimeSeriesViz(NVD3TimeSeriesViz):
  1287. """A line chart component where you can compare the % change over time"""
  1288. viz_type = 'compare'
  1289. verbose_name = _('Time Series - Percent Change')
  1290. class NVD3TimeSeriesStackedViz(NVD3TimeSeriesViz):
  1291. """A rich stack area chart"""
  1292. viz_type = 'area'
  1293. verbose_name = _('Time Series - Stacked')
  1294. sort_series = True
  1295. class DistributionPieViz(NVD3Viz):
  1296. """Annoy visualization snobs with this controversial pie chart"""
  1297. viz_type = 'pie'
  1298. verbose_name = _('Distribution - NVD3 - Pie Chart')
  1299. is_timeseries = False
  1300. def get_data(self, df):
  1301. metric = self.metric_labels[0]
  1302. df = df.pivot_table(
  1303. index=self.groupby,
  1304. values=[metric])
  1305. df.sort_values(by=metric, ascending=False, inplace=True)
  1306. df = df.reset_index()
  1307. df.columns = ['x', 'y']
  1308. return df.to_dict(orient='records')
  1309. class HistogramViz(BaseViz):
  1310. """Histogram"""
  1311. viz_type = 'histogram'
  1312. verbose_name = _('Histogram')
  1313. is_timeseries = False
  1314. def query_obj(self):
  1315. """Returns the query object for this visualization"""
  1316. d = super(HistogramViz, self).query_obj()
  1317. d['row_limit'] = self.form_data.get(
  1318. 'row_limit', int(config.get('VIZ_ROW_LIMIT')))
  1319. numeric_columns = self.form_data.get('all_columns_x')
  1320. if numeric_columns is None:
  1321. raise Exception(_('Must have at least one numeric column specified'))
  1322. self.columns = numeric_columns
  1323. d['columns'] = numeric_columns + self.groupby
  1324. # override groupby entry to avoid aggregation
  1325. d['groupby'] = []
  1326. return d
  1327. def get_data(self, df):
  1328. """Returns the chart data"""
  1329. chart_data = []
  1330. if len(self.groupby) > 0:
  1331. groups = df.groupby(self.groupby)
  1332. else:
  1333. groups = [((), df)]
  1334. for keys, data in groups:
  1335. if isinstance(keys, str):
  1336. keys = (keys,)
  1337. # removing undesirable characters
  1338. keys = [re.sub(r'\W+', r'_', k) for k in keys]
  1339. chart_data.extend([{
  1340. 'key': '__'.join([c] + keys),
  1341. 'values': data[c].tolist()}
  1342. for c in self.columns])
  1343. return chart_data
  1344. class DistributionBarViz(DistributionPieViz):
  1345. """A good old bar chart"""
  1346. viz_type = 'dist_bar'
  1347. verbose_name = _('Distribution - Bar Chart')
  1348. is_timeseries = False
  1349. def query_obj(self):
  1350. d = super(DistributionBarViz, self).query_obj() # noqa
  1351. fd = self.form_data
  1352. if (
  1353. len(d['groupby']) <
  1354. len(fd.get('groupby') or []) + len(fd.get('columns') or [])
  1355. ):
  1356. raise Exception(
  1357. _("Can't have overlap between Series and Breakdowns"))
  1358. if not fd.get('metrics'):
  1359. raise Exception(_('Pick at least one metric'))
  1360. if not fd.get('groupby'):
  1361. raise Exception(_('Pick at least one field for [Series]'))
  1362. return d
  1363. def get_data(self, df):
  1364. fd = self.form_data
  1365. metrics = self.metric_labels
  1366. row = df.groupby(self.groupby).sum()[metrics[0]].copy()
  1367. row.sort_values(ascending=False, inplace=True)
  1368. columns = fd.get('columns') or []
  1369. pt = df.pivot_table(
  1370. index=self.groupby,
  1371. columns=columns,
  1372. values=metrics)
  1373. if fd.get('contribution'):
  1374. pt = pt.fillna(0)
  1375. pt = pt.T
  1376. pt = (pt / pt.sum()).T
  1377. pt = pt.reindex(row.index)
  1378. chart_data = []
  1379. for name, ys in pt.items():
  1380. if pt[name].dtype.kind not in 'biufc' or name in self.groupby:
  1381. continue
  1382. if isinstance(name, str):
  1383. series_title = name
  1384. else:
  1385. offset = 0 if len(metrics) > 1 else 1
  1386. series_title = ', '.join([str(s) for s in name[offset:]])
  1387. values = []
  1388. for i, v in ys.items():
  1389. x = i
  1390. if isinstance(x, (tuple, list)):
  1391. x = ', '.join([str(s) for s in x])
  1392. else:
  1393. x = str(x)
  1394. values.append({
  1395. 'x': x,
  1396. 'y': v,
  1397. })
  1398. d = {
  1399. 'key': series_title,
  1400. 'values': values,
  1401. }
  1402. chart_data.append(d)
  1403. return chart_data
  1404. class SunburstViz(BaseViz):
  1405. """A multi level sunburst chart"""
  1406. viz_type = 'sunburst'
  1407. verbose_name = _('Sunburst')
  1408. is_timeseries = False
  1409. credits = (
  1410. 'Kerry Rodden '
  1411. '@<a href="https://bl.ocks.org/kerryrodden/7090426">bl.ocks.org</a>')
  1412. def get_data(self, df):
  1413. fd = self.form_data
  1414. cols = fd.get('groupby')
  1415. metric = self.get_metric_label(fd.get('metric'))
  1416. secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
  1417. if metric == secondary_metric or secondary_metric is None:
  1418. df.columns = cols + ['m1']
  1419. df['m2'] = df['m1']
  1420. return json.loads(df.to_json(orient='values'))
  1421. def query_obj(self):
  1422. qry = super(SunburstViz, self).query_obj()
  1423. fd = self.form_data
  1424. qry['metrics'] = [fd['metric']]
  1425. secondary_metric = fd.get('secondary_metric')
  1426. if secondary_metric and secondary_metric != fd['metric']:
  1427. qry['metrics'].append(secondary_metric)
  1428. return qry
  1429. class SankeyViz(BaseViz):
  1430. """A Sankey diagram that requires a parent-child dataset"""
  1431. viz_type = 'sankey'
  1432. verbose_name = _('Sankey')
  1433. is_timeseries = False
  1434. credits = '<a href="https://www.npmjs.com/package/d3-sankey">d3-sankey on npm</a>'
  1435. def query_obj(self):
  1436. qry = super(SankeyViz, self).query_obj()
  1437. if len(qry['groupby']) != 2:
  1438. raise Exception(_('Pick exactly 2 columns as [Source / Target]'))
  1439. qry['metrics'] = [
  1440. self.form_data['metric']]
  1441. return qry
  1442. def get_data(self, df):
  1443. df.columns = ['source', 'target', 'value']
  1444. df['source'] = df['source'].astype(basestring)
  1445. df['target'] = df['target'].astype(basestring)
  1446. recs = df.to_dict(orient='records')
  1447. hierarchy = defaultdict(set)
  1448. for row in recs:
  1449. hierarchy[row['source']].add(row['target'])
  1450. def find_cycle(g):
  1451. """Whether there's a cycle in a directed graph"""
  1452. path = set()
  1453. def visit(vertex):
  1454. path.add(vertex)
  1455. for neighbour in g.get(vertex, ()):
  1456. if neighbour in path or visit(neighbour):
  1457. return (vertex, neighbour)
  1458. path.remove(vertex)
  1459. for v in g:
  1460. cycle = visit(v)
  1461. if cycle:
  1462. return cycle
  1463. cycle = find_cycle(hierarchy)
  1464. if cycle:
  1465. raise Exception(_(
  1466. "There's a loop in your Sankey, please provide a tree. "
  1467. "Here's a faulty link: {}").format(cycle))
  1468. return recs
  1469. class DirectedForceViz(BaseViz):
  1470. """An animated directed force layout graph visualization"""
  1471. viz_type = 'directed_force'
  1472. verbose_name = _('Directed Force Layout')
  1473. credits = 'd3noob @<a href="http://bl.ocks.org/d3noob/5141278">bl.ocks.org</a>'
  1474. is_timeseries = False
  1475. def query_obj(self):
  1476. qry = super(DirectedForceViz, self).query_obj()
  1477. if len(self.form_data['groupby']) != 2:
  1478. raise Exception(_("Pick exactly 2 columns to 'Group By'"))
  1479. qry['metrics'] = [self.form_data['metric']]
  1480. return qry
  1481. def get_data(self, df):
  1482. df.columns = ['source', 'target', 'value']
  1483. return df.to_dict(orient='records')
  1484. class ChordViz(BaseViz):
  1485. """A Chord diagram"""
  1486. viz_type = 'chord'
  1487. verbose_name = _('Directed Force Layout')
  1488. credits = '<a href="https://github.com/d3/d3-chord">Bostock</a>'
  1489. is_timeseries = False
  1490. def query_obj(self):
  1491. qry = super(ChordViz, self).query_obj()
  1492. fd = self.form_data
  1493. qry['groupby'] = [fd.get('groupby'), fd.get('columns')]
  1494. qry['metrics'] = [self.get_metric_label(fd.get('metric'))]
  1495. return qry
  1496. def get_data(self, df):
  1497. df.columns = ['source', 'target', 'value']
  1498. # Preparing a symetrical matrix like d3.chords calls for
  1499. nodes = list(set(df['source']) | set(df['target']))
  1500. matrix = {}
  1501. for source, target in product(nodes, nodes):
  1502. matrix[(source, target)] = 0
  1503. for source, target, value in df.to_records(index=False):
  1504. matrix[(source, target)] = value
  1505. m = [[matrix[(n1, n2)] for n1 in nodes] for n2 in nodes]
  1506. return {
  1507. 'nodes': list(nodes),
  1508. 'matrix': m,
  1509. }
  1510. class CountryMapViz(BaseViz):
  1511. """A country centric"""
  1512. viz_type = 'country_map'
  1513. verbose_name = _('Country Map')
  1514. is_timeseries = False
  1515. credits = 'From bl.ocks.org By john-guerra'
  1516. def query_obj(self):
  1517. qry = super(CountryMapViz, self).query_obj()
  1518. qry['metrics'] = [
  1519. self.form_data['metric']]
  1520. qry['groupby'] = [self.form_data['entity']]
  1521. return qry
  1522. def get_data(self, df):
  1523. fd = self.form_data
  1524. cols = [fd.get('entity')]
  1525. metric = self.metric_labels[0]
  1526. cols += [metric]
  1527. ndf = df[cols]
  1528. df = ndf
  1529. df.columns = ['country_id', 'metric']
  1530. d = df.to_dict(orient='records')
  1531. return d
  1532. class WorldMapViz(BaseViz):
  1533. """A country centric world map"""
  1534. viz_type = 'world_map'
  1535. verbose_name = _('World Map')
  1536. is_timeseries = False
  1537. credits = 'datamaps on <a href="https://www.npmjs.com/package/datamaps">npm</a>'
  1538. def query_obj(self):
  1539. qry = super(WorldMapViz, self).query_obj()
  1540. qry['groupby'] = [self.form_data['entity']]
  1541. return qry
  1542. def get_data(self, df):
  1543. from superset.data import countries
  1544. fd = self.form_data
  1545. cols = [fd.get('entity')]
  1546. metric = self.get_metric_label(fd.get('metric'))
  1547. secondary_metric = self.get_metric_label(fd.get('secondary_metric'))
  1548. columns = ['country', 'm1', 'm2']
  1549. if metric == secondary_metric:
  1550. ndf = df[cols]
  1551. # df[metric] will be a DataFrame
  1552. # because there are duplicate column names
  1553. ndf['m1'] = df[metric].iloc[:, 0]
  1554. ndf['m2'] = ndf['m1']
  1555. else:
  1556. if secondary_metric:
  1557. cols += [metric, secondary_metric]
  1558. else:
  1559. cols += [metric]
  1560. columns = ['country', 'm1']
  1561. ndf = df[cols]
  1562. df = ndf
  1563. df.columns = columns
  1564. d = df.to_dict(orient='records')
  1565. for row in d:
  1566. country = None
  1567. if isinstance(row['country'], str):
  1568. country = countries.get(
  1569. fd.get('country_fieldtype'), row['country'])
  1570. if country:
  1571. row['country'] = country['cca3']
  1572. row['latitude'] = country['lat']
  1573. row['longitude'] = country['lng']
  1574. row['name'] = country['name']
  1575. else:
  1576. row['country'] = 'XXX'
  1577. return d
  1578. class FilterBoxViz(BaseViz):
  1579. """A multi filter, multi-choice filter box to make dashboards interactive"""
  1580. viz_type = 'filter_box'
  1581. verbose_name = _('Filters')
  1582. is_timeseries = False
  1583. credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
  1584. cache_type = 'get_data'
  1585. def query_obj(self):
  1586. return None
  1587. def run_extra_queries(self):
  1588. qry = self.filter_query_obj()
  1589. filters = [g for g in self.form_data['groupby']]
  1590. self.dataframes = {}
  1591. for flt in filters:
  1592. qry['groupby'] = [flt]
  1593. df = self.get_df_payload(query_obj=qry).get('df')
  1594. self.dataframes[flt] = df
  1595. def filter_query_obj(self):
  1596. qry = super(FilterBoxViz, self).query_obj()
  1597. groupby = self.form_data.get('groupby')
  1598. if len(groupby) < 1 and not self.form_data.get('date_filter'):
  1599. raise Exception(_('Pick at least one filter field'))
  1600. qry['metrics'] = [
  1601. self.form_data['metric']]
  1602. return qry
  1603. def get_data(self, df):
  1604. d = {}
  1605. filters = [g for g in self.form_data['groupby']]
  1606. for flt in filters:
  1607. df = self.dataframes[flt]
  1608. d[flt] = [{
  1609. 'id': row[0],
  1610. 'text': row[0],
  1611. 'filter': flt,
  1612. 'metric': row[1]}
  1613. for row in df.itertuples(index=False)
  1614. ]
  1615. return d
  1616. class IFrameViz(BaseViz):
  1617. """You can squeeze just about anything in this iFrame component"""
  1618. viz_type = 'iframe'
  1619. verbose_name = _('iFrame')
  1620. credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original'
  1621. is_timeseries = False
  1622. def query_obj(self):
  1623. return None
  1624. def get_df(self, query_obj=None):
  1625. return None
  1626. class ParallelCoordinatesViz(BaseViz):
  1627. """Interactive parallel coordinate implementation
  1628. Uses this amazing javascript library
  1629. https://github.com/syntagmatic/parallel-coordinates
  1630. """
  1631. viz_type = 'para'
  1632. verbose_name = _('Parallel Coordinates')
  1633. credits = (
  1634. '<a href="https://syntagmatic.github.io/parallel-coordinates/">'
  1635. "Syntagmatic's library</a>")
  1636. is_timeseries = False
  1637. def query_obj(self):
  1638. d = super(ParallelCoordinatesViz, self).query_obj()
  1639. fd = self.form_data
  1640. d['groupby'] = [fd.get('series')]
  1641. return d
  1642. def get_data(self, df):
  1643. return df.to_dict(orient='records')
  1644. class HeatmapViz(BaseViz):
  1645. """A nice heatmap visualization that support high density through canvas"""
  1646. viz_type = 'heatmap'
  1647. verbose_name = _('Heatmap')
  1648. is_timeseries = False
  1649. credits = (
  1650. 'inspired from mbostock @<a href="http://bl.ocks.org/mbostock/3074470">'
  1651. 'bl.ocks.org</a>')
  1652. def query_obj(self):
  1653. d = super(HeatmapViz, self).query_obj()
  1654. fd = self.form_data
  1655. d['metrics'] = [fd.get('metric')]
  1656. d['groupby'] = [fd.get('all_columns_x'), fd.get('all_columns_y')]
  1657. return d
  1658. def get_data(self, df):
  1659. fd = self.form_data
  1660. x = fd.get('all_columns_x')
  1661. y = fd.get('all_columns_y')
  1662. v = self.metric_labels[0]
  1663. if x == y:
  1664. df.columns = ['x', 'y', 'v']
  1665. else:
  1666. df = df[[x, y, v]]
  1667. df.columns = ['x', 'y', 'v']
  1668. norm = fd.get('normalize_across')
  1669. overall = False
  1670. max_ = df.v.max()
  1671. min_ = df.v.min()
  1672. if norm == 'heatmap':
  1673. overall = True
  1674. else:
  1675. gb = df.groupby(norm, group_keys=False)
  1676. if len(gb) <= 1:
  1677. overall = True
  1678. else:
  1679. df['perc'] = (
  1680. gb.apply(
  1681. lambda x: (x.v - x.v.min()) / (x.v.max() - x.v.min()))
  1682. )
  1683. df['rank'] = gb.apply(lambda x: x.v.rank(pct=True))
  1684. if overall:
  1685. df['perc'] = (df.v - min_) / (max_ - min_)
  1686. df['rank'] = df.v.rank(pct=True)
  1687. return {
  1688. 'records': df.to_dict(orient='records'),
  1689. 'extents': [min_, max_],
  1690. }
  1691. class HorizonViz(NVD3TimeSeriesViz):
  1692. """Horizon chart
  1693. https://www.npmjs.com/package/d3-horizon-chart
  1694. """
  1695. viz_type = 'horizon'
  1696. verbose_name = _('Horizon Charts')
  1697. credits = (
  1698. '<a href="https://www.npmjs.com/package/d3-horizon-chart">'
  1699. 'd3-horizon-chart</a>')
  1700. class MapboxViz(BaseViz):
  1701. """Rich maps made with Mapbox"""
  1702. viz_type = 'mapbox'
  1703. verbose_name = _('Mapbox')
  1704. is_timeseries = False
  1705. credits = (
  1706. '<a href=https://www.mapbox.com/mapbox-gl-js/api/>Mapbox GL JS</a>')
  1707. def query_obj(self):
  1708. d = super(MapboxViz, self).query_obj()
  1709. fd = self.form_data
  1710. label_col = fd.get('mapbox_label')
  1711. if not fd.get('groupby'):
  1712. d['columns'] = [fd.get('all_columns_x'), fd.get('all_columns_y')]
  1713. if label_col and len(label_col) >= 1:
  1714. if label_col[0] == 'count':
  1715. raise Exception(_(
  1716. "Must have a [Group By] column to have 'count' as the [Label]"))
  1717. d['columns'].append(label_col[0])
  1718. if fd.get('point_radius') != 'Auto':
  1719. d['columns'].append(fd.get('point_radius'))
  1720. d['columns'] = list(set(d['columns']))
  1721. else:
  1722. # Ensuring columns chosen are all in group by
  1723. if (label_col and len(label_col) >= 1 and
  1724. label_col[0] != 'count' and
  1725. label_col[0] not in fd.get('groupby')):
  1726. raise Exception(_(
  1727. 'Choice of [Label] must be present in [Group By]'))
  1728. if (fd.get('point_radius') != 'Auto' and
  1729. fd.get('point_radius') not in fd.get('groupby')):
  1730. raise Exception(_(
  1731. 'Choice of [Point Radius] must be present in [Group By]'))
  1732. if (fd.get('all_columns_x') not in fd.get('groupby') or
  1733. fd.get('all_columns_y') not in fd.get('groupby')):
  1734. raise Exception(_(
  1735. '[Longitude] and [Latitude] columns must be present in [Group By]'))
  1736. return d
  1737. def get_data(self, df):
  1738. if df is None:
  1739. return None
  1740. fd = self.form_data
  1741. label_col = fd.get('mapbox_label')
  1742. has_custom_metric = label_col is not None and len(label_col) > 0
  1743. metric_col = [None] * len(df.index)
  1744. if has_custom_metric:
  1745. if label_col[0] == fd.get('all_columns_x'):
  1746. metric_col = df[fd.get('all_columns_x')]
  1747. elif label_col[0] == fd.get('all_columns_y'):
  1748. metric_col = df[fd.get('all_columns_y')]
  1749. else:
  1750. metric_col = df[label_col[0]]
  1751. point_radius_col = (
  1752. [None] * len(df.index)
  1753. if fd.get('point_radius') == 'Auto'
  1754. else df[fd.get('point_radius')])
  1755. # using geoJSON formatting
  1756. geo_json = {
  1757. 'type': 'FeatureCollection',
  1758. 'features': [
  1759. {
  1760. 'type': 'Feature',
  1761. 'properties': {
  1762. 'metric': metric,
  1763. 'radius': point_radius,
  1764. },
  1765. 'geometry': {
  1766. 'type': 'Point',
  1767. 'coordinates': [lon, lat],
  1768. },
  1769. }
  1770. for lon, lat, metric, point_radius
  1771. in zip(
  1772. df[fd.get('all_columns_x')],
  1773. df[fd.get('all_columns_y')],
  1774. metric_col, point_radius_col)
  1775. ],
  1776. }
  1777. x_series, y_series = df[fd.get('all_columns_x')], df[fd.get('all_columns_y')]
  1778. south_west = [x_series.min(), y_series.min()]
  1779. north_east = [x_series.max(), y_series.max()]
  1780. return {
  1781. 'geoJSON': geo_json,
  1782. 'hasCustomMetric': has_custom_metric,
  1783. 'mapboxApiKey': config.get('MAPBOX_API_KEY'),
  1784. 'mapStyle': fd.get('mapbox_style'),
  1785. 'aggregatorName': fd.get('pandas_aggfunc'),
  1786. 'clusteringRadius': fd.get('clustering_radius'),
  1787. 'pointRadiusUnit': fd.get('point_radius_unit'),
  1788. 'globalOpacity': fd.get('global_opacity'),
  1789. 'bounds': [south_west, north_east],
  1790. 'renderWhileDragging': fd.get('render_while_dragging'),
  1791. 'tooltip': fd.get('rich_tooltip'),
  1792. 'color': fd.get('mapbox_color'),
  1793. }
  1794. class DeckGLMultiLayer(BaseViz):
  1795. """Pile on multiple DeckGL layers"""
  1796. viz_type = 'deck_multi'
  1797. verbose_name = _('Deck.gl - Multiple Layers')
  1798. is_timeseries = False
  1799. credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>'
  1800. def query_obj(self):
  1801. return None
  1802. def get_data(self, df):
  1803. fd = self.form_data
  1804. # Late imports to avoid circular import issues
  1805. from superset.models.core import Slice
  1806. from superset import db
  1807. slice_ids = fd.get('deck_slices')
  1808. slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
  1809. return {
  1810. 'mapboxApiKey': config.get('MAPBOX_API_KEY'),
  1811. 'slices': [slc.data for slc in slices],
  1812. }
  1813. class BaseDeckGLViz(BaseViz):
  1814. """Base class for deck.gl visualizations"""
  1815. is_timeseries = False
  1816. credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>'
  1817. spatial_control_keys = []
  1818. def handle_nulls(self, df):
  1819. return df
  1820. def get_metrics(self):
  1821. self.metric = self.form_data.get('size')
  1822. return [self.metric] if self.metric else []
  1823. def process_spatial_query_obj(self, key, group_by):
  1824. group_by.extend(self.get_spatial_columns(key))
  1825. def get_spatial_columns(self, key):
  1826. spatial = self.form_data.get(key)
  1827. if spatial is None:
  1828. raise ValueError(_('Bad spatial key'))
  1829. if spatial.get('type') == 'latlong':
  1830. return [spatial.get('lonCol'), spatial.get('latCol')]
  1831. elif spatial.get('type') == 'delimited':
  1832. return [spatial.get('lonlatCol')]
  1833. elif spatial.get('type') == 'geohash':
  1834. return [spatial.get('geohashCol')]
  1835. @staticmethod
  1836. def parse_coordinates(s):
  1837. if not s:
  1838. return None
  1839. try:
  1840. p = Point(s)
  1841. except Exception:
  1842. raise SpatialException(
  1843. _('Invalid spatial point encountered: %s' % s))
  1844. return (p.latitude, p.longitude)
  1845. @staticmethod
  1846. def reverse_geohash_decode(geohash_code):
  1847. lat, lng = geohash.decode(geohash_code)
  1848. return (lng, lat)
  1849. @staticmethod
  1850. def reverse_latlong(df, key):
  1851. df[key] = [
  1852. tuple(reversed(o))
  1853. for o in df[key]
  1854. if isinstance(o, (list, tuple))
  1855. ]
  1856. def process_spatial_data_obj(self, key, df):
  1857. spatial = self.form_data.get(key)
  1858. if spatial is None:
  1859. raise ValueError(_('Bad spatial key'))
  1860. if spatial.get('type') == 'latlong':
  1861. df[key] = list(zip(
  1862. pd.to_numeric(df[spatial.get('lonCol')], errors='coerce'),
  1863. pd.to_numeric(df[spatial.get('latCol')], errors='coerce'),
  1864. ))
  1865. elif spatial.get('type') == 'delimited':
  1866. lon_lat_col = spatial.get('lonlatCol')
  1867. df[key] = df[lon_lat_col].apply(self.parse_coordinates)
  1868. del df[lon_lat_col]
  1869. elif spatial.get('type') == 'geohash':
  1870. df[key] = df[spatial.get('geohashCol')].map(self.reverse_geohash_decode)
  1871. del df[spatial.get('geohashCol')]
  1872. if spatial.get('reverseCheckbox'):
  1873. self.reverse_latlong(df, key)
  1874. if df.get(key) is None:
  1875. raise NullValueException(_('Encountered invalid NULL spatial entry, \
  1876. please consider filtering those out'))
  1877. return df
  1878. def add_null_filters(self):
  1879. fd = self.form_data
  1880. spatial_columns = set()
  1881. for key in self.spatial_control_keys:
  1882. for column in self.get_spatial_columns(key):
  1883. spatial_columns.add(column)
  1884. if fd.get('adhoc_filters') is None:
  1885. fd['adhoc_filters'] = []
  1886. line_column = fd.get('line_column')
  1887. if line_column:
  1888. spatial_columns.add(line_column)
  1889. for column in sorted(spatial_columns):
  1890. filter_ = to_adhoc({
  1891. 'col': column,
  1892. 'op': 'IS NOT NULL',
  1893. 'val': '',
  1894. })
  1895. fd['adhoc_filters'].append(filter_)
  1896. def query_obj(self):
  1897. fd = self.form_data
  1898. # add NULL filters
  1899. if fd.get('filter_nulls', True):
  1900. self.add_null_filters()
  1901. d = super(BaseDeckGLViz, self).query_obj()
  1902. gb = []
  1903. for key in self.spatial_control_keys:
  1904. self.process_spatial_query_obj(key, gb)
  1905. if fd.get('dimension'):
  1906. gb += [fd.get('dimension')]
  1907. if fd.get('js_columns'):
  1908. gb += fd.get('js_columns')
  1909. metrics = self.get_metrics()
  1910. gb = list(set(gb))
  1911. if metrics:
  1912. d['groupby'] = gb
  1913. d['metrics'] = metrics
  1914. d['columns'] = []
  1915. else:
  1916. d['columns'] = gb
  1917. return d
  1918. def get_js_columns(self, d):
  1919. cols = self.form_data.get('js_columns') or []
  1920. return {col: d.get(col) for col in cols}
  1921. def get_data(self, df):
  1922. if df is None:
  1923. return None
  1924. # Processing spatial info
  1925. for key in self.spatial_control_keys:
  1926. df = self.process_spatial_data_obj(key, df)
  1927. features = []
  1928. for d in df.to_dict(orient='records'):
  1929. feature = self.get_properties(d)
  1930. extra_props = self.get_js_columns(d)
  1931. if extra_props:
  1932. feature['extraProps'] = extra_props
  1933. features.append(feature)
  1934. return {
  1935. 'features': features,
  1936. 'mapboxApiKey': config.get('MAPBOX_API_KEY'),
  1937. 'metricLabels': self.metric_labels,
  1938. }
  1939. def get_properties(self, d):
  1940. raise NotImplementedError()
  1941. class DeckScatterViz(BaseDeckGLViz):
  1942. """deck.gl's ScatterLayer"""
  1943. viz_type = 'deck_scatter'
  1944. verbose_name = _('Deck.gl - Scatter plot')
  1945. spatial_control_keys = ['spatial']
  1946. is_timeseries = True
  1947. def query_obj(self):
  1948. fd = self.form_data
  1949. self.is_timeseries = bool(
  1950. fd.get('time_grain_sqla') or fd.get('granularity'))
  1951. self.point_radius_fixed = (
  1952. fd.get('point_radius_fixed') or {'type': 'fix', 'value': 500})
  1953. return super(DeckScatterViz, self).query_obj()
  1954. def get_metrics(self):
  1955. self.metric = None
  1956. if self.point_radius_fixed.get('type') == 'metric':
  1957. self.metric = self.point_radius_fixed.get('value')
  1958. return [self.metric]
  1959. return None
  1960. def get_properties(self, d):
  1961. return {
  1962. 'metric': d.get(self.metric_label),
  1963. 'radius': self.fixed_value if self.fixed_value else d.get(self.metric_label),
  1964. 'cat_color': d.get(self.dim) if self.dim else None,
  1965. 'position': d.get('spatial'),
  1966. DTTM_ALIAS: d.get(DTTM_ALIAS),
  1967. }
  1968. def get_data(self, df):
  1969. fd = self.form_data
  1970. self.metric_label = \
  1971. self.get_metric_label(self.metric) if self.metric else None
  1972. self.point_radius_fixed = fd.get('point_radius_fixed')
  1973. self.fixed_value = None
  1974. self.dim = self.form_data.get('dimension')
  1975. if self.point_radius_fixed.get('type') != 'metric':
  1976. self.fixed_value = self.point_radius_fixed.get('value')
  1977. return super(DeckScatterViz, self).get_data(df)
  1978. class DeckScreengrid(BaseDeckGLViz):
  1979. """deck.gl's ScreenGridLayer"""
  1980. viz_type = 'deck_screengrid'
  1981. verbose_name = _('Deck.gl - Screen Grid')
  1982. spatial_control_keys = ['spatial']
  1983. is_timeseries = True
  1984. def query_obj(self):
  1985. fd = self.form_data
  1986. self.is_timeseries = fd.get('time_grain_sqla') or fd.get('granularity')
  1987. return super(DeckScreengrid, self).query_obj()
  1988. def get_properties(self, d):
  1989. return {
  1990. 'position': d.get('spatial'),
  1991. 'weight': d.get(self.metric_label) or 1,
  1992. '__timestamp': d.get(DTTM_ALIAS) or d.get('__time'),
  1993. }
  1994. def get_data(self, df):
  1995. self.metric_label = self.get_metric_label(self.metric)
  1996. return super(DeckScreengrid, self).get_data(df)
  1997. class DeckGrid(BaseDeckGLViz):
  1998. """deck.gl's DeckLayer"""
  1999. viz_type = 'deck_grid'
  2000. verbose_name = _('Deck.gl - 3D Grid')
  2001. spatial_control_keys = ['spatial']
  2002. def get_properties(self, d):
  2003. return {
  2004. 'position': d.get('spatial'),
  2005. 'weight': d.get(self.metric_label) or 1,
  2006. }
  2007. def get_data(self, df):
  2008. self.metric_label = self.get_metric_label(self.metric)
  2009. return super(DeckGrid, self).get_data(df)
  2010. def geohash_to_json(geohash_code):
  2011. p = geohash.bbox(geohash_code)
  2012. return [
  2013. [p.get('w'), p.get('n')],
  2014. [p.get('e'), p.get('n')],
  2015. [p.get('e'), p.get('s')],
  2016. [p.get('w'), p.get('s')],
  2017. [p.get('w'), p.get('n')],
  2018. ]
  2019. class DeckPathViz(BaseDeckGLViz):
  2020. """deck.gl's PathLayer"""
  2021. viz_type = 'deck_path'
  2022. verbose_name = _('Deck.gl - Paths')
  2023. deck_viz_key = 'path'
  2024. deser_map = {
  2025. 'json': json.loads,
  2026. 'polyline': polyline.decode,
  2027. 'geohash': geohash_to_json,
  2028. }
  2029. def query_obj(self):
  2030. d = super(DeckPathViz, self).query_obj()
  2031. line_col = self.form_data.get('line_column')
  2032. if d['metrics']:
  2033. self.has_metrics = True
  2034. d['groupby'].append(line_col)
  2035. else:
  2036. self.has_metrics = False
  2037. d['columns'].append(line_col)
  2038. return d
  2039. def get_properties(self, d):
  2040. fd = self.form_data
  2041. line_type = fd.get('line_type')
  2042. deser = self.deser_map[line_type]
  2043. line_column = fd.get('line_column')
  2044. path = deser(d[line_column])
  2045. if fd.get('reverse_long_lat'):
  2046. path = [(o[1], o[0]) for o in path]
  2047. d[self.deck_viz_key] = path
  2048. if line_type != 'geohash':
  2049. del d[line_column]
  2050. return d
  2051. class DeckPolygon(DeckPathViz):
  2052. """deck.gl's Polygon Layer"""
  2053. viz_type = 'deck_polygon'
  2054. deck_viz_key = 'polygon'
  2055. verbose_name = _('Deck.gl - Polygon')
  2056. class DeckHex(BaseDeckGLViz):
  2057. """deck.gl's DeckLayer"""
  2058. viz_type = 'deck_hex'
  2059. verbose_name = _('Deck.gl - 3D HEX')
  2060. spatial_control_keys = ['spatial']
  2061. def get_properties(self, d):
  2062. return {
  2063. 'position': d.get('spatial'),
  2064. 'weight': d.get(self.metric_label) or 1,
  2065. }
  2066. def get_data(self, df):
  2067. self.metric_label = self.get_metric_label(self.metric)
  2068. return super(DeckHex, self).get_data(df)
  2069. class DeckGeoJson(BaseDeckGLViz):
  2070. """deck.gl's GeoJSONLayer"""
  2071. viz_type = 'deck_geojson'
  2072. verbose_name = _('Deck.gl - GeoJSON')
  2073. def query_obj(self):
  2074. d = super(DeckGeoJson, self).query_obj()
  2075. d['columns'] += [self.form_data.get('geojson')]
  2076. d['metrics'] = []
  2077. d['groupby'] = []
  2078. return d
  2079. def get_properties(self, d):
  2080. geojson = d.get(self.form_data.get('geojson'))
  2081. return json.loads(geojson)
  2082. class DeckArc(BaseDeckGLViz):
  2083. """deck.gl's Arc Layer"""
  2084. viz_type = 'deck_arc'
  2085. verbose_name = _('Deck.gl - Arc')
  2086. spatial_control_keys = ['start_spatial', 'end_spatial']
  2087. is_timeseries = True
  2088. def query_obj(self):
  2089. fd = self.form_data
  2090. self.is_timeseries = bool(
  2091. fd.get('time_grain_sqla') or fd.get('granularity'))
  2092. return super(DeckArc, self).query_obj()
  2093. def get_properties(self, d):
  2094. dim = self.form_data.get('dimension')
  2095. return {
  2096. 'sourcePosition': d.get('start_spatial'),
  2097. 'targetPosition': d.get('end_spatial'),
  2098. 'cat_color': d.get(dim) if dim else None,
  2099. DTTM_ALIAS: d.get(DTTM_ALIAS),
  2100. }
  2101. def get_data(self, df):
  2102. d = super(DeckArc, self).get_data(df)
  2103. return {
  2104. 'features': d['features'],
  2105. 'mapboxApiKey': config.get('MAPBOX_API_KEY'),
  2106. }
  2107. class EventFlowViz(BaseViz):
  2108. """A visualization to explore patterns in event sequences"""
  2109. viz_type = 'event_flow'
  2110. verbose_name = _('Event flow')
  2111. credits = 'from <a href="https://github.com/williaster/data-ui">@data-ui</a>'
  2112. is_timeseries = True
  2113. def query_obj(self):
  2114. query = super(EventFlowViz, self).query_obj()
  2115. form_data = self.form_data
  2116. event_key = form_data.get('all_columns_x')
  2117. entity_key = form_data.get('entity')
  2118. meta_keys = [
  2119. col for col in form_data.get('all_columns')
  2120. if col != event_key and col != entity_key
  2121. ]
  2122. query['columns'] = [event_key, entity_key] + meta_keys
  2123. if form_data['order_by_entity']:
  2124. query['orderby'] = [(entity_key, True)]
  2125. return query
  2126. def get_data(self, df):
  2127. return df.to_dict(orient='records')
  2128. class PairedTTestViz(BaseViz):
  2129. """A table displaying paired t-test values"""
  2130. viz_type = 'paired_ttest'
  2131. verbose_name = _('Time Series - Paired t-test')
  2132. sort_series = False
  2133. is_timeseries = True
  2134. def get_data(self, df):
  2135. """
  2136. Transform received data frame into an object of the form:
  2137. {
  2138. 'metric1': [
  2139. {
  2140. groups: ('groupA', ... ),
  2141. values: [ {x, y}, ... ],
  2142. }, ...
  2143. ], ...
  2144. }
  2145. """
  2146. fd = self.form_data
  2147. groups = fd.get('groupby')
  2148. metrics = fd.get('metrics')
  2149. df.fillna(0)
  2150. df = df.pivot_table(
  2151. index=DTTM_ALIAS,
  2152. columns=groups,
  2153. values=metrics)
  2154. cols = []
  2155. # Be rid of falsey keys
  2156. for col in df.columns:
  2157. if col == '':
  2158. cols.append('N/A')
  2159. elif col is None:
  2160. cols.append('NULL')
  2161. else:
  2162. cols.append(col)
  2163. df.columns = cols
  2164. data = {}
  2165. series = df.to_dict('series')
  2166. for nameSet in df.columns:
  2167. # If no groups are defined, nameSet will be the metric name
  2168. hasGroup = not isinstance(nameSet, str)
  2169. Y = series[nameSet]
  2170. d = {
  2171. 'group': nameSet[1:] if hasGroup else 'All',
  2172. 'values': [
  2173. {'x': t, 'y': Y[t] if t in Y else None}
  2174. for t in df.index
  2175. ],
  2176. }
  2177. key = nameSet[0] if hasGroup else nameSet
  2178. if key in data:
  2179. data[key].append(d)
  2180. else:
  2181. data[key] = [d]
  2182. return data
  2183. class RoseViz(NVD3TimeSeriesViz):
  2184. viz_type = 'rose'
  2185. verbose_name = _('Time Series - Nightingale Rose Chart')
  2186. sort_series = False
  2187. is_timeseries = True
  2188. def get_data(self, df):
  2189. data = super(RoseViz, self).get_data(df)
  2190. result = {}
  2191. for datum in data:
  2192. key = datum['key']
  2193. for val in datum['values']:
  2194. timestamp = val['x'].value
  2195. if not result.get(timestamp):
  2196. result[timestamp] = []
  2197. value = 0 if math.isnan(val['y']) else val['y']
  2198. result[timestamp].append({
  2199. 'key': key,
  2200. 'value': value,
  2201. 'name': ', '.join(key) if isinstance(key, list) else key,
  2202. 'time': val['x'],
  2203. })
  2204. return result
  2205. class PartitionViz(NVD3TimeSeriesViz):
  2206. """
  2207. A hierarchical data visualization with support for time series.
  2208. """
  2209. viz_type = 'partition'
  2210. verbose_name = _('Partition Diagram')
  2211. def query_obj(self):
  2212. query_obj = super(PartitionViz, self).query_obj()
  2213. time_op = self.form_data.get('time_series_option', 'not_time')
  2214. # Return time series data if the user specifies so
  2215. query_obj['is_timeseries'] = time_op != 'not_time'
  2216. return query_obj
  2217. def levels_for(self, time_op, groups, df):
  2218. """
  2219. Compute the partition at each `level` from the dataframe.
  2220. """
  2221. levels = {}
  2222. for i in range(0, len(groups) + 1):
  2223. agg_df = df.groupby(groups[:i]) if i else df
  2224. levels[i] = (
  2225. agg_df.mean() if time_op == 'agg_mean'
  2226. else agg_df.sum(numeric_only=True))
  2227. return levels
  2228. def levels_for_diff(self, time_op, groups, df):
  2229. # Obtain a unique list of the time grains
  2230. times = list(set(df[DTTM_ALIAS]))
  2231. times.sort()
  2232. until = times[len(times) - 1]
  2233. since = times[0]
  2234. # Function describing how to calculate the difference
  2235. func = {
  2236. 'point_diff': [
  2237. pd.Series.sub,
  2238. lambda a, b, fill_value: a - b,
  2239. ],
  2240. 'point_factor': [
  2241. pd.Series.div,
  2242. lambda a, b, fill_value: a / float(b),
  2243. ],
  2244. 'point_percent': [
  2245. lambda a, b, fill_value=0: a.div(b, fill_value=fill_value) - 1,
  2246. lambda a, b, fill_value: a / float(b) - 1,
  2247. ],
  2248. }[time_op]
  2249. agg_df = df.groupby(DTTM_ALIAS).sum()
  2250. levels = {0: pd.Series({
  2251. m: func[1](agg_df[m][until], agg_df[m][since], 0)
  2252. for m in agg_df.columns})}
  2253. for i in range(1, len(groups) + 1):
  2254. agg_df = df.groupby([DTTM_ALIAS] + groups[:i]).sum()
  2255. levels[i] = pd.DataFrame({
  2256. m: func[0](agg_df[m][until], agg_df[m][since], fill_value=0)
  2257. for m in agg_df.columns})
  2258. return levels
  2259. def levels_for_time(self, groups, df):
  2260. procs = {}
  2261. for i in range(0, len(groups) + 1):
  2262. self.form_data['groupby'] = groups[:i]
  2263. df_drop = df.drop(groups[i:], 1)
  2264. procs[i] = self.process_data(df_drop, aggregate=True).fillna(0)
  2265. self.form_data['groupby'] = groups
  2266. return procs
  2267. def nest_values(self, levels, level=0, metric=None, dims=()):
  2268. """
  2269. Nest values at each level on the back-end with
  2270. access and setting, instead of summing from the bottom.
  2271. """
  2272. if not level:
  2273. return [{
  2274. 'name': m,
  2275. 'val': levels[0][m],
  2276. 'children': self.nest_values(levels, 1, m),
  2277. } for m in levels[0].index]
  2278. if level == 1:
  2279. return [{
  2280. 'name': i,
  2281. 'val': levels[1][metric][i],
  2282. 'children': self.nest_values(levels, 2, metric, (i,)),
  2283. } for i in levels[1][metric].index]
  2284. if level >= len(levels):
  2285. return []
  2286. return [{
  2287. 'name': i,
  2288. 'val': levels[level][metric][dims][i],
  2289. 'children': self.nest_values(
  2290. levels, level + 1, metric, dims + (i,),
  2291. ),
  2292. } for i in levels[level][metric][dims].index]
  2293. def nest_procs(self, procs, level=-1, dims=(), time=None):
  2294. if level == -1:
  2295. return [{
  2296. 'name': m,
  2297. 'children': self.nest_procs(procs, 0, (m,)),
  2298. } for m in procs[0].columns]
  2299. if not level:
  2300. return [{
  2301. 'name': t,
  2302. 'val': procs[0][dims[0]][t],
  2303. 'children': self.nest_procs(procs, 1, dims, t),
  2304. } for t in procs[0].index]
  2305. if level >= len(procs):
  2306. return []
  2307. return [{
  2308. 'name': i,
  2309. 'val': procs[level][dims][i][time],
  2310. 'children': self.nest_procs(procs, level + 1, dims + (i,), time),
  2311. } for i in procs[level][dims].columns]
  2312. def get_data(self, df):
  2313. fd = self.form_data
  2314. groups = fd.get('groupby', [])
  2315. time_op = fd.get('time_series_option', 'not_time')
  2316. if not len(groups):
  2317. raise ValueError('Please choose at least one groupby')
  2318. if time_op == 'not_time':
  2319. levels = self.levels_for('agg_sum', groups, df)
  2320. elif time_op in ['agg_sum', 'agg_mean']:
  2321. levels = self.levels_for(time_op, groups, df)
  2322. elif time_op in ['point_diff', 'point_factor', 'point_percent']:
  2323. levels = self.levels_for_diff(time_op, groups, df)
  2324. elif time_op == 'adv_anal':
  2325. procs = self.levels_for_time(groups, df)
  2326. return self.nest_procs(procs)
  2327. else:
  2328. levels = self.levels_for('agg_sum', [DTTM_ALIAS] + groups, df)
  2329. return self.nest_values(levels)
  2330. viz_types = {
  2331. o.viz_type: o for o in globals().values()
  2332. if (
  2333. inspect.isclass(o) and
  2334. issubclass(o, BaseViz) and
  2335. o.viz_type not in config.get('VIZ_TYPE_BLACKLIST'))}