airflow-ldap-auth-manager 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,733 @@
1
+ """
2
+ LDAPAuthManager for Apache Airflow 3.1+
3
+
4
+ Key notes:
5
+ * Implements all abstract methods from BaseAuthManager, including
6
+ `filter_authorized_menu_items`, `is_authorized_asset[_alias]`, `is_authorized_backfill`,
7
+ `is_authorized_pool`, `is_authorized_variable`, and `is_authorized_custom_view`.
8
+ * LDAP authentication via ldap3.
9
+ * JWT cookie handoff per Airflow 3 spec (`_token`, not httponly, secure if https).
10
+ * Group→role mapping for admin/editor/viewer.
11
+ """
12
+ import json
13
+ import logging
14
+ import re
15
+ import ssl
16
+ from collections.abc import Iterable
17
+ from dataclasses import dataclass, field
18
+ from enum import IntEnum
19
+ from pathlib import Path
20
+ from typing import Any, Optional, TypedDict, override
21
+
22
+ from fastapi import APIRouter, FastAPI, Request
23
+ from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
24
+ from fastapi.staticfiles import StaticFiles
25
+ from jinja2 import Environment, FileSystemLoader
26
+ from ldap3 import (ALL, ALL_ATTRIBUTES, AUTO_BIND_NO_TLS,
27
+ AUTO_BIND_TLS_BEFORE_BIND, ROUND_ROBIN, SUBTREE, Connection,
28
+ Server, ServerPool, Tls)
29
+
30
+ from airflow.api_fastapi.auth.managers.base_auth_manager import (
31
+ COOKIE_NAME_JWT_TOKEN, BaseAuthManager, ResourceMethod)
32
+ from airflow.api_fastapi.auth.managers.models import resource_details as rd
33
+ from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
34
+ from airflow.configuration import conf
35
+ from airflow.sdk import Variable
36
+
37
+ log = logging.getLogger("airflow.auth.ldap")
38
+
39
+
40
+ class AuthenticatedUserData(TypedDict, total=False):
41
+ """LDAP authentication payload produced by :class:`LdapClient`."""
42
+
43
+ dn: str
44
+ attrs: dict[str, Any]
45
+ username: str | None
46
+ email: str | None
47
+ groups: list[str]
48
+
49
+
50
+ def _get_sensitive(section: str, key: str) -> str | None:
51
+ """Return a potentially secret value stored either in Variables or airflow.cfg."""
52
+ # 1) Secret indirection via Variables/secret backend
53
+ secret_name = conf.get(section, f"{key}_secret", fallback=None)
54
+ if secret_name:
55
+ val = Variable.get(secret_name, default=None)
56
+ if val:
57
+ return val
58
+ # 2) Plaintext fallback from airflow.cfg
59
+ return conf.get(section, key, fallback=None)
60
+
61
+
62
+ def _listify_bases(val: str | None) -> list[str]:
63
+ """Parse DN bases supplied either as JSON or separated by semicolons/newlines."""
64
+ if not val:
65
+ return []
66
+ s = val.strip()
67
+ # If someone gave us JSON, use it
68
+ if s.startswith("["):
69
+ try:
70
+ arr = json.loads(s)
71
+ return [x.strip() for x in arr if isinstance(x, str) and x.strip()]
72
+ except Exception:
73
+ pass
74
+ # Otherwise split on semicolons or newlines (never commas!)
75
+ parts = re.split(r"[;\n]+", s)
76
+ return [p.strip() for p in parts if p.strip()]
77
+
78
+
79
+ # -----------------------------
80
+ # User model
81
+ # -----------------------------
82
+ @dataclass
83
+ class LdapUser(BaseUser):
84
+ """Airflow user representation enriched with LDAP metadata."""
85
+
86
+ user_id: str
87
+ username: str | None = None
88
+ email: str | None = None
89
+ groups: list[str] = field(default_factory=list)
90
+
91
+
92
+ class Role(IntEnum):
93
+ """Simple role ladder used for authorization checks."""
94
+
95
+ NONE = 0
96
+ VIEWER = 1
97
+ EDITOR = 2
98
+ ADMIN = 3
99
+
100
+
101
+ # -----------------------------
102
+ # Helper: LDAP access
103
+ # -----------------------------
104
+ class LdapClient:
105
+ """Wrapper around :mod:`ldap3` interactions for authentication searches."""
106
+
107
+ def __init__(self):
108
+ """Initialise server pool and cached configuration values."""
109
+ uris = [uri.strip() for uri in conf.get("ldap_auth_manager", "server_uri").split(",")]
110
+
111
+ self._bind_dn = _get_sensitive("ldap_auth_manager", "bind_dn")
112
+ self._bind_pw = _get_sensitive("ldap_auth_manager", "bind_password")
113
+
114
+ bases_raw = conf.get("ldap_auth_manager", "user_search_base", fallback="") # can be 1 or many
115
+ self._user_bases = _listify_bases(bases_raw)
116
+ if not self._user_bases:
117
+ raise ValueError("ldap_auth_manager.user_search_base must be set to at least one DN")
118
+
119
+ # self._user_base = conf.get("ldap_auth_manager", "user_search_base")
120
+ self._user_filter_tpl = conf.get(
121
+ "ldap_auth_manager",
122
+ "user_search_filter",
123
+ fallback="(|(uid={username})(sAMAccountName={username})(mail={username}))",
124
+ )
125
+ self._group_base = conf.get("ldap_auth_manager", "group_search_base", fallback=None)
126
+ self._group_member_attr = conf.get("ldap_auth_manager", "group_member_attr", fallback="member")
127
+ self._username_attr = conf.get("ldap_auth_manager", "username_attr", fallback="uid")
128
+ self._email_attr = conf.get("ldap_auth_manager", "email_attr", fallback="mail")
129
+ self._start_tls = conf.getboolean("ldap_auth_manager", "start_tls", fallback=False)
130
+ self._verify_ssl = conf.getboolean("ldap_auth_manager", "verify_ssl", fallback=True)
131
+ self._debug_logging = conf.getboolean("ldap_auth_manager", "debug_logging", fallback=False)
132
+
133
+ servers = []
134
+ for uri in uris:
135
+ lower_uri = uri.lower()
136
+ use_ssl = lower_uri.startswith("ldaps://")
137
+ tls = None
138
+ if use_ssl or self._start_tls:
139
+ # ``Tls`` handles certificate verification; disable only when explicitly requested.
140
+ tls = Tls(validate=ssl.CERT_REQUIRED if self._verify_ssl else ssl.CERT_NONE, version=ssl.PROTOCOL_TLS)
141
+ servers.append(Server(uri, use_ssl=use_ssl, get_info=ALL, tls=tls))
142
+
143
+ self._servers = ServerPool(servers, pool_strategy=ROUND_ROBIN)
144
+
145
+ def _service_conn(self) -> Connection:
146
+ """Return a bound service connection used for privileged LDAP queries."""
147
+
148
+ auto_bind = AUTO_BIND_TLS_BEFORE_BIND if self._start_tls else AUTO_BIND_NO_TLS
149
+
150
+ if self._start_tls and any(s.ssl for s in self._servers.servers):
151
+ raise ValueError("start_tls=true requires ldap:// (plain) servers, not ldaps://")
152
+
153
+ conn = Connection(
154
+ self._servers, # this is a ServerPool
155
+ user=self._bind_dn or None,
156
+ password=self._bind_pw or None,
157
+ auto_bind=auto_bind,
158
+ )
159
+
160
+ if self._debug_logging:
161
+ # after successful auto_bind
162
+ srv = conn.server # the chosen ldap3.Server from the pool
163
+
164
+ use_ssl = getattr(srv, "ssl", None)
165
+ if use_ssl is None:
166
+ use_ssl = getattr(srv, "use_ssl", False)
167
+
168
+ scheme = "ldaps" if use_ssl else "ldap"
169
+ port = srv.port or (636 if use_ssl else 389)
170
+ pool_strategy = getattr(self._servers, "strategy", None) or getattr(self._servers, "pool_strategy", "n/a")
171
+
172
+ log.info(
173
+ f"LDAP bound to {scheme}://{srv.host}:{port} "
174
+ f"(pool_strategy={pool_strategy}, start_tls={self._start_tls})"
175
+ )
176
+
177
+ # optional: show the authenticated identity
178
+ try:
179
+ who = conn.extend.standard.who_am_i()
180
+ log.info(f"LDAP whoami: {who}")
181
+ except Exception as e:
182
+ log.warning(f"LDAP whoami not available: {e!r}")
183
+
184
+ return conn
185
+
186
+ def authenticate(self, username: str, password: str) -> Optional[AuthenticatedUserData]:
187
+ """Authenticate a user by binding with their DN and password."""
188
+ with self._service_conn() as svc:
189
+ flt = self._user_filter_tpl.format(username=username)
190
+ entry = None
191
+ for base in self._user_bases:
192
+ if svc.search(
193
+ search_base=base,
194
+ search_filter=flt,
195
+ search_scope=SUBTREE,
196
+ attributes=ALL_ATTRIBUTES,
197
+ ):
198
+ entry = svc.entries[0]
199
+ break
200
+
201
+ if not entry:
202
+ return None
203
+
204
+ user_dn = entry.entry_dn
205
+ attrs = entry.entry_attributes_as_dict
206
+
207
+ if self._debug_logging:
208
+ log.info(f"LDAP user search matched base={base!r} dn={user_dn!r}")
209
+
210
+ try:
211
+ with Connection(self._servers, user=user_dn, password=password, auto_bind=True):
212
+ pass
213
+ except Exception:
214
+ return None
215
+
216
+ groups: list[str] = []
217
+ if self._group_base:
218
+ with self._service_conn() as svc2:
219
+ member_attr = self._group_member_attr
220
+ # Some directories store either the DN or the username in the member attribute,
221
+ # so we query for both forms in a single OR filter.
222
+ member_filters = [f"({member_attr}={user_dn})", f"({member_attr}={username})"]
223
+ flt = f"(|{''.join(member_filters)})"
224
+ svc2.search(
225
+ search_base=self._group_base,
226
+ search_filter=flt,
227
+ search_scope=SUBTREE,
228
+ attributes=["cn"],
229
+ )
230
+ for e in svc2.entries:
231
+ groups.append(str(e.cn))
232
+
233
+ username_attr = attrs.get(self._username_attr)
234
+ if isinstance(username_attr, list):
235
+ norm_username = username_attr[0] if username_attr else username
236
+ else:
237
+ norm_username = username_attr or username
238
+
239
+ email_attr = attrs.get(self._email_attr)
240
+ if isinstance(email_attr, list):
241
+ email = email_attr[0] if email_attr else None
242
+ else:
243
+ email = email_attr
244
+
245
+ return {
246
+ "dn": user_dn,
247
+ "attrs": attrs,
248
+ "username": norm_username,
249
+ "email": email,
250
+ "groups": groups,
251
+ }
252
+
253
+
254
+ # -----------------------------
255
+ # AuthZ policy helpers
256
+ # -----------------------------
257
+ class Policy:
258
+ """
259
+ Small helper that translates LDAP groups to a single effective Role and
260
+ exposes convenience predicates.
261
+ """
262
+
263
+ def __init__(self):
264
+ # Store as lowercase once; compare on lowercase later.
265
+ self.admin_groups = self._load_group_config("admin_groups")
266
+ self.editor_groups = self._load_group_config("editor_groups")
267
+ self.viewer_groups = self._load_group_config("viewer_groups")
268
+ # Default if user matches no configured groups
269
+ self._default_role = Role.NONE # <— deny by default
270
+
271
+ def _load_group_config(self, option: str) -> set[str]:
272
+ """Return the configured group list for ``option`` lowered for easy matching."""
273
+ raw_value = conf.get("ldap_auth_manager", option, fallback="")
274
+ return {group.lower() for group in _csv_to_set(raw_value)}
275
+
276
+ def role_for(self, groups: Iterable[str]) -> Role:
277
+ """Return the highest Role allowed by the supplied group memberships."""
278
+ gs = {g.lower() for g in (groups or [])}
279
+ if self.admin_groups and (gs & self.admin_groups):
280
+ return Role.ADMIN
281
+ if self.editor_groups and (gs & self.editor_groups):
282
+ return Role.EDITOR
283
+ if self.viewer_groups and (gs & self.viewer_groups):
284
+ return Role.VIEWER
285
+ # If no mapping found, default to least privilege
286
+ return self._default_role
287
+
288
+ def at_least(self, groups: Iterable[str], min_role: Role) -> bool:
289
+ """Return ``True`` if ``groups`` map to a role >= ``min_role``."""
290
+ return self.role_for(groups) >= min_role
291
+
292
+ # kept for readability if you want to use them elsewhere
293
+ def is_admin(self, groups: Iterable[str]) -> bool:
294
+ """Return ``True`` when ``groups`` grant administrator privileges."""
295
+ return self.at_least(groups, Role.ADMIN)
296
+
297
+ def is_editor(self, groups: Iterable[str]) -> bool:
298
+ """Return ``True`` when ``groups`` grant editor privileges."""
299
+ return self.at_least(groups, Role.EDITOR)
300
+
301
+ def is_viewer(self, groups: Iterable[str]) -> bool:
302
+ """Return ``True`` when ``groups`` grant viewer privileges."""
303
+ return self.at_least(groups, Role.VIEWER)
304
+
305
+
306
+ def _csv_to_set(val: str) -> set[str]:
307
+ """Convert a comma separated string into a set of trimmed tokens."""
308
+ return {x.strip() for x in val.split(",") if x.strip()}
309
+
310
+
311
+ # -----------------------------
312
+ # LDAP Auth Manager
313
+ # -----------------------------
314
+ class LdapAuthManager(BaseAuthManager[LdapUser]):
315
+ """LDAP backed implementation of Airflow's :class:`BaseAuthManager`."""
316
+
317
+ def __init__(self, context=None):
318
+ """Create the manager with configured LDAP client and authorization policy."""
319
+ super().__init__(context=context)
320
+ self._ldap = LdapClient()
321
+ self._policy = Policy()
322
+ self._debug_logging = conf.getboolean("ldap_auth_manager", "debug_logging", fallback=False)
323
+
324
+ # --- Authentication surface ---
325
+ @override
326
+ def get_url_login(self, **kwargs) -> str:
327
+ """Return the login URL including the ``next`` redirect parameter."""
328
+ next_url = kwargs.get("next", "/")
329
+ return f"/auth/login?next={next_url}"
330
+
331
+ def get_url_logout(self) -> Optional[str]:
332
+ """Return the configured logout redirect target."""
333
+ return conf.get("ldap_auth_manager", "logout_redirect", fallback="/")
334
+
335
+ @override
336
+ def serialize_user(self, user: LdapUser) -> dict:
337
+ """Serialize ``LdapUser`` instances for storage inside JWT payloads."""
338
+ return {
339
+ "user_id": user.user_id,
340
+ "username": user.username,
341
+ "email": user.email,
342
+ "groups": user.groups,
343
+ }
344
+
345
+ @override
346
+ def deserialize_user(self, data: dict) -> LdapUser:
347
+ """Recreate ``LdapUser`` instances from serialized payloads."""
348
+ return LdapUser(
349
+ user_id=str(data.get("user_id")),
350
+ username=data.get("username"),
351
+ email=data.get("email"),
352
+ groups=list(data.get("groups") or []),
353
+ )
354
+
355
+ @override
356
+ async def get_user_from_token(self, token: str) -> LdapUser | None:
357
+ """Decode ``token`` and return a user only if policy permits."""
358
+ import jwt
359
+ from jwt import (ExpiredSignatureError, InvalidAudienceError,
360
+ InvalidSignatureError, InvalidTokenError)
361
+
362
+ secret = conf.get("api_auth", "jwt_secret")
363
+ if not secret:
364
+ return None
365
+
366
+ alg = conf.get("api_auth", "jwt_algorithm", fallback="HS512") or "HS512"
367
+ audience = conf.get("api_auth", "jwt_audience", fallback="urn:airflow.apache.org:api")
368
+ aud = (audience.split(",")[0]).strip() if audience and "," in audience else audience
369
+ issuer = conf.get("api_auth", "jwt_issuer", fallback=None)
370
+
371
+ try:
372
+ claims = jwt.decode(
373
+ token,
374
+ secret,
375
+ algorithms=[alg],
376
+ audience=aud if aud else None,
377
+ issuer=issuer if issuer else None,
378
+ options={"require": ["sub", "exp", "iat"], "verify_aud": bool(aud)},
379
+ )
380
+ except (ExpiredSignatureError, InvalidAudienceError, InvalidSignatureError, InvalidTokenError):
381
+ return None
382
+
383
+ data = claims.get("user") or {}
384
+ user = self.deserialize_user(data)
385
+
386
+ # Enforce whitelist again at request time
387
+ if self._policy.role_for(user.groups) is Role.NONE:
388
+ return None
389
+
390
+ return user
391
+
392
+ # --- Menu filtering ---
393
+ def filter_authorized_menu_items(self, menu_items, *, user: LdapUser):
394
+ """Return the menu unchanged; permissions are enforced per endpoint."""
395
+ # Everyone can see the whole menu; the endpoints themselves enforce write restrictions.
396
+ return list(menu_items or [])
397
+
398
+ # --- Authorization surface ---
399
+ def _norm_method(self, method) -> str:
400
+ """Coerce ``method`` (enum or string) into an uppercase HTTP verb."""
401
+ name = getattr(method, "name", None)
402
+ if name:
403
+ return str(name).upper()
404
+ value = getattr(method, "value", None)
405
+ if value is not None:
406
+ return str(value).upper()
407
+ return str(method).upper()
408
+
409
+ def _is_dag_run_scoped(
410
+ self, write_scope: str | None, access_entity: "rd.DagAccessEntity | None"
411
+ ) -> bool:
412
+ """Return ``True`` when the request targets DAG run level operations."""
413
+
414
+ if write_scope == "dag_run":
415
+ return True
416
+
417
+ if access_entity is None:
418
+ return False
419
+
420
+ try:
421
+ ae_name = getattr(access_entity, "name", None)
422
+ if ae_name is None:
423
+ ae_name = str(access_entity)
424
+ return str(ae_name).upper() in {"DAG_RUN", "DAGRUN", "RUN"}
425
+ except Exception:
426
+ return False
427
+
428
+ def _is_authorized(
429
+ self,
430
+ *,
431
+ method: ResourceMethod | str,
432
+ user: LdapUser,
433
+ access_entity: "rd.DagAccessEntity | None" = None,
434
+ write_scope: str | None = None,
435
+ ) -> bool:
436
+ """
437
+ Central authorization rule-set.
438
+
439
+ Rules:
440
+ - GET -> Viewer+
441
+ - Non-GET:
442
+ - If write_scope == "dag_run" (or DagAccessEntity.DAG_RUN), Editor+
443
+ - Else Admin only
444
+ """
445
+ role = self._policy.role_for(user.groups)
446
+ if role == Role.NONE:
447
+ return False # deny outright
448
+
449
+ m = self._norm_method(method)
450
+
451
+ # Always allow reads
452
+ if m == "GET":
453
+ return self._policy.at_least(user.groups, Role.VIEWER)
454
+
455
+ # Non-GET (write-ish)
456
+ # Detect dag-run scoped writes either via explicit marker or DagAccessEntity
457
+ if self._is_dag_run_scoped(write_scope, access_entity):
458
+ return self._policy.at_least(user.groups, Role.EDITOR)
459
+
460
+ # Everything else requires full admin
461
+ return self._policy.at_least(user.groups, Role.ADMIN)
462
+
463
+ @override
464
+ def is_authorized_configuration(
465
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.ConfigurationDetails | None = None
466
+ ) -> bool:
467
+ """Allow configuration access for admins, read-only for others."""
468
+ # Configuration changes are admin-only; reads are fine for all.
469
+ return self._is_authorized(method=method, user=user)
470
+
471
+ @override
472
+ def is_authorized_connection(
473
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.ConnectionDetails | None = None
474
+ ) -> bool:
475
+ """Authorize access to individual connections using the global policy."""
476
+ return self._is_authorized(method=method, user=user)
477
+
478
+ @override
479
+ def batch_is_authorized_connection(
480
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.ConnectionDetails | None = None
481
+ ) -> bool:
482
+ """Apply the same rules as ``is_authorized_connection`` for batch endpoints."""
483
+ return self._is_authorized(method=method, user=user)
484
+
485
+ @override
486
+ def batch_is_authorized_variable(
487
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.VariableDetails | None = None
488
+ ) -> bool:
489
+ """Apply standard policy to batch variable endpoints."""
490
+ return self._is_authorized(method=method, user=user)
491
+
492
+ @override
493
+ def is_authorized_variable(
494
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.VariableDetails | None = None
495
+ ) -> bool:
496
+ """Authorize individual variable operations via the central policy."""
497
+ return self._is_authorized(method=method, user=user)
498
+
499
+ @override
500
+ def batch_is_authorized_pool(
501
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.PoolDetails | None = None
502
+ ) -> bool:
503
+ """Apply the same rules as pool single-item operations."""
504
+ return self._is_authorized(method=method, user=user)
505
+
506
+ @override
507
+ def is_authorized_pool(
508
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.PoolDetails | None = None
509
+ ) -> bool:
510
+ """Authorize individual pool operations via the shared policy."""
511
+ return self._is_authorized(method=method, user=user)
512
+
513
+ @override
514
+ def is_authorized_asset(
515
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.AssetDetails | None = None
516
+ ) -> bool:
517
+ """Authorize asset interactions via the shared policy."""
518
+ return self._is_authorized(method=method, user=user)
519
+
520
+ @override
521
+ def is_authorized_asset_alias(
522
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.AssetAliasDetails | None = None
523
+ ) -> bool:
524
+ """Authorize asset alias interactions via the shared policy."""
525
+ return self._is_authorized(method=method, user=user)
526
+
527
+ @override
528
+ def is_authorized_backfill(
529
+ self, *, method: ResourceMethod, user: LdapUser, details: rd.BackfillDetails | None = None
530
+ ) -> bool:
531
+ """Require admin access for disruptive backfill operations."""
532
+ # Backfills are disruptive -> admin-only for writes; reads for all
533
+ return self._is_authorized(method=method, user=user, write_scope=None)
534
+
535
+ @override
536
+ def is_authorized_dag(
537
+ self,
538
+ *,
539
+ method: ResourceMethod,
540
+ user: LdapUser,
541
+ access_entity: rd.DagAccessEntity | None = None,
542
+ details: rd.DagDetails | None = None,
543
+ ) -> bool:
544
+ """Authorize DAG operations, allowing editors to manage DAG runs."""
545
+ # Let editor write only when the operation targets DAG RUNs
546
+ return self._is_authorized(method=method, user=user, access_entity=access_entity, write_scope=None)
547
+
548
+ @override
549
+ def batch_is_authorized_dag(
550
+ self,
551
+ *,
552
+ method: ResourceMethod,
553
+ user: LdapUser,
554
+ access_entity: rd.DagAccessEntity | None = None,
555
+ details: rd.DagDetails | None = None,
556
+ ) -> bool:
557
+ """Batch DAG endpoints follow the same rules as single DAG operations."""
558
+ # Let editor write only when the operation targets DAG RUNs
559
+ return self._is_authorized(method=method, user=user, access_entity=access_entity, write_scope=None)
560
+
561
+ @override
562
+ def is_authorized_view(self, *, access_view: rd.AccessView, user: LdapUser) -> bool:
563
+ """Permit viewer-level access to read-only views."""
564
+ # Views are read-only by design
565
+ return self._policy.at_least(user.groups, Role.VIEWER)
566
+
567
+ @override
568
+ def is_authorized_custom_view(self, *, method: ResourceMethod | str, resource_name: str, user: LdapUser) -> bool:
569
+ """Authorize arbitrary custom views using the default policy."""
570
+ return self._is_authorized(method=method, user=user)
571
+
572
+ # -----------------------------
573
+ # FastAPI extension for login/token/logout
574
+ # -----------------------------
575
+ def get_fastapi_app(self) -> Optional[FastAPI]:
576
+ router = APIRouter()
577
+
578
+ base_dir = Path(__file__).parent
579
+ template_dir = base_dir / "templates"
580
+ static_dir = base_dir / "static"
581
+ jinja_env = Environment(loader=FileSystemLoader(template_dir), autoescape=True)
582
+
583
+ instance_name = conf.get("api", "instance_name", fallback="Airflow")
584
+ login_tip = conf.get("ldap_auth_manager", "login_tip", fallback="")
585
+
586
+ jinja_env.globals.update(instance_name=instance_name, login_tip=login_tip)
587
+
588
+ def render(name: str, **ctx) -> HTMLResponse:
589
+ """Render ``name`` with the provided context."""
590
+ tpl = jinja_env.get_template(name)
591
+ return HTMLResponse(tpl.render(**ctx))
592
+
593
+ def _sanitize_next(raw_next: Optional[str], request: Request) -> str:
594
+ """Return a safe, same-origin path for redirect.
595
+ - unwrap nested next parameters once (e.g. "/?next=/graph")
596
+ - drop absolute URLs to other hosts
597
+ - default to '/'
598
+ """
599
+ target = (raw_next or "/").strip()
600
+ try:
601
+ from urllib.parse import parse_qs, urlparse
602
+
603
+ # unwrap one level of nested 'next'
604
+ qn = parse_qs(urlparse(target).query).get("next", [])
605
+ if qn:
606
+ target = qn[0]
607
+
608
+ # allow only same-origin relative paths
609
+ if target.startswith("/") and not target.startswith("//"):
610
+ return target
611
+
612
+ # if absolute, ensure same host
613
+ u = urlparse(target)
614
+ req = urlparse(str(request.base_url))
615
+ if u.scheme and u.netloc and u.netloc == req.netloc:
616
+ return u.path or "/"
617
+ except Exception:
618
+ pass
619
+ return "/"
620
+
621
+ @router.get("/login", response_class=HTMLResponse)
622
+ def login_form(next: str = "/", error: str | None = None):
623
+ """Serve the HTML login form."""
624
+ return render("ldap_login.html", next=next, error=error)
625
+
626
+ @router.post("/token")
627
+ async def create_token(request: Request):
628
+ """Authenticate the user and return/issue a JWT token."""
629
+ # --- parse input: JSON or form ---
630
+ username = password = next_param = None
631
+ ct = (request.headers.get("content-type") or "").lower()
632
+
633
+ if ct.startswith("application/x-www-form-urlencoded") or ct.startswith("multipart/form-data"):
634
+ form = await request.form()
635
+ username = form.get("username")
636
+ password = form.get("password")
637
+ next_param = form.get("next")
638
+ else:
639
+ # Try JSON regardless of Content-Type (some clients lie)
640
+ try:
641
+ payload = await request.json()
642
+ except Exception:
643
+ payload = None
644
+ if isinstance(payload, dict):
645
+ username = payload.get("username")
646
+ password = payload.get("password")
647
+ next_param = payload.get("next")
648
+
649
+ if not username or not password:
650
+ return JSONResponse({"detail": "username and password are required"}, status_code=422)
651
+
652
+ # --- authenticate via LDAP as you already do ---
653
+ info = self._ldap.authenticate(username=username, password=password) # type: ignore
654
+ if not info:
655
+ if "application/json" in (request.headers.get("accept") or ""):
656
+ return JSONResponse({"detail": "Invalid credentials"}, status_code=401)
657
+ target = f"/auth/login?next={(next_param or '/')}&&error=Invalid%20credentials"
658
+ return RedirectResponse(url=target, status_code=303)
659
+
660
+ user = LdapUser(
661
+ user_id=info["dn"], # type: ignore
662
+ username=info.get("username"),
663
+ email=info.get("email"),
664
+ groups=info.get("groups", []),
665
+ )
666
+
667
+ # role check...
668
+ role = self._policy.role_for(user.groups)
669
+ if self._debug_logging:
670
+ log.info(f"LDAP login: user={user.username} groups={user.groups} role={role.name}")
671
+ if role is Role.NONE:
672
+ msg = "You are not a member of any Airflow access group"
673
+ if "application/json" in (request.headers.get("accept") or ""):
674
+ return JSONResponse({"detail": msg}, status_code=403)
675
+ target = f"/auth/login?next={(next_param or '/')}&&error=" + msg.replace(" ", "%20")
676
+ return RedirectResponse(url=target, status_code=303)
677
+
678
+ # --- mint JWT ---
679
+ from datetime import datetime, timedelta, timezone
680
+
681
+ import jwt
682
+
683
+ secret = conf.get("api_auth", "jwt_secret")
684
+ if not secret:
685
+ return JSONResponse({"detail": "api_auth.jwt_secret is not set"}, status_code=500)
686
+
687
+ alg = conf.get("api_auth", "jwt_algorithm", fallback="HS512") or "HS512"
688
+ audience = conf.get("api_auth", "jwt_audience", fallback="urn:airflow.apache.org:api")
689
+ aud = (audience.split(",")[0]).strip() if audience and "," in audience else audience
690
+ issuer = conf.get("api_auth", "jwt_issuer", fallback=None)
691
+ kid = conf.get("api_auth", "jwt_kid", fallback=None)
692
+ exp_secs = conf.getint("api_auth", "jwt_expiration_time", fallback=36000)
693
+
694
+ now = datetime.now(timezone.utc)
695
+ claims = {
696
+ "sub": user.user_id,
697
+ "iat": int(now.timestamp()),
698
+ "nbf": int(now.timestamp()) - 5,
699
+ "exp": int((now + timedelta(seconds=exp_secs)).timestamp()),
700
+ "aud": aud,
701
+ "user": self.serialize_user(user),
702
+ }
703
+ if issuer:
704
+ claims["iss"] = issuer
705
+ headers = {"kid": kid} if kid else None
706
+ token = jwt.encode(claims, secret, algorithm=alg, headers=headers)
707
+
708
+ # --- respond JSON for API callers, cookie+303 for browsers ---
709
+ wants_json = "application/json" in (request.headers.get("accept") or "")
710
+ if wants_json:
711
+ return JSONResponse({
712
+ "access_token": token,
713
+ "token_type": "Bearer",
714
+ "expires_in": exp_secs,
715
+ })
716
+
717
+ target = _sanitize_next(next_param, request) # type: ignore[arg-type]
718
+ resp = RedirectResponse(url=target or "/", status_code=303)
719
+ secure = (request.base_url.scheme == "https") or bool(conf.get("api", "ssl_cert", fallback=""))
720
+ resp.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure, httponly=False, samesite="lax", path="/")
721
+ return resp
722
+
723
+ @router.get("/logout")
724
+ def logout(next: str = "/"):
725
+ """Clear the JWT cookie and redirect to ``next`` or the configured URL."""
726
+ resp = RedirectResponse(url=conf.get("ldap_auth_manager", "logout_redirect", fallback=next))
727
+ resp.delete_cookie(COOKIE_NAME_JWT_TOKEN, path="/")
728
+ return resp
729
+
730
+ app = FastAPI()
731
+ app.mount("/static", StaticFiles(directory=str(static_dir)), name="auth-static")
732
+ app.include_router(router)
733
+ return app