base_schemas.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Licensed to the Apache Software Foundation (ASF) under one
  2. # or more contributor license agreements. See the NOTICE file
  3. # distributed with this work for additional information
  4. # regarding copyright ownership. The ASF licenses this file
  5. # to you under the Apache License, Version 2.0 (the
  6. # "License"); you may not use this file except in compliance
  7. # with the License. You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing,
  12. # software distributed under the License is distributed on an
  13. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  14. # KIND, either express or implied. See the License for the
  15. # specific language governing permissions and limitations
  16. # under the License.
  17. from typing import Dict, List, Optional
  18. from flask import current_app, g
  19. from flask_appbuilder import Model
  20. from marshmallow import post_load, pre_load, Schema, ValidationError
  21. from sqlalchemy.orm.exc import NoResultFound
  22. def validate_owner(value):
  23. try:
  24. (
  25. current_app.appbuilder.get_session.query(
  26. current_app.appbuilder.sm.user_model.id
  27. )
  28. .filter_by(id=value)
  29. .one()
  30. )
  31. except NoResultFound:
  32. raise ValidationError(f"User {value} does not exist")
  33. class BaseSupersetSchema(Schema):
  34. """
  35. Extends Marshmallow schema so that we can pass a Model to load
  36. (following marshamallow-sqlalchemy pattern). This is useful
  37. to perform partial model merges on HTTP PUT
  38. """
  39. __class_model__: Model = None
  40. def __init__(self, **kwargs):
  41. self.instance: Optional[Model] = None
  42. super().__init__(**kwargs)
  43. def load(
  44. self, data, many=None, partial=None, instance: Model = None, **kwargs
  45. ): # pylint: disable=arguments-differ
  46. self.instance = instance
  47. return super().load(data, many=many, partial=partial, **kwargs)
  48. @post_load
  49. def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
  50. """
  51. Creates a Model object from POST or PUT requests. PUT will use self.instance
  52. previously fetched from the endpoint handler
  53. :param data: Schema data payload
  54. :param discard: List of fields to not set on the model
  55. """
  56. discard = discard or []
  57. if not self.instance:
  58. self.instance = self.__class_model__() # pylint: disable=not-callable
  59. for field in data:
  60. if field not in discard:
  61. setattr(self.instance, field, data.get(field))
  62. return self.instance
  63. class BaseOwnedSchema(BaseSupersetSchema):
  64. """
  65. Implements owners validation,pre load and post_load
  66. (to populate the owners field) on Marshmallow schemas
  67. """
  68. owners_field_name = "owners"
  69. @post_load
  70. def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
  71. discard = discard or []
  72. discard.append(self.owners_field_name)
  73. instance = super().make_object(data, discard)
  74. if "owners" not in data and g.user not in instance.owners:
  75. instance.owners.append(g.user)
  76. if self.owners_field_name in data:
  77. self.set_owners(instance, data[self.owners_field_name])
  78. return instance
  79. @pre_load
  80. def pre_load(self, data: Dict):
  81. # if PUT request don't set owners to empty list
  82. if not self.instance:
  83. data[self.owners_field_name] = data.get(self.owners_field_name, [])
  84. @staticmethod
  85. def set_owners(instance: Model, owners: List[int]):
  86. owner_objs = list()
  87. if g.user.id not in owners:
  88. owners.append(g.user.id)
  89. for owner_id in owners:
  90. user = current_app.appbuilder.get_session.query(
  91. current_app.appbuilder.sm.user_model
  92. ).get(owner_id)
  93. owner_objs.append(user)
  94. instance.owners = owner_objs