apache-airflow-providers-fab 2.4.4rc1__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. airflow/providers/fab/__init__.py +1 -1
  2. airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py +2 -2
  3. airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py +4 -4
  4. airflow/providers/fab/auth_manager/api_fastapi/routes/login.py +7 -4
  5. airflow/providers/fab/auth_manager/cli_commands/user_command.py +1 -2
  6. airflow/providers/fab/auth_manager/cli_commands/utils.py +7 -3
  7. airflow/providers/fab/auth_manager/fab_auth_manager.py +15 -11
  8. airflow/providers/fab/auth_manager/models/__init__.py +179 -122
  9. airflow/providers/fab/auth_manager/models/db.py +11 -6
  10. airflow/providers/fab/auth_manager/security_manager/override.py +239 -213
  11. airflow/providers/fab/auth_manager/views/user.py +11 -5
  12. airflow/providers/fab/migrations/versions/0001_1_4_0_create_ab_tables_if_missing.py +5 -4
  13. airflow/providers/fab/www/app.py +3 -4
  14. airflow/providers/fab/www/extensions/init_appbuilder.py +26 -39
  15. airflow/providers/fab/www/extensions/init_session.py +2 -2
  16. airflow/providers/fab/www/security_appless.py +6 -1
  17. airflow/providers/fab/www/security_manager.py +4 -14
  18. airflow/providers/fab/www/session.py +26 -3
  19. airflow/providers/fab/www/utils.py +1 -208
  20. {apache_airflow_providers_fab-2.4.4rc1.dist-info → apache_airflow_providers_fab-3.0.0.dist-info}/METADATA +18 -12
  21. {apache_airflow_providers_fab-2.4.4rc1.dist-info → apache_airflow_providers_fab-3.0.0.dist-info}/RECORD +25 -25
  22. {apache_airflow_providers_fab-2.4.4rc1.dist-info → apache_airflow_providers_fab-3.0.0.dist-info}/WHEEL +0 -0
  23. {apache_airflow_providers_fab-2.4.4rc1.dist-info → apache_airflow_providers_fab-3.0.0.dist-info}/entry_points.txt +0 -0
  24. {apache_airflow_providers_fab-2.4.4rc1.dist-info → apache_airflow_providers_fab-3.0.0.dist-info}/licenses/3rd-party-licenses/LICENSES-ui.txt +0 -0
  25. {apache_airflow_providers_fab-2.4.4rc1.dist-info → apache_airflow_providers_fab-3.0.0.dist-info}/licenses/NOTICE +0 -0
@@ -24,7 +24,6 @@ from flask_appbuilder.security.views import (
24
24
  UserDBModelView,
25
25
  UserLDAPModelView,
26
26
  UserOAuthModelView,
27
- UserOIDModelView,
28
27
  UserRemoteUserModelView,
29
28
  )
30
29
 
@@ -122,10 +121,6 @@ class CustomUserOAuthModelView(MultiResourceUserMixin, UserOAuthModelView):
122
121
  """Customize permission names for FAB's builtin UserOAuthModelView."""
123
122
 
124
123
 
125
- class CustomUserOIDModelView(MultiResourceUserMixin, UserOIDModelView):
126
- """Customize permission names for FAB's builtin UserOIDModelView."""
127
-
128
-
129
124
  class CustomUserRemoteUserModelView(MultiResourceUserMixin, UserRemoteUserModelView):
130
125
  """Customize permission names for FAB's builtin UserRemoteUserModelView."""
131
126
 
@@ -180,6 +175,17 @@ class CustomUserDBModelView(MultiResourceUserMixin, UserDBModelView):
180
175
  "userinfoedit": "read",
181
176
  }
182
177
 
178
+ add_columns = [
179
+ "first_name",
180
+ "last_name",
181
+ "username",
182
+ "active",
183
+ "email",
184
+ "roles",
185
+ "password",
186
+ "conf_password",
187
+ ]
188
+
183
189
  base_permissions = [
184
190
  permissions.ACTION_CAN_CREATE,
185
191
  permissions.ACTION_CAN_READ,
@@ -121,6 +121,7 @@ def upgrade() -> None:
121
121
  sa.Column("name", sa.String(length=250), nullable=False),
122
122
  sa.PrimaryKeyConstraint("id", name=op.f("ab_view_menu_pkey")),
123
123
  sa.UniqueConstraint("name", name=op.f("ab_view_menu_name_uq")),
124
+ if_not_exists=True,
124
125
  )
