airflow-ldap-auth-manager 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow_ldap_auth_manager/__init__.py +5 -0
- airflow_ldap_auth_manager/ldap_auth_manager.py +733 -0
- airflow_ldap_auth_manager/static/airflow.svg +13 -0
- airflow_ldap_auth_manager/static/style.css +138 -0
- airflow_ldap_auth_manager/templates/ldap_login.html +41 -0
- airflow_ldap_auth_manager-0.1.0.dist-info/METADATA +333 -0
- airflow_ldap_auth_manager-0.1.0.dist-info/RECORD +10 -0
- airflow_ldap_auth_manager-0.1.0.dist-info/WHEEL +5 -0
- airflow_ldap_auth_manager-0.1.0.dist-info/licenses/LICENSE +201 -0
- airflow_ldap_auth_manager-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,733 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LDAPAuthManager for Apache Airflow 3.1+
|
|
3
|
+
|
|
4
|
+
Key notes:
|
|
5
|
+
* Implements all abstract methods from BaseAuthManager, including
|
|
6
|
+
`filter_authorized_menu_items`, `is_authorized_asset[_alias]`, `is_authorized_backfill`,
|
|
7
|
+
`is_authorized_pool`, `is_authorized_variable`, and `is_authorized_custom_view`.
|
|
8
|
+
* LDAP authentication via ldap3.
|
|
9
|
+
* JWT cookie handoff per Airflow 3 spec (`_token`, not httponly, secure if https).
|
|
10
|
+
* Group→role mapping for admin/editor/viewer.
|
|
11
|
+
"""
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
import re
|
|
15
|
+
import ssl
|
|
16
|
+
from collections.abc import Iterable
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from enum import IntEnum
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any, Optional, TypedDict, override
|
|
21
|
+
|
|
22
|
+
from fastapi import APIRouter, FastAPI, Request
|
|
23
|
+
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
|
24
|
+
from fastapi.staticfiles import StaticFiles
|
|
25
|
+
from jinja2 import Environment, FileSystemLoader
|
|
26
|
+
from ldap3 import (ALL, ALL_ATTRIBUTES, AUTO_BIND_NO_TLS,
|
|
27
|
+
AUTO_BIND_TLS_BEFORE_BIND, ROUND_ROBIN, SUBTREE, Connection,
|
|
28
|
+
Server, ServerPool, Tls)
|
|
29
|
+
|
|
30
|
+
from airflow.api_fastapi.auth.managers.base_auth_manager import (
|
|
31
|
+
COOKIE_NAME_JWT_TOKEN, BaseAuthManager, ResourceMethod)
|
|
32
|
+
from airflow.api_fastapi.auth.managers.models import resource_details as rd
|
|
33
|
+
from airflow.api_fastapi.auth.managers.models.base_user import BaseUser
|
|
34
|
+
from airflow.configuration import conf
|
|
35
|
+
from airflow.sdk import Variable
|
|
36
|
+
|
|
37
|
+
log = logging.getLogger("airflow.auth.ldap")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AuthenticatedUserData(TypedDict, total=False):
|
|
41
|
+
"""LDAP authentication payload produced by :class:`LdapClient`."""
|
|
42
|
+
|
|
43
|
+
dn: str
|
|
44
|
+
attrs: dict[str, Any]
|
|
45
|
+
username: str | None
|
|
46
|
+
email: str | None
|
|
47
|
+
groups: list[str]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _get_sensitive(section: str, key: str) -> str | None:
|
|
51
|
+
"""Return a potentially secret value stored either in Variables or airflow.cfg."""
|
|
52
|
+
# 1) Secret indirection via Variables/secret backend
|
|
53
|
+
secret_name = conf.get(section, f"{key}_secret", fallback=None)
|
|
54
|
+
if secret_name:
|
|
55
|
+
val = Variable.get(secret_name, default=None)
|
|
56
|
+
if val:
|
|
57
|
+
return val
|
|
58
|
+
# 2) Plaintext fallback from airflow.cfg
|
|
59
|
+
return conf.get(section, key, fallback=None)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _listify_bases(val: str | None) -> list[str]:
|
|
63
|
+
"""Parse DN bases supplied either as JSON or separated by semicolons/newlines."""
|
|
64
|
+
if not val:
|
|
65
|
+
return []
|
|
66
|
+
s = val.strip()
|
|
67
|
+
# If someone gave us JSON, use it
|
|
68
|
+
if s.startswith("["):
|
|
69
|
+
try:
|
|
70
|
+
arr = json.loads(s)
|
|
71
|
+
return [x.strip() for x in arr if isinstance(x, str) and x.strip()]
|
|
72
|
+
except Exception:
|
|
73
|
+
pass
|
|
74
|
+
# Otherwise split on semicolons or newlines (never commas!)
|
|
75
|
+
parts = re.split(r"[;\n]+", s)
|
|
76
|
+
return [p.strip() for p in parts if p.strip()]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# -----------------------------
|
|
80
|
+
# User model
|
|
81
|
+
# -----------------------------
|
|
82
|
+
@dataclass
|
|
83
|
+
class LdapUser(BaseUser):
|
|
84
|
+
"""Airflow user representation enriched with LDAP metadata."""
|
|
85
|
+
|
|
86
|
+
user_id: str
|
|
87
|
+
username: str | None = None
|
|
88
|
+
email: str | None = None
|
|
89
|
+
groups: list[str] = field(default_factory=list)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class Role(IntEnum):
|
|
93
|
+
"""Simple role ladder used for authorization checks."""
|
|
94
|
+
|
|
95
|
+
NONE = 0
|
|
96
|
+
VIEWER = 1
|
|
97
|
+
EDITOR = 2
|
|
98
|
+
ADMIN = 3
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# -----------------------------
|
|
102
|
+
# Helper: LDAP access
|
|
103
|
+
# -----------------------------
|
|
104
|
+
class LdapClient:
|
|
105
|
+
"""Wrapper around :mod:`ldap3` interactions for authentication searches."""
|
|
106
|
+
|
|
107
|
+
def __init__(self):
|
|
108
|
+
"""Initialise server pool and cached configuration values."""
|
|
109
|
+
uris = [uri.strip() for uri in conf.get("ldap_auth_manager", "server_uri").split(",")]
|
|
110
|
+
|
|
111
|
+
self._bind_dn = _get_sensitive("ldap_auth_manager", "bind_dn")
|
|
112
|
+
self._bind_pw = _get_sensitive("ldap_auth_manager", "bind_password")
|
|
113
|
+
|
|
114
|
+
bases_raw = conf.get("ldap_auth_manager", "user_search_base", fallback="") # can be 1 or many
|
|
115
|
+
self._user_bases = _listify_bases(bases_raw)
|
|
116
|
+
if not self._user_bases:
|
|
117
|
+
raise ValueError("ldap_auth_manager.user_search_base must be set to at least one DN")
|
|
118
|
+
|
|
119
|
+
# self._user_base = conf.get("ldap_auth_manager", "user_search_base")
|
|
120
|
+
self._user_filter_tpl = conf.get(
|
|
121
|
+
"ldap_auth_manager",
|
|
122
|
+
"user_search_filter",
|
|
123
|
+
fallback="(|(uid={username})(sAMAccountName={username})(mail={username}))",
|
|
124
|
+
)
|
|
125
|
+
self._group_base = conf.get("ldap_auth_manager", "group_search_base", fallback=None)
|
|
126
|
+
self._group_member_attr = conf.get("ldap_auth_manager", "group_member_attr", fallback="member")
|
|
127
|
+
self._username_attr = conf.get("ldap_auth_manager", "username_attr", fallback="uid")
|
|
128
|
+
self._email_attr = conf.get("ldap_auth_manager", "email_attr", fallback="mail")
|
|
129
|
+
self._start_tls = conf.getboolean("ldap_auth_manager", "start_tls", fallback=False)
|
|
130
|
+
self._verify_ssl = conf.getboolean("ldap_auth_manager", "verify_ssl", fallback=True)
|
|
131
|
+
self._debug_logging = conf.getboolean("ldap_auth_manager", "debug_logging", fallback=False)
|
|
132
|
+
|
|
133
|
+
servers = []
|
|
134
|
+
for uri in uris:
|
|
135
|
+
lower_uri = uri.lower()
|
|
136
|
+
use_ssl = lower_uri.startswith("ldaps://")
|
|
137
|
+
tls = None
|
|
138
|
+
if use_ssl or self._start_tls:
|
|
139
|
+
# ``Tls`` handles certificate verification; disable only when explicitly requested.
|
|
140
|
+
tls = Tls(validate=ssl.CERT_REQUIRED if self._verify_ssl else ssl.CERT_NONE, version=ssl.PROTOCOL_TLS)
|
|
141
|
+
servers.append(Server(uri, use_ssl=use_ssl, get_info=ALL, tls=tls))
|
|
142
|
+
|
|
143
|
+
self._servers = ServerPool(servers, pool_strategy=ROUND_ROBIN)
|
|
144
|
+
|
|
145
|
+
def _service_conn(self) -> Connection:
|
|
146
|
+
"""Return a bound service connection used for privileged LDAP queries."""
|
|
147
|
+
|
|
148
|
+
auto_bind = AUTO_BIND_TLS_BEFORE_BIND if self._start_tls else AUTO_BIND_NO_TLS
|
|
149
|
+
|
|
150
|
+
if self._start_tls and any(s.ssl for s in self._servers.servers):
|
|
151
|
+
raise ValueError("start_tls=true requires ldap:// (plain) servers, not ldaps://")
|
|
152
|
+
|
|
153
|
+
conn = Connection(
|
|
154
|
+
self._servers, # this is a ServerPool
|
|
155
|
+
user=self._bind_dn or None,
|
|
156
|
+
password=self._bind_pw or None,
|
|
157
|
+
auto_bind=auto_bind,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
if self._debug_logging:
|
|
161
|
+
# after successful auto_bind
|
|
162
|
+
srv = conn.server # the chosen ldap3.Server from the pool
|
|
163
|
+
|
|
164
|
+
use_ssl = getattr(srv, "ssl", None)
|
|
165
|
+
if use_ssl is None:
|
|
166
|
+
use_ssl = getattr(srv, "use_ssl", False)
|
|
167
|
+
|
|
168
|
+
scheme = "ldaps" if use_ssl else "ldap"
|
|
169
|
+
port = srv.port or (636 if use_ssl else 389)
|
|
170
|
+
pool_strategy = getattr(self._servers, "strategy", None) or getattr(self._servers, "pool_strategy", "n/a")
|
|
171
|
+
|
|
172
|
+
log.info(
|
|
173
|
+
f"LDAP bound to {scheme}://{srv.host}:{port} "
|
|
174
|
+
f"(pool_strategy={pool_strategy}, start_tls={self._start_tls})"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# optional: show the authenticated identity
|
|
178
|
+
try:
|
|
179
|
+
who = conn.extend.standard.who_am_i()
|
|
180
|
+
log.info(f"LDAP whoami: {who}")
|
|
181
|
+
except Exception as e:
|
|
182
|
+
log.warning(f"LDAP whoami not available: {e!r}")
|
|
183
|
+
|
|
184
|
+
return conn
|
|
185
|
+
|
|
186
|
+
def authenticate(self, username: str, password: str) -> Optional[AuthenticatedUserData]:
|
|
187
|
+
"""Authenticate a user by binding with their DN and password."""
|
|
188
|
+
with self._service_conn() as svc:
|
|
189
|
+
flt = self._user_filter_tpl.format(username=username)
|
|
190
|
+
entry = None
|
|
191
|
+
for base in self._user_bases:
|
|
192
|
+
if svc.search(
|
|
193
|
+
search_base=base,
|
|
194
|
+
search_filter=flt,
|
|
195
|
+
search_scope=SUBTREE,
|
|
196
|
+
attributes=ALL_ATTRIBUTES,
|
|
197
|
+
):
|
|
198
|
+
entry = svc.entries[0]
|
|
199
|
+
break
|
|
200
|
+
|
|
201
|
+
if not entry:
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
user_dn = entry.entry_dn
|
|
205
|
+
attrs = entry.entry_attributes_as_dict
|
|
206
|
+
|
|
207
|
+
if self._debug_logging:
|
|
208
|
+
log.info(f"LDAP user search matched base={base!r} dn={user_dn!r}")
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
with Connection(self._servers, user=user_dn, password=password, auto_bind=True):
|
|
212
|
+
pass
|
|
213
|
+
except Exception:
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
groups: list[str] = []
|
|
217
|
+
if self._group_base:
|
|
218
|
+
with self._service_conn() as svc2:
|
|
219
|
+
member_attr = self._group_member_attr
|
|
220
|
+
# Some directories store either the DN or the username in the member attribute,
|
|
221
|
+
# so we query for both forms in a single OR filter.
|
|
222
|
+
member_filters = [f"({member_attr}={user_dn})", f"({member_attr}={username})"]
|
|
223
|
+
flt = f"(|{''.join(member_filters)})"
|
|
224
|
+
svc2.search(
|
|
225
|
+
search_base=self._group_base,
|
|
226
|
+
search_filter=flt,
|
|
227
|
+
search_scope=SUBTREE,
|
|
228
|
+
attributes=["cn"],
|
|
229
|
+
)
|
|
230
|
+
for e in svc2.entries:
|
|
231
|
+
groups.append(str(e.cn))
|
|
232
|
+
|
|
233
|
+
username_attr = attrs.get(self._username_attr)
|
|
234
|
+
if isinstance(username_attr, list):
|
|
235
|
+
norm_username = username_attr[0] if username_attr else username
|
|
236
|
+
else:
|
|
237
|
+
norm_username = username_attr or username
|
|
238
|
+
|
|
239
|
+
email_attr = attrs.get(self._email_attr)
|
|
240
|
+
if isinstance(email_attr, list):
|
|
241
|
+
email = email_attr[0] if email_attr else None
|
|
242
|
+
else:
|
|
243
|
+
email = email_attr
|
|
244
|
+
|
|
245
|
+
return {
|
|
246
|
+
"dn": user_dn,
|
|
247
|
+
"attrs": attrs,
|
|
248
|
+
"username": norm_username,
|
|
249
|
+
"email": email,
|
|
250
|
+
"groups": groups,
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# -----------------------------
|
|
255
|
+
# AuthZ policy helpers
|
|
256
|
+
# -----------------------------
|
|
257
|
+
class Policy:
|
|
258
|
+
"""
|
|
259
|
+
Small helper that translates LDAP groups to a single effective Role and
|
|
260
|
+
exposes convenience predicates.
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
def __init__(self):
|
|
264
|
+
# Store as lowercase once; compare on lowercase later.
|
|
265
|
+
self.admin_groups = self._load_group_config("admin_groups")
|
|
266
|
+
self.editor_groups = self._load_group_config("editor_groups")
|
|
267
|
+
self.viewer_groups = self._load_group_config("viewer_groups")
|
|
268
|
+
# Default if user matches no configured groups
|
|
269
|
+
self._default_role = Role.NONE # <— deny by default
|
|
270
|
+
|
|
271
|
+
def _load_group_config(self, option: str) -> set[str]:
|
|
272
|
+
"""Return the configured group list for ``option`` lowered for easy matching."""
|
|
273
|
+
raw_value = conf.get("ldap_auth_manager", option, fallback="")
|
|
274
|
+
return {group.lower() for group in _csv_to_set(raw_value)}
|
|
275
|
+
|
|
276
|
+
def role_for(self, groups: Iterable[str]) -> Role:
|
|
277
|
+
"""Return the highest Role allowed by the supplied group memberships."""
|
|
278
|
+
gs = {g.lower() for g in (groups or [])}
|
|
279
|
+
if self.admin_groups and (gs & self.admin_groups):
|
|
280
|
+
return Role.ADMIN
|
|
281
|
+
if self.editor_groups and (gs & self.editor_groups):
|
|
282
|
+
return Role.EDITOR
|
|
283
|
+
if self.viewer_groups and (gs & self.viewer_groups):
|
|
284
|
+
return Role.VIEWER
|
|
285
|
+
# If no mapping found, default to least privilege
|
|
286
|
+
return self._default_role
|
|
287
|
+
|
|
288
|
+
def at_least(self, groups: Iterable[str], min_role: Role) -> bool:
|
|
289
|
+
"""Return ``True`` if ``groups`` map to a role >= ``min_role``."""
|
|
290
|
+
return self.role_for(groups) >= min_role
|
|
291
|
+
|
|
292
|
+
# kept for readability if you want to use them elsewhere
|
|
293
|
+
def is_admin(self, groups: Iterable[str]) -> bool:
|
|
294
|
+
"""Return ``True`` when ``groups`` grant administrator privileges."""
|
|
295
|
+
return self.at_least(groups, Role.ADMIN)
|
|
296
|
+
|
|
297
|
+
def is_editor(self, groups: Iterable[str]) -> bool:
|
|
298
|
+
"""Return ``True`` when ``groups`` grant editor privileges."""
|
|
299
|
+
return self.at_least(groups, Role.EDITOR)
|
|
300
|
+
|
|
301
|
+
def is_viewer(self, groups: Iterable[str]) -> bool:
|
|
302
|
+
"""Return ``True`` when ``groups`` grant viewer privileges."""
|
|
303
|
+
return self.at_least(groups, Role.VIEWER)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _csv_to_set(val: str) -> set[str]:
|
|
307
|
+
"""Convert a comma separated string into a set of trimmed tokens."""
|
|
308
|
+
return {x.strip() for x in val.split(",") if x.strip()}
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
# -----------------------------
|
|
312
|
+
# LDAP Auth Manager
|
|
313
|
+
# -----------------------------
|
|
314
|
+
class LdapAuthManager(BaseAuthManager[LdapUser]):
|
|
315
|
+
"""LDAP backed implementation of Airflow's :class:`BaseAuthManager`."""
|
|
316
|
+
|
|
317
|
+
def __init__(self, context=None):
|
|
318
|
+
"""Create the manager with configured LDAP client and authorization policy."""
|
|
319
|
+
super().__init__(context=context)
|
|
320
|
+
self._ldap = LdapClient()
|
|
321
|
+
self._policy = Policy()
|
|
322
|
+
self._debug_logging = conf.getboolean("ldap_auth_manager", "debug_logging", fallback=False)
|
|
323
|
+
|
|
324
|
+
# --- Authentication surface ---
|
|
325
|
+
@override
|
|
326
|
+
def get_url_login(self, **kwargs) -> str:
|
|
327
|
+
"""Return the login URL including the ``next`` redirect parameter."""
|
|
328
|
+
next_url = kwargs.get("next", "/")
|
|
329
|
+
return f"/auth/login?next={next_url}"
|
|
330
|
+
|
|
331
|
+
def get_url_logout(self) -> Optional[str]:
|
|
332
|
+
"""Return the configured logout redirect target."""
|
|
333
|
+
return conf.get("ldap_auth_manager", "logout_redirect", fallback="/")
|
|
334
|
+
|
|
335
|
+
@override
|
|
336
|
+
def serialize_user(self, user: LdapUser) -> dict:
|
|
337
|
+
"""Serialize ``LdapUser`` instances for storage inside JWT payloads."""
|
|
338
|
+
return {
|
|
339
|
+
"user_id": user.user_id,
|
|
340
|
+
"username": user.username,
|
|
341
|
+
"email": user.email,
|
|
342
|
+
"groups": user.groups,
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
@override
|
|
346
|
+
def deserialize_user(self, data: dict) -> LdapUser:
|
|
347
|
+
"""Recreate ``LdapUser`` instances from serialized payloads."""
|
|
348
|
+
return LdapUser(
|
|
349
|
+
user_id=str(data.get("user_id")),
|
|
350
|
+
username=data.get("username"),
|
|
351
|
+
email=data.get("email"),
|
|
352
|
+
groups=list(data.get("groups") or []),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
@override
|
|
356
|
+
async def get_user_from_token(self, token: str) -> LdapUser | None:
|
|
357
|
+
"""Decode ``token`` and return a user only if policy permits."""
|
|
358
|
+
import jwt
|
|
359
|
+
from jwt import (ExpiredSignatureError, InvalidAudienceError,
|
|
360
|
+
InvalidSignatureError, InvalidTokenError)
|
|
361
|
+
|
|
362
|
+
secret = conf.get("api_auth", "jwt_secret")
|
|
363
|
+
if not secret:
|
|
364
|
+
return None
|
|
365
|
+
|
|
366
|
+
alg = conf.get("api_auth", "jwt_algorithm", fallback="HS512") or "HS512"
|
|
367
|
+
audience = conf.get("api_auth", "jwt_audience", fallback="urn:airflow.apache.org:api")
|
|
368
|
+
aud = (audience.split(",")[0]).strip() if audience and "," in audience else audience
|
|
369
|
+
issuer = conf.get("api_auth", "jwt_issuer", fallback=None)
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
claims = jwt.decode(
|
|
373
|
+
token,
|
|
374
|
+
secret,
|
|
375
|
+
algorithms=[alg],
|
|
376
|
+
audience=aud if aud else None,
|
|
377
|
+
issuer=issuer if issuer else None,
|
|
378
|
+
options={"require": ["sub", "exp", "iat"], "verify_aud": bool(aud)},
|
|
379
|
+
)
|
|
380
|
+
except (ExpiredSignatureError, InvalidAudienceError, InvalidSignatureError, InvalidTokenError):
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
data = claims.get("user") or {}
|
|
384
|
+
user = self.deserialize_user(data)
|
|
385
|
+
|
|
386
|
+
# Enforce whitelist again at request time
|
|
387
|
+
if self._policy.role_for(user.groups) is Role.NONE:
|
|
388
|
+
return None
|
|
389
|
+
|
|
390
|
+
return user
|
|
391
|
+
|
|
392
|
+
# --- Menu filtering ---
|
|
393
|
+
def filter_authorized_menu_items(self, menu_items, *, user: LdapUser):
|
|
394
|
+
"""Return the menu unchanged; permissions are enforced per endpoint."""
|
|
395
|
+
# Everyone can see the whole menu; the endpoints themselves enforce write restrictions.
|
|
396
|
+
return list(menu_items or [])
|
|
397
|
+
|
|
398
|
+
# --- Authorization surface ---
|
|
399
|
+
def _norm_method(self, method) -> str:
|
|
400
|
+
"""Coerce ``method`` (enum or string) into an uppercase HTTP verb."""
|
|
401
|
+
name = getattr(method, "name", None)
|
|
402
|
+
if name:
|
|
403
|
+
return str(name).upper()
|
|
404
|
+
value = getattr(method, "value", None)
|
|
405
|
+
if value is not None:
|
|
406
|
+
return str(value).upper()
|
|
407
|
+
return str(method).upper()
|
|
408
|
+
|
|
409
|
+
def _is_dag_run_scoped(
|
|
410
|
+
self, write_scope: str | None, access_entity: "rd.DagAccessEntity | None"
|
|
411
|
+
) -> bool:
|
|
412
|
+
"""Return ``True`` when the request targets DAG run level operations."""
|
|
413
|
+
|
|
414
|
+
if write_scope == "dag_run":
|
|
415
|
+
return True
|
|
416
|
+
|
|
417
|
+
if access_entity is None:
|
|
418
|
+
return False
|
|
419
|
+
|
|
420
|
+
try:
|
|
421
|
+
ae_name = getattr(access_entity, "name", None)
|
|
422
|
+
if ae_name is None:
|
|
423
|
+
ae_name = str(access_entity)
|
|
424
|
+
return str(ae_name).upper() in {"DAG_RUN", "DAGRUN", "RUN"}
|
|
425
|
+
except Exception:
|
|
426
|
+
return False
|
|
427
|
+
|
|
428
|
+
def _is_authorized(
|
|
429
|
+
self,
|
|
430
|
+
*,
|
|
431
|
+
method: ResourceMethod | str,
|
|
432
|
+
user: LdapUser,
|
|
433
|
+
access_entity: "rd.DagAccessEntity | None" = None,
|
|
434
|
+
write_scope: str | None = None,
|
|
435
|
+
) -> bool:
|
|
436
|
+
"""
|
|
437
|
+
Central authorization rule-set.
|
|
438
|
+
|
|
439
|
+
Rules:
|
|
440
|
+
- GET -> Viewer+
|
|
441
|
+
- Non-GET:
|
|
442
|
+
- If write_scope == "dag_run" (or DagAccessEntity.DAG_RUN), Editor+
|
|
443
|
+
- Else Admin only
|
|
444
|
+
"""
|
|
445
|
+
role = self._policy.role_for(user.groups)
|
|
446
|
+
if role == Role.NONE:
|
|
447
|
+
return False # deny outright
|
|
448
|
+
|
|
449
|
+
m = self._norm_method(method)
|
|
450
|
+
|
|
451
|
+
# Always allow reads
|
|
452
|
+
if m == "GET":
|
|
453
|
+
return self._policy.at_least(user.groups, Role.VIEWER)
|
|
454
|
+
|
|
455
|
+
# Non-GET (write-ish)
|
|
456
|
+
# Detect dag-run scoped writes either via explicit marker or DagAccessEntity
|
|
457
|
+
if self._is_dag_run_scoped(write_scope, access_entity):
|
|
458
|
+
return self._policy.at_least(user.groups, Role.EDITOR)
|
|
459
|
+
|
|
460
|
+
# Everything else requires full admin
|
|
461
|
+
return self._policy.at_least(user.groups, Role.ADMIN)
|
|
462
|
+
|
|
463
|
+
@override
|
|
464
|
+
def is_authorized_configuration(
|
|
465
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.ConfigurationDetails | None = None
|
|
466
|
+
) -> bool:
|
|
467
|
+
"""Allow configuration access for admins, read-only for others."""
|
|
468
|
+
# Configuration changes are admin-only; reads are fine for all.
|
|
469
|
+
return self._is_authorized(method=method, user=user)
|
|
470
|
+
|
|
471
|
+
@override
|
|
472
|
+
def is_authorized_connection(
|
|
473
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.ConnectionDetails | None = None
|
|
474
|
+
) -> bool:
|
|
475
|
+
"""Authorize access to individual connections using the global policy."""
|
|
476
|
+
return self._is_authorized(method=method, user=user)
|
|
477
|
+
|
|
478
|
+
@override
|
|
479
|
+
def batch_is_authorized_connection(
|
|
480
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.ConnectionDetails | None = None
|
|
481
|
+
) -> bool:
|
|
482
|
+
"""Apply the same rules as ``is_authorized_connection`` for batch endpoints."""
|
|
483
|
+
return self._is_authorized(method=method, user=user)
|
|
484
|
+
|
|
485
|
+
@override
|
|
486
|
+
def batch_is_authorized_variable(
|
|
487
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.VariableDetails | None = None
|
|
488
|
+
) -> bool:
|
|
489
|
+
"""Apply standard policy to batch variable endpoints."""
|
|
490
|
+
return self._is_authorized(method=method, user=user)
|
|
491
|
+
|
|
492
|
+
@override
|
|
493
|
+
def is_authorized_variable(
|
|
494
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.VariableDetails | None = None
|
|
495
|
+
) -> bool:
|
|
496
|
+
"""Authorize individual variable operations via the central policy."""
|
|
497
|
+
return self._is_authorized(method=method, user=user)
|
|
498
|
+
|
|
499
|
+
@override
|
|
500
|
+
def batch_is_authorized_pool(
|
|
501
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.PoolDetails | None = None
|
|
502
|
+
) -> bool:
|
|
503
|
+
"""Apply the same rules as pool single-item operations."""
|
|
504
|
+
return self._is_authorized(method=method, user=user)
|
|
505
|
+
|
|
506
|
+
@override
|
|
507
|
+
def is_authorized_pool(
|
|
508
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.PoolDetails | None = None
|
|
509
|
+
) -> bool:
|
|
510
|
+
"""Authorize individual pool operations via the shared policy."""
|
|
511
|
+
return self._is_authorized(method=method, user=user)
|
|
512
|
+
|
|
513
|
+
@override
|
|
514
|
+
def is_authorized_asset(
|
|
515
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.AssetDetails | None = None
|
|
516
|
+
) -> bool:
|
|
517
|
+
"""Authorize asset interactions via the shared policy."""
|
|
518
|
+
return self._is_authorized(method=method, user=user)
|
|
519
|
+
|
|
520
|
+
@override
|
|
521
|
+
def is_authorized_asset_alias(
|
|
522
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.AssetAliasDetails | None = None
|
|
523
|
+
) -> bool:
|
|
524
|
+
"""Authorize asset alias interactions via the shared policy."""
|
|
525
|
+
return self._is_authorized(method=method, user=user)
|
|
526
|
+
|
|
527
|
+
@override
|
|
528
|
+
def is_authorized_backfill(
|
|
529
|
+
self, *, method: ResourceMethod, user: LdapUser, details: rd.BackfillDetails | None = None
|
|
530
|
+
) -> bool:
|
|
531
|
+
"""Require admin access for disruptive backfill operations."""
|
|
532
|
+
# Backfills are disruptive -> admin-only for writes; reads for all
|
|
533
|
+
return self._is_authorized(method=method, user=user, write_scope=None)
|
|
534
|
+
|
|
535
|
+
@override
|
|
536
|
+
def is_authorized_dag(
|
|
537
|
+
self,
|
|
538
|
+
*,
|
|
539
|
+
method: ResourceMethod,
|
|
540
|
+
user: LdapUser,
|
|
541
|
+
access_entity: rd.DagAccessEntity | None = None,
|
|
542
|
+
details: rd.DagDetails | None = None,
|
|
543
|
+
) -> bool:
|
|
544
|
+
"""Authorize DAG operations, allowing editors to manage DAG runs."""
|
|
545
|
+
# Let editor write only when the operation targets DAG RUNs
|
|
546
|
+
return self._is_authorized(method=method, user=user, access_entity=access_entity, write_scope=None)
|
|
547
|
+
|
|
548
|
+
@override
|
|
549
|
+
def batch_is_authorized_dag(
|
|
550
|
+
self,
|
|
551
|
+
*,
|
|
552
|
+
method: ResourceMethod,
|
|
553
|
+
user: LdapUser,
|
|
554
|
+
access_entity: rd.DagAccessEntity | None = None,
|
|
555
|
+
details: rd.DagDetails | None = None,
|
|
556
|
+
) -> bool:
|
|
557
|
+
"""Batch DAG endpoints follow the same rules as single DAG operations."""
|
|
558
|
+
# Let editor write only when the operation targets DAG RUNs
|
|
559
|
+
return self._is_authorized(method=method, user=user, access_entity=access_entity, write_scope=None)
|
|
560
|
+
|
|
561
|
+
@override
|
|
562
|
+
def is_authorized_view(self, *, access_view: rd.AccessView, user: LdapUser) -> bool:
|
|
563
|
+
"""Permit viewer-level access to read-only views."""
|
|
564
|
+
# Views are read-only by design
|
|
565
|
+
return self._policy.at_least(user.groups, Role.VIEWER)
|
|
566
|
+
|
|
567
|
+
@override
|
|
568
|
+
def is_authorized_custom_view(self, *, method: ResourceMethod | str, resource_name: str, user: LdapUser) -> bool:
|
|
569
|
+
"""Authorize arbitrary custom views using the default policy."""
|
|
570
|
+
return self._is_authorized(method=method, user=user)
|
|
571
|
+
|
|
572
|
+
# -----------------------------
|
|
573
|
+
# FastAPI extension for login/token/logout
|
|
574
|
+
# -----------------------------
|
|
575
|
+
def get_fastapi_app(self) -> Optional[FastAPI]:
|
|
576
|
+
router = APIRouter()
|
|
577
|
+
|
|
578
|
+
base_dir = Path(__file__).parent
|
|
579
|
+
template_dir = base_dir / "templates"
|
|
580
|
+
static_dir = base_dir / "static"
|
|
581
|
+
jinja_env = Environment(loader=FileSystemLoader(template_dir), autoescape=True)
|
|
582
|
+
|
|
583
|
+
instance_name = conf.get("api", "instance_name", fallback="Airflow")
|
|
584
|
+
login_tip = conf.get("ldap_auth_manager", "login_tip", fallback="")
|
|
585
|
+
|
|
586
|
+
jinja_env.globals.update(instance_name=instance_name, login_tip=login_tip)
|
|
587
|
+
|
|
588
|
+
def render(name: str, **ctx) -> HTMLResponse:
|
|
589
|
+
"""Render ``name`` with the provided context."""
|
|
590
|
+
tpl = jinja_env.get_template(name)
|
|
591
|
+
return HTMLResponse(tpl.render(**ctx))
|
|
592
|
+
|
|
593
|
+
def _sanitize_next(raw_next: Optional[str], request: Request) -> str:
|
|
594
|
+
"""Return a safe, same-origin path for redirect.
|
|
595
|
+
- unwrap nested next parameters once (e.g. "/?next=/graph")
|
|
596
|
+
- drop absolute URLs to other hosts
|
|
597
|
+
- default to '/'
|
|
598
|
+
"""
|
|
599
|
+
target = (raw_next or "/").strip()
|
|
600
|
+
try:
|
|
601
|
+
from urllib.parse import parse_qs, urlparse
|
|
602
|
+
|
|
603
|
+
# unwrap one level of nested 'next'
|
|
604
|
+
qn = parse_qs(urlparse(target).query).get("next", [])
|
|
605
|
+
if qn:
|
|
606
|
+
target = qn[0]
|
|
607
|
+
|
|
608
|
+
# allow only same-origin relative paths
|
|
609
|
+
if target.startswith("/") and not target.startswith("//"):
|
|
610
|
+
return target
|
|
611
|
+
|
|
612
|
+
# if absolute, ensure same host
|
|
613
|
+
u = urlparse(target)
|
|
614
|
+
req = urlparse(str(request.base_url))
|
|
615
|
+
if u.scheme and u.netloc and u.netloc == req.netloc:
|
|
616
|
+
return u.path or "/"
|
|
617
|
+
except Exception:
|
|
618
|
+
pass
|
|
619
|
+
return "/"
|
|
620
|
+
|
|
621
|
+
@router.get("/login", response_class=HTMLResponse)
|
|
622
|
+
def login_form(next: str = "/", error: str | None = None):
|
|
623
|
+
"""Serve the HTML login form."""
|
|
624
|
+
return render("ldap_login.html", next=next, error=error)
|
|
625
|
+
|
|
626
|
+
@router.post("/token")
|
|
627
|
+
async def create_token(request: Request):
|
|
628
|
+
"""Authenticate the user and return/issue a JWT token."""
|
|
629
|
+
# --- parse input: JSON or form ---
|
|
630
|
+
username = password = next_param = None
|
|
631
|
+
ct = (request.headers.get("content-type") or "").lower()
|
|
632
|
+
|
|
633
|
+
if ct.startswith("application/x-www-form-urlencoded") or ct.startswith("multipart/form-data"):
|
|
634
|
+
form = await request.form()
|
|
635
|
+
username = form.get("username")
|
|
636
|
+
password = form.get("password")
|
|
637
|
+
next_param = form.get("next")
|
|
638
|
+
else:
|
|
639
|
+
# Try JSON regardless of Content-Type (some clients lie)
|
|
640
|
+
try:
|
|
641
|
+
payload = await request.json()
|
|
642
|
+
except Exception:
|
|
643
|
+
payload = None
|
|
644
|
+
if isinstance(payload, dict):
|
|
645
|
+
username = payload.get("username")
|
|
646
|
+
password = payload.get("password")
|
|
647
|
+
next_param = payload.get("next")
|
|
648
|
+
|
|
649
|
+
if not username or not password:
|
|
650
|
+
return JSONResponse({"detail": "username and password are required"}, status_code=422)
|
|
651
|
+
|
|
652
|
+
# --- authenticate via LDAP as you already do ---
|
|
653
|
+
info = self._ldap.authenticate(username=username, password=password) # type: ignore
|
|
654
|
+
if not info:
|
|
655
|
+
if "application/json" in (request.headers.get("accept") or ""):
|
|
656
|
+
return JSONResponse({"detail": "Invalid credentials"}, status_code=401)
|
|
657
|
+
target = f"/auth/login?next={(next_param or '/')}&&error=Invalid%20credentials"
|
|
658
|
+
return RedirectResponse(url=target, status_code=303)
|
|
659
|
+
|
|
660
|
+
user = LdapUser(
|
|
661
|
+
user_id=info["dn"], # type: ignore
|
|
662
|
+
username=info.get("username"),
|
|
663
|
+
email=info.get("email"),
|
|
664
|
+
groups=info.get("groups", []),
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
# role check...
|
|
668
|
+
role = self._policy.role_for(user.groups)
|
|
669
|
+
if self._debug_logging:
|
|
670
|
+
log.info(f"LDAP login: user={user.username} groups={user.groups} role={role.name}")
|
|
671
|
+
if role is Role.NONE:
|
|
672
|
+
msg = "You are not a member of any Airflow access group"
|
|
673
|
+
if "application/json" in (request.headers.get("accept") or ""):
|
|
674
|
+
return JSONResponse({"detail": msg}, status_code=403)
|
|
675
|
+
target = f"/auth/login?next={(next_param or '/')}&&error=" + msg.replace(" ", "%20")
|
|
676
|
+
return RedirectResponse(url=target, status_code=303)
|
|
677
|
+
|
|
678
|
+
# --- mint JWT ---
|
|
679
|
+
from datetime import datetime, timedelta, timezone
|
|
680
|
+
|
|
681
|
+
import jwt
|
|
682
|
+
|
|
683
|
+
secret = conf.get("api_auth", "jwt_secret")
|
|
684
|
+
if not secret:
|
|
685
|
+
return JSONResponse({"detail": "api_auth.jwt_secret is not set"}, status_code=500)
|
|
686
|
+
|
|
687
|
+
alg = conf.get("api_auth", "jwt_algorithm", fallback="HS512") or "HS512"
|
|
688
|
+
audience = conf.get("api_auth", "jwt_audience", fallback="urn:airflow.apache.org:api")
|
|
689
|
+
aud = (audience.split(",")[0]).strip() if audience and "," in audience else audience
|
|
690
|
+
issuer = conf.get("api_auth", "jwt_issuer", fallback=None)
|
|
691
|
+
kid = conf.get("api_auth", "jwt_kid", fallback=None)
|
|
692
|
+
exp_secs = conf.getint("api_auth", "jwt_expiration_time", fallback=36000)
|
|
693
|
+
|
|
694
|
+
now = datetime.now(timezone.utc)
|
|
695
|
+
claims = {
|
|
696
|
+
"sub": user.user_id,
|
|
697
|
+
"iat": int(now.timestamp()),
|
|
698
|
+
"nbf": int(now.timestamp()) - 5,
|
|
699
|
+
"exp": int((now + timedelta(seconds=exp_secs)).timestamp()),
|
|
700
|
+
"aud": aud,
|
|
701
|
+
"user": self.serialize_user(user),
|
|
702
|
+
}
|
|
703
|
+
if issuer:
|
|
704
|
+
claims["iss"] = issuer
|
|
705
|
+
headers = {"kid": kid} if kid else None
|
|
706
|
+
token = jwt.encode(claims, secret, algorithm=alg, headers=headers)
|
|
707
|
+
|
|
708
|
+
# --- respond JSON for API callers, cookie+303 for browsers ---
|
|
709
|
+
wants_json = "application/json" in (request.headers.get("accept") or "")
|
|
710
|
+
if wants_json:
|
|
711
|
+
return JSONResponse({
|
|
712
|
+
"access_token": token,
|
|
713
|
+
"token_type": "Bearer",
|
|
714
|
+
"expires_in": exp_secs,
|
|
715
|
+
})
|
|
716
|
+
|
|
717
|
+
target = _sanitize_next(next_param, request) # type: ignore[arg-type]
|
|
718
|
+
resp = RedirectResponse(url=target or "/", status_code=303)
|
|
719
|
+
secure = (request.base_url.scheme == "https") or bool(conf.get("api", "ssl_cert", fallback=""))
|
|
720
|
+
resp.set_cookie(COOKIE_NAME_JWT_TOKEN, token, secure=secure, httponly=False, samesite="lax", path="/")
|
|
721
|
+
return resp
|
|
722
|
+
|
|
723
|
+
@router.get("/logout")
|
|
724
|
+
def logout(next: str = "/"):
|
|
725
|
+
"""Clear the JWT cookie and redirect to ``next`` or the configured URL."""
|
|
726
|
+
resp = RedirectResponse(url=conf.get("ldap_auth_manager", "logout_redirect", fallback=next))
|
|
727
|
+
resp.delete_cookie(COOKIE_NAME_JWT_TOKEN, path="/")
|
|
728
|
+
return resp
|
|
729
|
+
|
|
730
|
+
app = FastAPI()
|
|
731
|
+
app.mount("/static", StaticFiles(directory=str(static_dir)), name="auth-static")
|
|
732
|
+
app.include_router(router)
|
|
733
|
+
return app
|