sqla_models_tests.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. # isort:skip_file
  18. import tests.test_app
  19. from superset.connectors.sqla.models import SqlaTable, TableColumn
  20. from superset.db_engine_specs.druid import DruidEngineSpec
  21. from superset.utils.core import get_example_database
  22. from .base_tests import SupersetTestCase
  23. class DatabaseModelTestCase(SupersetTestCase):
  24. def test_is_time_druid_time_col(self):
  25. """Druid has a special __time column"""
  26. col = TableColumn(column_name="__time", type="INTEGER")
  27. self.assertEqual(col.is_dttm, None)
  28. DruidEngineSpec.alter_new_orm_column(col)
  29. self.assertEqual(col.is_dttm, True)
  30. col = TableColumn(column_name="__not_time", type="INTEGER")
  31. self.assertEqual(col.is_time, False)
  32. def test_is_time_by_type(self):
  33. col = TableColumn(column_name="foo", type="DATE")
  34. self.assertEqual(col.is_time, True)
  35. col = TableColumn(column_name="foo", type="DATETIME")
  36. self.assertEqual(col.is_time, True)
  37. col = TableColumn(column_name="foo", type="STRING")
  38. self.assertEqual(col.is_time, False)
  39. def test_has_extra_cache_keys(self):
  40. query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
  41. table = SqlaTable(
  42. table_name="test_has_extra_cache_keys_table",
  43. sql=query,
  44. database=get_example_database(),
  45. )
  46. query_obj = {
  47. "granularity": None,
  48. "from_dttm": None,
  49. "to_dttm": None,
  50. "groupby": ["user"],
  51. "metrics": [],
  52. "is_timeseries": False,
  53. "filter": [],
  54. "extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
  55. }
  56. extra_cache_keys = table.get_extra_cache_keys(query_obj)
  57. self.assertTrue(table.has_calls_to_cache_key_wrapper(query_obj))
  58. self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])
  59. def test_has_no_extra_cache_keys(self):
  60. query = "SELECT 'abc' as user"
  61. table = SqlaTable(
  62. table_name="test_has_no_extra_cache_keys_table",
  63. sql=query,
  64. database=get_example_database(),
  65. )
  66. query_obj = {
  67. "granularity": None,
  68. "from_dttm": None,
  69. "to_dttm": None,
  70. "groupby": ["user"],
  71. "metrics": [],
  72. "is_timeseries": False,
  73. "filter": [],
  74. "extras": {"where": "(user != 'abc')"},
  75. }
  76. extra_cache_keys = table.get_extra_cache_keys(query_obj)
  77. self.assertFalse(table.has_calls_to_cache_key_wrapper(query_obj))
  78. self.assertListEqual(extra_cache_keys, [])