125
126
  op.create_table(
126
127
  "ab_group_role",
@@ -138,8 +139,8 @@ def upgrade() -> None:
138
139
  if_not_exists=True,
139
140
  )
140
141
  with op.batch_alter_table("ab_group_role", schema=None) as batch_op:
141
- batch_op.create_index("idx_group_id", ["group_id"], unique=False)
142
- batch_op.create_index("idx_group_role_id", ["role_id"], unique=False)
142
+ batch_op.create_index("idx_group_id", ["group_id"], unique=False, if_not_exists=True)
143
+ batch_op.create_index("idx_group_role_id", ["role_id"], unique=False, if_not_exists=True)
143
144
 
144
145
  op.create_table(
145
146
  "ab_permission_view",
@@ -174,8 +175,8 @@ def upgrade() -> None:
174
175
  if_not_exists=True,
175
176
  )
176
177
  with op.batch_alter_table("ab_user_group", schema=None) as batch_op:
177
- batch_op.create_index("idx_user_group_id", ["group_id"], unique=False)
178
- batch_op.create_index("idx_user_id", ["user_id"], unique=False)
178
+ batch_op.create_index("idx_user_group_id", ["group_id"], unique=False, if_not_exists=True)
179
+ batch_op.create_index("idx_user_id", ["user_id"], unique=False, if_not_exists=True)
179
180
 
180
181
  op.create_table(
181
182
  "ab_user_role",
@@ -21,7 +21,7 @@ from datetime import timedelta
21
21
  from os.path import isabs
22
22
 
23
23
  from flask import Flask
24
- from flask_appbuilder import SQLA
24
+ from flask_sqlalchemy import SQLAlchemy
25
25
  from flask_wtf.csrf import CSRFProtect
26
26
  from sqlalchemy.engine.url import make_url
27
27
 
@@ -84,9 +84,8 @@ def create_app(enable_plugins: bool):
84
84
 
85
85
  csrf.init_app(flask_app)
86
86
 
87
- db = SQLA()
87
+ db = SQLAlchemy(flask_app)
88
88
  db.session = settings.Session
89
- db.init_app(flask_app)
90
89
 
91
90
  configure_logging()
92
91
  configure_manifest_files(flask_app)
@@ -107,8 +106,8 @@ def create_app(enable_plugins: bool):
107
106
  elif isinstance(get_auth_manager(), FabAuthManager):
108
107
  init_api_auth_provider(flask_app)
109
108
  init_api_error_handlers(flask_app)
109
+ init_airflow_session_interface(flask_app, db)
110
110
  init_jinja_globals(flask_app, enable_plugins=enable_plugins)
111
- init_airflow_session_interface(flask_app)
112
111
  init_wsgi_middleware(flask_app)
113
112
  return flask_app
114
113
 
@@ -80,8 +80,6 @@ class AirflowAppBuilder:
80
80
  """This is the base class for all the framework."""
81
81
 
82
82
  baseviews: list[BaseView | Session] = []
83
- # Flask app
84
- app = None
85
83
  # Database Session
86
84
  session = None
87
85
  # Security Manager Class
@@ -149,7 +147,6 @@ class AirflowAppBuilder:
149
147
  self.indexview = indexview
150
148
  self.static_folder = static_folder
151
149
  self.static_url_path = static_url_path
152
- self.app = app
153
150
  self.enable_plugins = enable_plugins
154
151
  self.update_perms = conf.getboolean("fab", "UPDATE_FAB_PERMS")
155
152
  self.auth_rate_limited = conf.getboolean("fab", "AUTH_RATE_LIMITED")
@@ -164,6 +161,7 @@ class AirflowAppBuilder:
164
161
  :param app:
165
162
  :param session: The SQLAlchemy session
166
163
  """
164
+ log.info("Initializing AppBuilder")
167
165
  app.config.setdefault("APP_NAME", "F.A.B.")
168
166
  app.config.setdefault("APP_THEME", "")
169
167
  app.config.setdefault("APP_ICON", "")
@@ -176,8 +174,6 @@ class AirflowAppBuilder:
176
174
  app.config.setdefault("AUTH_RATE_LIMITED", self.auth_rate_limited)
177
175
  app.config.setdefault("AUTH_RATE_LIMIT", self.auth_rate_limit)
178
176
 
179
- self.app = app
180
-
181
177
  self.base_template = app.config.get("FAB_BASE_TEMPLATE", self.base_template)
182
178
  self.static_folder = app.config.get("FAB_STATIC_FOLDER", self.static_folder)
183
179
  self.static_url_path = app.config.get("FAB_STATIC_URL_PATH", self.static_url_path)
@@ -228,24 +224,19 @@ class AirflowAppBuilder:
228
224
  app.extensions["appbuilder"] = self
229
225
 
230
226
  @property
231
- def get_app(self):
232
- """
233
- Get current or configured flask app.
234
-
235
- :return: Flask App
236
- """
237
- if self.app:
238
- return self.app
227
+ def app(self) -> Flask:
228
+ log.warning(
229
+ "appbuilder.app is deprecated and will be removed in a future version. Use current_app instead"
230
+ )
239
231
  return current_app
240
232
 
241
233
  @property
242
- def get_session(self):
243
- """
244
- Get the current sqlalchemy session.
245
-
246
- :return: SQLAlchemy Session
247
- """
248
- return self.session
234
+ def get_app(self) -> Flask:
235
+ log.warning(
236
+ "appbuilder.get_app is deprecated and will be removed in a future version. "
237
+ "Use current_app instead"
238
+ )
239
+ return self.app
249
240
 
250
241
  @property
251
242
  def app_name(self):
@@ -254,7 +245,7 @@ class AirflowAppBuilder:
254
245
 
255
246
  :return: String with app name
256
247
  """
257
- return self.get_app.config["APP_NAME"]
248
+ return current_app.config["APP_NAME"]
258
249
 
259
250
  @property
260
251
  def app_theme(self):
@@ -263,7 +254,7 @@ class AirflowAppBuilder:
263
254
 
264
255
  :return: String app theme name
265
256
  """
266
- return self.get_app.config["APP_THEME"]
257
+ return current_app.config["APP_THEME"]
267
258
 
268
259
  @property
269
260
  def app_icon(self):
@@ -272,11 +263,11 @@ class AirflowAppBuilder:
272
263
 
273
264
  :return: String with relative app icon location
274
265
  """
275
- return self.get_app.config["APP_ICON"]
266
+ return current_app.config["APP_ICON"]
276
267
 
277
268
  @property
278
269
  def languages(self):
279
- return self.get_app.config["LANGUAGES"]
270
+ return current_app.config["LANGUAGES"]
280
271
 
281
272
  @property
282
273
  def version(self):
@@ -288,7 +279,7 @@ class AirflowAppBuilder:
288
279
  return __version__
289
280
 
290
281
  def _add_global_filters(self):
291
- self.template_filters = TemplateFilters(self.get_app, self.sm)
282
+ self.template_filters = TemplateFilters(current_app, self.sm)
292
283
 
293
284
  def _add_global_static(self):
294
285
  bp = Blueprint(
@@ -299,7 +290,7 @@ class AirflowAppBuilder:
299
290
  static_folder=self.static_folder,
300
291
  static_url_path=self.static_url_path,
301
292
  )
302
- self.get_app.register_blueprint(bp)
293
+ current_app.register_blueprint(bp)
303
294
 
304
295
  def _add_admin_views(self):
305
296
  """Register indexview, utilview (back function), babel views and Security views."""
@@ -328,8 +319,6 @@ class AirflowAppBuilder:
328
319
  log.error(LOGMSG_ERR_FAB_ADDON_PROCESS, addon, e)
329
320
 
330
321
  def _check_and_init(self, baseview):
331
- if hasattr(baseview, "datamodel"):
332
- baseview.datamodel.session = self.session
333
322
  if callable(baseview):
334
323
  baseview = baseview()
335
324
  return baseview
@@ -409,16 +398,15 @@ class AirflowAppBuilder:
409
398
  appbuilder.add_link("google", href="www.google.com", icon="fa-google-plus")
410
399
  """
411
400
  baseview = self._check_and_init(baseview)
412
- log.info(LOGMSG_INF_FAB_ADD_VIEW, baseview.__class__.__name__, name)
401
+ log.debug(LOGMSG_INF_FAB_ADD_VIEW, baseview.__class__.__name__, name)
413
402
 
414
403
  if not self._view_exists(baseview):
415
404
  baseview.appbuilder = self
416
405
  self.baseviews.append(baseview)
417
406
  self._process_inner_views()
418
- if self.app:
419
- self.register_blueprint(baseview)
420
- self._add_permission(baseview)
421
- self.add_limits(baseview)
407
+ self.register_blueprint(baseview)
408
+ self._add_permission(baseview)
409
+ self.add_limits(baseview)
422
410
  self.add_link(
423
411
  name=name,
424
412
  href=href,
@@ -512,15 +500,14 @@ class AirflowAppBuilder:
512
500
  :param baseview: A BaseView type class instantiated.
513
501
  """
514
502
  baseview = self._check_and_init(baseview)
515
- log.info(LOGMSG_INF_FAB_ADD_VIEW, baseview.__class__.__name__, "")
503
+ log.debug(LOGMSG_INF_FAB_ADD_VIEW, baseview.__class__.__name__, "")
516
504
 
517
505
  if not self._view_exists(baseview):
518
506
  baseview.appbuilder = self
519
507
  self.baseviews.append(baseview)
520
508
  self._process_inner_views()
521
- if self.app:
522
- self.register_blueprint(baseview, endpoint=endpoint, static_folder=static_folder)
523
- self._add_permission(baseview)
509
+ self.register_blueprint(baseview, endpoint=endpoint, static_folder=static_folder)
510
+ self._add_permission(baseview)
524
511
  else:
525
512
  log.warning(LOGMSG_WAR_FAB_VIEW_EXISTS, baseview.__class__.__name__)
526
513
  return baseview
@@ -588,7 +575,7 @@ class AirflowAppBuilder:
588
575
  self._add_permissions_menu(item.name, update_perms=update_perms)
589
576
 
590
577
  def register_blueprint(self, baseview, endpoint=None, static_folder=None):
591
- self.get_app.register_blueprint(
578
+ current_app.register_blueprint(
592
579
  baseview.create_blueprint(self, endpoint=endpoint, static_folder=static_folder)
593
580
  )
594
581
 
@@ -607,7 +594,7 @@ def init_appbuilder(app: Flask, enable_plugins: bool) -> AirflowAppBuilder:
607
594
  """Init `Flask App Builder <https://flask-appbuilder.readthedocs.io/en/latest/>`__."""
608
595
  return AirflowAppBuilder(
609
596
  app=app,
610
- session=settings.Session,
597
+ session=settings.Session(),
611
598
  base_template="airflow/main.html",
612
599
  enable_plugins=enable_plugins,
613
600
  )
@@ -26,7 +26,7 @@ from airflow.providers.fab.www.session import (
26
26
  )
27
27
 
28
28
 
29
- def init_airflow_session_interface(app):
29
+ def init_airflow_session_interface(app, db):
30
30
  """Set airflow session interface."""
31
31
  config = app.config.copy()
32
32
  selected_backend = conf.get("fab", "SESSION_BACKEND")
@@ -47,7 +47,7 @@ def init_airflow_session_interface(app):
47
47
  elif selected_backend == "database":
48
48
  app.session_interface = AirflowDatabaseSessionInterface(
49
49
  app=app,
50
- db=None,
50
+ client=db,
51
51
  permanent=permanent_cookie,
52
52
  # Typically these would be configurable with Flask-Session,
53
53
  # but we will set them explicitly instead as they don't make
@@ -34,7 +34,7 @@ class FakeAppBuilder:
34
34
  """
35
35
 
36
36
  def __init__(self, session: Session | None = None) -> None:
37
- self.get_session = session
37
+ self.session = session
38
38
 
39
39
 
40
40
  class ApplessAirflowSecurityManager(FabAirflowSecurityManagerOverride):
@@ -42,3 +42,8 @@ class ApplessAirflowSecurityManager(FabAirflowSecurityManagerOverride):
42
42
 
43
43
  def __init__(self, session: Session | None = None):
44
44
  self.appbuilder = FakeAppBuilder(session)
45
+ self._session = session
46
+
47
+ @property
48
+ def session(self):
49
+ return self._session
@@ -18,12 +18,12 @@ from __future__ import annotations
18
18
 
19
19
  from collections.abc import Callable
20
20
 
21
- from flask import g
21
+ from flask import current_app, g
22
22
  from flask_limiter import Limiter
23
23
  from flask_limiter.util import get_remote_address
24
24
 
25
25
  from airflow.api_fastapi.app import get_auth_manager
26
- from airflow.providers.fab.www.utils import CustomSQLAInterface, get_method_from_fab_action_map
26
+ from airflow.providers.fab.www.utils import get_method_from_fab_action_map
27
27
  from airflow.utils.log.logging_mixin import LoggingMixin
28
28
 
29
29
  EXISTING_ROLES = {
@@ -50,15 +50,6 @@ class AirflowSecurityManagerV2(LoggingMixin):
50
50
  # Setup Flask-Limiter
51
51
  self.limiter = self.create_limiter()
52
52
 
53
- # Go and fix up the SQLAInterface used from the stock one to our subclass.
54
- # This is needed to support the "hack" where we had to edit
55
- # FieldConverter.conversion_table in place in utils
56
- for attr in dir(self):
57
- if attr.endswith("view"):
58
- view = getattr(self, attr, None)
59
- if view and getattr(view, "datamodel", None):
60
- view.datamodel = CustomSQLAInterface(view.datamodel.obj)
61
-
62
53
  @staticmethod
63
54
  def before_request():
64
55
  """Run hook before request."""
@@ -66,9 +57,8 @@ class AirflowSecurityManagerV2(LoggingMixin):
66
57
  g.user = get_auth_manager().get_user()
67
58
 
68
59
  def create_limiter(self) -> Limiter:
69
- app = self.appbuilder.get_app
70
- limiter = Limiter(key_func=app.config.get("RATELIMIT_KEY_FUNC", get_remote_address))
71
- limiter.init_app(app)
60
+ limiter = Limiter(key_func=current_app.config.get("RATELIMIT_KEY_FUNC", get_remote_address))
61
+ limiter.init_app(current_app)
72
62
  return limiter
73
63
 
74
64
  def has_access(
@@ -16,18 +16,41 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
+ import msgspec
19
20
  from flask import request
20
21
  from flask.sessions import SecureCookieSessionInterface
21
- from flask_session.sessions import SqlAlchemySessionInterface
22
+ from flask_babel import LazyString
23
+ from flask_session.sqlalchemy import SqlAlchemySessionInterface
24
+
25
+
26
+ class _LazySafeSerializer:
27
+ def encode(self, session_dict):
28
+ encoder = msgspec.msgpack.Encoder(
29
+ enc_hook=lambda obj: str(obj) if isinstance(obj, LazyString) else obj
30
+ )
31
+
32
+ return encoder.encode(dict(session_dict))
33
+
34
+ def decode(self, data):
35
+ decoder = msgspec.msgpack.Decoder()
36
+
37
+ return decoder.decode(data)
38
+
39
+ def _default(self, obj):
40
+ if isinstance(obj, LazyString):
41
+ return str(obj)
42
+ raise TypeError(f"Unsupported type: {type(obj)}")
22
43
 
23
44
 
24
45
  class SessionExemptMixin:
25
46
  """Exempt certain blueprints/paths from autogenerated sessions."""
26
47
 
48
+ def __init__(self, *args, **kwargs):
49
+ super().__init__(*args, **kwargs)
50
+ self.serializer = _LazySafeSerializer()
51
+
27
52
  def save_session(self, *args, **kwargs):
28
53
  """Prevent creating session from REST API and health requests."""
29
- if request.blueprint == "/api/v1":
30
- return None
31
54
  if request.path == "/health":
32
55
  return None
33
56
  return super().save_session(*args, **kwargs)
@@ -18,15 +18,7 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import logging
21
- from typing import TYPE_CHECKING, Any
22
-
23
- from flask_appbuilder.models.filters import BaseFilter
24
- from flask_appbuilder.models.sqla import filters as fab_sqlafilters
25
- from flask_appbuilder.models.sqla.filters import get_field_setup_query, set_value_to_type
26
- from flask_appbuilder.models.sqla.interface import SQLAInterface
27
- from flask_babel import lazy_gettext
28
- from sqlalchemy import types
29
- from sqlalchemy.ext.associationproxy import AssociationProxy
21
+ from typing import TYPE_CHECKING
30
22
 
31
23
  from airflow.api_fastapi.app import get_auth_manager
32
24
  from airflow.configuration import conf
@@ -37,11 +29,8 @@ from airflow.providers.fab.www.security.permissions import (
37
29
  ACTION_CAN_EDIT,
38
30
  ACTION_CAN_READ,
39
31
  )
40
- from airflow.utils import timezone
41
32
 
42
33
  if TYPE_CHECKING:
43
- from sqlalchemy.orm.session import Session
44
-
45
34
  try:
46
35
  from airflow.api_fastapi.auth.managers.base_auth_manager import ExtendedResourceMethod
47
36
  except ImportError:
@@ -95,199 +84,3 @@ def get_method_from_fab_action_map():
95
84
  return {
96
85
  **{v: k for k, v in _MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()},
97
86
  }
98
-
99
-
100
- class UtcAwareFilterMixin:
101
- """Mixin for filter for UTC time."""
102
-
103
- def apply(self, query, value):
104
- """Apply the filter."""
105
- if isinstance(value, str) and not value.strip():
106
- value = None
107
- else:
108
- value = timezone.parse(value, timezone=timezone.utc)
109
-
110
- return super().apply(query, value)
111
-
112
-
113
- class FilterIsNull(BaseFilter):
114
- """Is null filter."""
115
-
116
- name = lazy_gettext("Is Null")
117
- arg_name = "emp"
118
-
119
- def apply(self, query, value):
120
- query, field = get_field_setup_query(query, self.model, self.column_name)
121
- value = set_value_to_type(self.datamodel, self.column_name, None)
122
- return query.filter(field == value)
123
-
124
-
125
- class FilterIsNotNull(BaseFilter):
126
- """Is not null filter."""
127
-
128
- name = lazy_gettext("Is not Null")
129
- arg_name = "nemp"
130
-
131
- def apply(self, query, value):
132
- query, field = get_field_setup_query(query, self.model, self.column_name)
133
- value = set_value_to_type(self.datamodel, self.column_name, None)
134
- return query.filter(field != value)
135
-
136
-
137
- class FilterGreaterOrEqual(BaseFilter):
138
- """Greater than or Equal filter."""
139
-
140
- name = lazy_gettext("Greater than or Equal")
141
- arg_name = "gte"
142
-
143
- def apply(self, query, value):
144
- query, field = get_field_setup_query(query, self.model, self.column_name)
145
- value = set_value_to_type(self.datamodel, self.column_name, value)
146
-
147
- if value is None:
148
- return query
149
-
150
- return query.filter(field >= value)
151
-
152
-
153
- class FilterSmallerOrEqual(BaseFilter):
154
- """Smaller than or Equal filter."""
155
-
156
- name = lazy_gettext("Smaller than or Equal")
157
- arg_name = "lte"
158
-
159
- def apply(self, query, value):
160
- query, field = get_field_setup_query(query, self.model, self.column_name)
161
- value = set_value_to_type(self.datamodel, self.column_name, value)
162
-
163
- if value is None:
164
- return query
165
-
166
- return query.filter(field <= value)
167
-
168
-
169
- class UtcAwareFilterSmallerOrEqual(UtcAwareFilterMixin, FilterSmallerOrEqual):
170
- """Smaller than or Equal filter for UTC time."""
171
-
172
-
173
- class UtcAwareFilterGreaterOrEqual(UtcAwareFilterMixin, FilterGreaterOrEqual):
174
- """Greater than or Equal filter for UTC time."""
175
-
176
-
177
- class UtcAwareFilterEqual(UtcAwareFilterMixin, fab_sqlafilters.FilterEqual):
178
- """Equality filter for UTC time."""
179
-
180
-
181
- class UtcAwareFilterGreater(UtcAwareFilterMixin, fab_sqlafilters.FilterGreater):
182
- """Greater Than filter for UTC time."""
183
-
184
-
185
- class UtcAwareFilterSmaller(UtcAwareFilterMixin, fab_sqlafilters.FilterSmaller):
186
- """Smaller Than filter for UTC time."""
187
-
188
-
189
- class UtcAwareFilterNotEqual(UtcAwareFilterMixin, fab_sqlafilters.FilterNotEqual):
190
- """Not Equal To filter for UTC time."""
191
-
192
-
193
- class AirflowFilterConverter(fab_sqlafilters.SQLAFilterConverter):
194
- """Retrieve conversion tables for Airflow-specific filters."""
195
-
196
- conversion_table = (
197
- (
198
- "is_utcdatetime",
199
- [
200
- UtcAwareFilterEqual,
201
- UtcAwareFilterGreater,
202
- UtcAwareFilterSmaller,
203
- UtcAwareFilterNotEqual,
204
- UtcAwareFilterSmallerOrEqual,
205
- UtcAwareFilterGreaterOrEqual,
206
- ],
207
- ),
208
- # FAB will try to create filters for extendedjson fields even though we
209
- # exclude them from all UI, so we add this here to make it ignore them.
210
- ("is_extendedjson", []),
211
- ("is_json", []),
212
- *fab_sqlafilters.SQLAFilterConverter.conversion_table,
213
- )
214
-
215
- def __init__(self, datamodel):
216
- super().__init__(datamodel)
217
-
218
- for _, filters in self.conversion_table:
219
- if FilterIsNull not in filters:
220
- filters.append(FilterIsNull)
221
- if FilterIsNotNull not in filters:
222
- filters.append(FilterIsNotNull)
223
-
224
-
225
- class CustomSQLAInterface(SQLAInterface):
226
- """
227
- FAB does not know how to handle columns with leading underscores because they are not supported by WTForm.
228
-
229
- This hack will remove the leading '_' from the key to lookup the column names.
230
- """
231
-
232
- def __init__(self, obj, session: Session | None = None):
233
- super().__init__(obj, session=session)
234
-
235
- def clean_column_names():
236
- if self.list_properties:
237
- self.list_properties = {k.lstrip("_"): v for k, v in self.list_properties.items()}
238
- if self.list_columns:
239
- self.list_columns = {k.lstrip("_"): v for k, v in self.list_columns.items()}
240
-
241
- clean_column_names()
242
- # Support for AssociationProxy in search and list columns
243
- for obj_attr, desc in self.obj.__mapper__.all_orm_descriptors.items():
244
- if isinstance(desc, AssociationProxy):
245
- proxy_instance = getattr(self.obj, obj_attr)
246
- if hasattr(proxy_instance.remote_attr.prop, "columns"):
247
- self.list_columns[obj_attr] = proxy_instance.remote_attr.prop.columns[0]
248
- self.list_properties[obj_attr] = proxy_instance.remote_attr.prop
249
-
250
- def is_utcdatetime(self, col_name):
251
- """Check if the datetime is a UTC one."""
252
- from airflow.utils.sqlalchemy import UtcDateTime
253
-
254
- if col_name in self.list_columns:
255
- obj = self.list_columns[col_name].type
256
- return (
257
- isinstance(obj, UtcDateTime)
258
- or isinstance(obj, types.TypeDecorator)
259
- and isinstance(obj.impl, UtcDateTime)
260
- )
261
- return False
262
-
263
- def is_extendedjson(self, col_name):
264
- """Check if it is a special extended JSON type."""
265
- from airflow.utils.sqlalchemy import ExtendedJSON
266
-
267
- if col_name in self.list_columns:
268
- obj = self.list_columns[col_name].type
269
- return (
270
- isinstance(obj, ExtendedJSON)
271
- or isinstance(obj, types.TypeDecorator)
272
- and isinstance(obj.impl, ExtendedJSON)
273
- )
274
- return False
275
-
276
- def is_json(self, col_name):
277
- """Check if it is a JSON type."""
278
- from sqlalchemy import JSON
279
-
280
- if col_name in self.list_columns:
281
- obj = self.list_columns[col_name].type
282
- return (
283
- isinstance(obj, JSON) or isinstance(obj, types.TypeDecorator) and isinstance(obj.impl, JSON)
284
- )
285
- return False
286
-
287
- def get_col_default(self, col_name: str) -> Any:
288
- if col_name not in self.list_columns:
289
- # Handle AssociationProxy etc, or anything that isn't a "real" column
290
- return None
291
- return super().get_col_default(col_name)
292
-
293
- filter_converter_class = AirflowFilterConverter