dq-made-easy-utils 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dq_made_easy_utils-0.1.0/PKG-INFO +14 -0
- dq_made_easy_utils-0.1.0/README.md +5 -0
- dq_made_easy_utils-0.1.0/pyproject.toml +20 -0
- dq_made_easy_utils-0.1.0/setup.cfg +4 -0
- dq_made_easy_utils-0.1.0/src/dq_made_easy_utils.egg-info/PKG-INFO +14 -0
- dq_made_easy_utils-0.1.0/src/dq_made_easy_utils.egg-info/SOURCES.txt +15 -0
- dq_made_easy_utils-0.1.0/src/dq_made_easy_utils.egg-info/dependency_links.txt +1 -0
- dq_made_easy_utils-0.1.0/src/dq_made_easy_utils.egg-info/requires.txt +2 -0
- dq_made_easy_utils-0.1.0/src/dq_made_easy_utils.egg-info/top_level.txt +1 -0
- dq_made_easy_utils-0.1.0/src/dq_utils/__init__.py +11 -0
- dq_made_easy_utils-0.1.0/src/dq_utils/auth_utils.py +303 -0
- dq_made_easy_utils-0.1.0/src/dq_utils/internal_api_contracts.py +185 -0
- dq_made_easy_utils-0.1.0/src/dq_utils/logging_utils.py +76 -0
- dq_made_easy_utils-0.1.0/src/dq_utils/spark_jars.py +99 -0
- dq_made_easy_utils-0.1.0/src/dq_utils/spark_runtime.py +66 -0
- dq_made_easy_utils-0.1.0/tests/test_auth_utils.py +143 -0
- dq_made_easy_utils-0.1.0/tests/test_logging_utils.py +83 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dq-made-easy-utils
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Shared utilities for dq-made-easy Python services
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: jsonschema>=4.25.1
|
|
8
|
+
Requires-Dist: requests>=2.32.0
|
|
9
|
+
|
|
10
|
+
# dq-made-easy-utils
|
|
11
|
+
|
|
12
|
+
Shared Python utilities used across dq-made-easy services.
|
|
13
|
+
|
|
14
|
+
Import package name: `dq_utils`.
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=69", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "dq-made-easy-utils"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Shared utilities for dq-made-easy Python services"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"jsonschema>=4.25.1",
|
|
13
|
+
"requests>=2.32.0",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[tool.setuptools]
|
|
17
|
+
package-dir = {"" = "src"}
|
|
18
|
+
|
|
19
|
+
[tool.setuptools.packages.find]
|
|
20
|
+
where = ["src"]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dq-made-easy-utils
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Shared utilities for dq-made-easy Python services
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: jsonschema>=4.25.1
|
|
8
|
+
Requires-Dist: requests>=2.32.0
|
|
9
|
+
|
|
10
|
+
# dq-made-easy-utils
|
|
11
|
+
|
|
12
|
+
Shared Python utilities used across dq-made-easy services.
|
|
13
|
+
|
|
14
|
+
Import package name: `dq_utils`.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
src/dq_made_easy_utils.egg-info/PKG-INFO
|
|
4
|
+
src/dq_made_easy_utils.egg-info/SOURCES.txt
|
|
5
|
+
src/dq_made_easy_utils.egg-info/dependency_links.txt
|
|
6
|
+
src/dq_made_easy_utils.egg-info/requires.txt
|
|
7
|
+
src/dq_made_easy_utils.egg-info/top_level.txt
|
|
8
|
+
src/dq_utils/__init__.py
|
|
9
|
+
src/dq_utils/auth_utils.py
|
|
10
|
+
src/dq_utils/internal_api_contracts.py
|
|
11
|
+
src/dq_utils/logging_utils.py
|
|
12
|
+
src/dq_utils/spark_jars.py
|
|
13
|
+
src/dq_utils/spark_runtime.py
|
|
14
|
+
tests/test_auth_utils.py
|
|
15
|
+
tests/test_logging_utils.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
dq_utils
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Shared utilities for dq-made-easy Python services."""
|
|
2
|
+
|
|
3
|
+
from dq_utils.internal_api_contracts import InternalApiContractLookupError
|
|
4
|
+
from dq_utils.internal_api_contracts import InternalApiContractRegistry
|
|
5
|
+
from dq_utils.internal_api_contracts import InternalApiContractValidationError
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"InternalApiContractLookupError",
|
|
9
|
+
"InternalApiContractRegistry",
|
|
10
|
+
"InternalApiContractValidationError",
|
|
11
|
+
]
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Protocol
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AuthConfigError(RuntimeError):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TokenProvider(Protocol):
|
|
16
|
+
def get_token(self, *, correlation_id: str) -> str: ...
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class TokenBundle:
|
|
21
|
+
access_token: str
|
|
22
|
+
expires_at_epoch_seconds: float
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class StaticTokenProvider:
|
|
26
|
+
def __init__(self, token: str) -> None:
|
|
27
|
+
token = str(token or "").strip()
|
|
28
|
+
if not token:
|
|
29
|
+
raise AuthConfigError("Static token is empty")
|
|
30
|
+
self._token = token
|
|
31
|
+
|
|
32
|
+
def get_token(self, *, correlation_id: str) -> str:
|
|
33
|
+
_ = correlation_id
|
|
34
|
+
return self._token
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class OidcClientCredentialsTokenProvider:
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
*,
|
|
41
|
+
token_url: str,
|
|
42
|
+
client_id: str,
|
|
43
|
+
client_secret: str,
|
|
44
|
+
scope: str | None = None,
|
|
45
|
+
refresh_skew_seconds: int = 60,
|
|
46
|
+
timeout_seconds: int = 10,
|
|
47
|
+
) -> None:
|
|
48
|
+
token_url = str(token_url or "").strip()
|
|
49
|
+
client_id = str(client_id or "").strip()
|
|
50
|
+
client_secret = str(client_secret or "").strip()
|
|
51
|
+
scope = str(scope or "").strip() or None
|
|
52
|
+
|
|
53
|
+
if not token_url:
|
|
54
|
+
raise AuthConfigError("OIDC token_url is required")
|
|
55
|
+
if not client_id:
|
|
56
|
+
raise AuthConfigError("OIDC client_id is required")
|
|
57
|
+
if not client_secret:
|
|
58
|
+
raise AuthConfigError("OIDC client_secret is required")
|
|
59
|
+
|
|
60
|
+
self._token_url = token_url
|
|
61
|
+
self._client_id = client_id
|
|
62
|
+
self._client_secret = client_secret
|
|
63
|
+
self._scope = scope
|
|
64
|
+
self._refresh_skew_seconds = int(refresh_skew_seconds)
|
|
65
|
+
self._timeout_seconds = int(timeout_seconds)
|
|
66
|
+
self._cached: TokenBundle | None = None
|
|
67
|
+
|
|
68
|
+
def get_token(self, *, correlation_id: str) -> str:
|
|
69
|
+
now = time.time()
|
|
70
|
+
if self._cached is not None and (self._cached.expires_at_epoch_seconds - self._refresh_skew_seconds) > now:
|
|
71
|
+
return self._cached.access_token
|
|
72
|
+
|
|
73
|
+
data: dict[str, str] = {
|
|
74
|
+
"grant_type": "client_credentials",
|
|
75
|
+
"client_id": self._client_id,
|
|
76
|
+
"client_secret": self._client_secret,
|
|
77
|
+
}
|
|
78
|
+
if self._scope:
|
|
79
|
+
data["scope"] = self._scope
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
response = requests.post(
|
|
83
|
+
self._token_url,
|
|
84
|
+
data=data,
|
|
85
|
+
headers={"X-Correlation-ID": correlation_id},
|
|
86
|
+
timeout=self._timeout_seconds,
|
|
87
|
+
)
|
|
88
|
+
except Exception as exc:
|
|
89
|
+
raise AuthConfigError(
|
|
90
|
+
f"Unable to obtain OIDC access token (token endpoint unreachable at '{self._token_url}')"
|
|
91
|
+
) from exc
|
|
92
|
+
|
|
93
|
+
if response.status_code >= 400:
|
|
94
|
+
raise AuthConfigError(
|
|
95
|
+
f"Unable to obtain OIDC access token (token endpoint returned {response.status_code})"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
payload = response.json()
|
|
100
|
+
except Exception as exc:
|
|
101
|
+
raise AuthConfigError("OIDC token endpoint returned non-JSON response") from exc
|
|
102
|
+
|
|
103
|
+
token = str(payload.get("access_token") or "").strip()
|
|
104
|
+
expires_in = payload.get("expires_in")
|
|
105
|
+
try:
|
|
106
|
+
expires_in_seconds = int(expires_in)
|
|
107
|
+
except Exception:
|
|
108
|
+
expires_in_seconds = 0
|
|
109
|
+
|
|
110
|
+
if not token:
|
|
111
|
+
raise AuthConfigError("OIDC token endpoint response missing access_token")
|
|
112
|
+
if expires_in_seconds <= 0:
|
|
113
|
+
raise AuthConfigError("OIDC token endpoint response missing/invalid expires_in")
|
|
114
|
+
|
|
115
|
+
self._cached = TokenBundle(
|
|
116
|
+
access_token=token,
|
|
117
|
+
expires_at_epoch_seconds=now + float(expires_in_seconds),
|
|
118
|
+
)
|
|
119
|
+
return token
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class OidcPasswordTokenProvider:
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
*,
|
|
126
|
+
token_url: str,
|
|
127
|
+
client_id: str,
|
|
128
|
+
username: str,
|
|
129
|
+
password: str,
|
|
130
|
+
client_secret: str | None = None,
|
|
131
|
+
scope: str | None = None,
|
|
132
|
+
refresh_skew_seconds: int = 60,
|
|
133
|
+
timeout_seconds: int = 10,
|
|
134
|
+
) -> None:
|
|
135
|
+
token_url = str(token_url or "").strip()
|
|
136
|
+
client_id = str(client_id or "").strip()
|
|
137
|
+
username = str(username or "").strip()
|
|
138
|
+
password = str(password or "").strip()
|
|
139
|
+
client_secret = str(client_secret or "").strip() or None
|
|
140
|
+
scope = str(scope or "").strip() or None
|
|
141
|
+
|
|
142
|
+
if not token_url:
|
|
143
|
+
raise AuthConfigError("OIDC token_url is required")
|
|
144
|
+
if not client_id:
|
|
145
|
+
raise AuthConfigError("OIDC client_id is required")
|
|
146
|
+
if not username:
|
|
147
|
+
raise AuthConfigError("OIDC username is required")
|
|
148
|
+
if not password:
|
|
149
|
+
raise AuthConfigError("OIDC password is required")
|
|
150
|
+
|
|
151
|
+
self._token_url = token_url
|
|
152
|
+
self._client_id = client_id
|
|
153
|
+
self._username = username
|
|
154
|
+
self._password = password
|
|
155
|
+
self._client_secret = client_secret
|
|
156
|
+
self._scope = scope
|
|
157
|
+
self._refresh_skew_seconds = int(refresh_skew_seconds)
|
|
158
|
+
self._timeout_seconds = int(timeout_seconds)
|
|
159
|
+
self._cached: TokenBundle | None = None
|
|
160
|
+
|
|
161
|
+
def get_token(self, *, correlation_id: str) -> str:
|
|
162
|
+
now = time.time()
|
|
163
|
+
if self._cached is not None and (self._cached.expires_at_epoch_seconds - self._refresh_skew_seconds) > now:
|
|
164
|
+
return self._cached.access_token
|
|
165
|
+
|
|
166
|
+
data: dict[str, str] = {
|
|
167
|
+
"grant_type": "password",
|
|
168
|
+
"client_id": self._client_id,
|
|
169
|
+
"username": self._username,
|
|
170
|
+
"password": self._password,
|
|
171
|
+
}
|
|
172
|
+
if self._client_secret:
|
|
173
|
+
data["client_secret"] = self._client_secret
|
|
174
|
+
if self._scope:
|
|
175
|
+
data["scope"] = self._scope
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
response = requests.post(
|
|
179
|
+
self._token_url,
|
|
180
|
+
data=data,
|
|
181
|
+
headers={"X-Correlation-ID": correlation_id},
|
|
182
|
+
timeout=self._timeout_seconds,
|
|
183
|
+
)
|
|
184
|
+
except Exception as exc:
|
|
185
|
+
raise AuthConfigError(
|
|
186
|
+
f"Unable to obtain OIDC access token (token endpoint unreachable at '{self._token_url}')"
|
|
187
|
+
) from exc
|
|
188
|
+
|
|
189
|
+
if response.status_code >= 400:
|
|
190
|
+
raise AuthConfigError(
|
|
191
|
+
f"Unable to obtain OIDC access token (token endpoint returned {response.status_code})"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
payload = response.json()
|
|
196
|
+
except Exception as exc:
|
|
197
|
+
raise AuthConfigError("OIDC token endpoint returned non-JSON response") from exc
|
|
198
|
+
|
|
199
|
+
token = str(payload.get("access_token") or "").strip()
|
|
200
|
+
expires_in = payload.get("expires_in")
|
|
201
|
+
try:
|
|
202
|
+
expires_in_seconds = int(expires_in)
|
|
203
|
+
except Exception:
|
|
204
|
+
expires_in_seconds = 0
|
|
205
|
+
|
|
206
|
+
if not token:
|
|
207
|
+
raise AuthConfigError("OIDC token endpoint response missing access_token")
|
|
208
|
+
if expires_in_seconds <= 0:
|
|
209
|
+
raise AuthConfigError("OIDC token endpoint response missing/invalid expires_in")
|
|
210
|
+
|
|
211
|
+
self._cached = TokenBundle(
|
|
212
|
+
access_token=token,
|
|
213
|
+
expires_at_epoch_seconds=now + float(expires_in_seconds),
|
|
214
|
+
)
|
|
215
|
+
return token
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def resolve_oidc_token_url(*, issuer: str | None, token_url: str | None) -> str | None:
|
|
219
|
+
token_url_value = str(token_url or "").strip()
|
|
220
|
+
if token_url_value:
|
|
221
|
+
return token_url_value
|
|
222
|
+
|
|
223
|
+
issuer_value = str(issuer or "").strip().rstrip("/")
|
|
224
|
+
if issuer_value:
|
|
225
|
+
return issuer_value + "/protocol/openid-connect/token"
|
|
226
|
+
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def build_token_provider_from_env(
|
|
231
|
+
*,
|
|
232
|
+
static_token_env_var: str,
|
|
233
|
+
issuer_env_var: str,
|
|
234
|
+
token_url_env_var: str,
|
|
235
|
+
client_id_env_var: str,
|
|
236
|
+
client_secret_env_var: str,
|
|
237
|
+
scope_env_var: str,
|
|
238
|
+
refresh_skew_seconds: int = 60,
|
|
239
|
+
) -> TokenProvider:
|
|
240
|
+
static_token = str(os.getenv(static_token_env_var) or "").strip()
|
|
241
|
+
if static_token:
|
|
242
|
+
return StaticTokenProvider(static_token)
|
|
243
|
+
|
|
244
|
+
token_url = resolve_oidc_token_url(
|
|
245
|
+
issuer=os.getenv(issuer_env_var),
|
|
246
|
+
token_url=os.getenv(token_url_env_var),
|
|
247
|
+
)
|
|
248
|
+
client_id = str(os.getenv(client_id_env_var) or "").strip()
|
|
249
|
+
client_secret = str(os.getenv(client_secret_env_var) or "").strip()
|
|
250
|
+
scope = str(os.getenv(scope_env_var) or "").strip() or None
|
|
251
|
+
|
|
252
|
+
if token_url and client_id and client_secret:
|
|
253
|
+
return OidcClientCredentialsTokenProvider(
|
|
254
|
+
token_url=token_url,
|
|
255
|
+
client_id=client_id,
|
|
256
|
+
client_secret=client_secret,
|
|
257
|
+
scope=scope,
|
|
258
|
+
refresh_skew_seconds=refresh_skew_seconds,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
raise AuthConfigError(
|
|
262
|
+
"Auth is not configured. Set a static bearer token in "
|
|
263
|
+
f"{static_token_env_var}, or configure OIDC client credentials using "
|
|
264
|
+
f"({issuer_env_var} or {token_url_env_var}) plus {client_id_env_var} and {client_secret_env_var}."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def build_oidc_token_provider_from_env(
|
|
269
|
+
*,
|
|
270
|
+
issuer_env_var: str,
|
|
271
|
+
token_url_env_var: str,
|
|
272
|
+
client_id_env_var: str,
|
|
273
|
+
client_secret_env_var: str,
|
|
274
|
+
scope_env_var: str,
|
|
275
|
+
refresh_skew_seconds: int = 60,
|
|
276
|
+
) -> TokenProvider:
|
|
277
|
+
"""Build an OIDC client-credentials token provider from env.
|
|
278
|
+
|
|
279
|
+
This intentionally does not support static bearer tokens. Callers that need
|
|
280
|
+
fail-fast token rotation should use this helper.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
token_url = resolve_oidc_token_url(
|
|
284
|
+
issuer=os.getenv(issuer_env_var),
|
|
285
|
+
token_url=os.getenv(token_url_env_var),
|
|
286
|
+
)
|
|
287
|
+
client_id = str(os.getenv(client_id_env_var) or "").strip()
|
|
288
|
+
client_secret = str(os.getenv(client_secret_env_var) or "").strip()
|
|
289
|
+
scope = str(os.getenv(scope_env_var) or "").strip() or None
|
|
290
|
+
|
|
291
|
+
if token_url and client_id and client_secret:
|
|
292
|
+
return OidcClientCredentialsTokenProvider(
|
|
293
|
+
token_url=token_url,
|
|
294
|
+
client_id=client_id,
|
|
295
|
+
client_secret=client_secret,
|
|
296
|
+
scope=scope,
|
|
297
|
+
refresh_skew_seconds=refresh_skew_seconds,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
raise AuthConfigError(
|
|
301
|
+
"OIDC auth is not configured. Configure OIDC client credentials using "
|
|
302
|
+
f"({issuer_env_var} or {token_url_env_var}) plus {client_id_env_var} and {client_secret_env_var}."
|
|
303
|
+
)
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from jsonschema import Draft202012Validator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _format_json_path(segments: tuple[Any, ...]) -> str:
|
|
12
|
+
path = "$"
|
|
13
|
+
for segment in segments:
|
|
14
|
+
if isinstance(segment, int):
|
|
15
|
+
path += f"[{segment}]"
|
|
16
|
+
continue
|
|
17
|
+
text = str(segment)
|
|
18
|
+
if text.isidentifier():
|
|
19
|
+
path += f".{text}"
|
|
20
|
+
continue
|
|
21
|
+
path += f"[{json.dumps(text)}]"
|
|
22
|
+
return path
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class ContractValidationIssue:
|
|
27
|
+
json_path: str
|
|
28
|
+
schema_path: str
|
|
29
|
+
message: str
|
|
30
|
+
validator: str
|
|
31
|
+
|
|
32
|
+
def as_dict(self) -> dict[str, str]:
|
|
33
|
+
return {
|
|
34
|
+
"json_path": self.json_path,
|
|
35
|
+
"schema_path": self.schema_path,
|
|
36
|
+
"message": self.message,
|
|
37
|
+
"validator": self.validator,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class OperationContract:
|
|
43
|
+
version: str
|
|
44
|
+
method: str
|
|
45
|
+
path: str
|
|
46
|
+
operation_id: str
|
|
47
|
+
request_body_required: bool
|
|
48
|
+
request_body_schema_ref: str | None
|
|
49
|
+
request_content_types: tuple[str, ...]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class InternalApiContractLookupError(RuntimeError):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class InternalApiContractValidationError(RuntimeError):
|
|
57
|
+
def __init__(self, operation: OperationContract, issues: list[ContractValidationIssue]) -> None:
|
|
58
|
+
self.operation = operation
|
|
59
|
+
self.issues = tuple(issues)
|
|
60
|
+
super().__init__(
|
|
61
|
+
f"Request payload does not match contract for {operation.method} {operation.path} ({operation.operation_id})"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def as_dict(self) -> dict[str, Any]:
|
|
65
|
+
return {
|
|
66
|
+
"operation_id": self.operation.operation_id,
|
|
67
|
+
"path": self.operation.path,
|
|
68
|
+
"method": self.operation.method,
|
|
69
|
+
"validation_errors": [issue.as_dict() for issue in self.issues],
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class InternalApiContractRegistry:
|
|
74
|
+
def __init__(self, contracts_root: str | Path) -> None:
|
|
75
|
+
self._contracts_root = Path(contracts_root)
|
|
76
|
+
self._operations: dict[tuple[str, str], OperationContract] = {}
|
|
77
|
+
self._schema_bundles: dict[str, dict[str, Any]] = {}
|
|
78
|
+
self._validators: dict[tuple[str, str], Draft202012Validator] = {}
|
|
79
|
+
self._load()
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def contracts_root(self) -> Path:
|
|
83
|
+
return self._contracts_root
|
|
84
|
+
|
|
85
|
+
def get_operation(self, method: str, path: str) -> OperationContract | None:
|
|
86
|
+
return self._operations.get((str(method or "").upper(), str(path or "")))
|
|
87
|
+
|
|
88
|
+
def validate_request_payload(self, method: str, path: str, payload: Any) -> OperationContract:
|
|
89
|
+
operation = self.get_operation(method, path)
|
|
90
|
+
if operation is None:
|
|
91
|
+
raise InternalApiContractLookupError(f"No internal API contract found for {method} {path}")
|
|
92
|
+
if operation.request_body_schema_ref is None:
|
|
93
|
+
return operation
|
|
94
|
+
|
|
95
|
+
validator = self._get_validator(operation.version, operation.request_body_schema_ref)
|
|
96
|
+
errors = sorted(validator.iter_errors(payload), key=lambda err: (list(err.path), list(err.schema_path)))
|
|
97
|
+
if not errors:
|
|
98
|
+
return operation
|
|
99
|
+
|
|
100
|
+
issues = [
|
|
101
|
+
ContractValidationIssue(
|
|
102
|
+
json_path=_format_json_path(tuple(error.path)),
|
|
103
|
+
schema_path=_format_json_path(tuple(error.schema_path)),
|
|
104
|
+
message=error.message,
|
|
105
|
+
validator=str(error.validator),
|
|
106
|
+
)
|
|
107
|
+
for error in errors
|
|
108
|
+
]
|
|
109
|
+
raise InternalApiContractValidationError(operation, issues)
|
|
110
|
+
|
|
111
|
+
def _get_validator(self, version: str, schema_ref: str) -> Draft202012Validator:
|
|
112
|
+
cache_key = (version, schema_ref)
|
|
113
|
+
cached = self._validators.get(cache_key)
|
|
114
|
+
if cached is not None:
|
|
115
|
+
return cached
|
|
116
|
+
|
|
117
|
+
schema_bundle = self._schema_bundles.get(version)
|
|
118
|
+
if schema_bundle is None:
|
|
119
|
+
raise InternalApiContractLookupError(f"No schema bundle loaded for internal API version {version}")
|
|
120
|
+
|
|
121
|
+
validation_schema = {
|
|
122
|
+
"$schema": schema_bundle.get("$schema", "https://json-schema.org/draft/2020-12/schema"),
|
|
123
|
+
"$defs": schema_bundle.get("$defs", {}),
|
|
124
|
+
"allOf": [{"$ref": schema_ref}],
|
|
125
|
+
}
|
|
126
|
+
validator = Draft202012Validator(validation_schema)
|
|
127
|
+
self._validators[cache_key] = validator
|
|
128
|
+
return validator
|
|
129
|
+
|
|
130
|
+
def _load(self) -> None:
|
|
131
|
+
index_path = self._contracts_root / "index.json"
|
|
132
|
+
if not index_path.exists():
|
|
133
|
+
raise RuntimeError(f"Internal API contract index is missing: {index_path}")
|
|
134
|
+
|
|
135
|
+
index_payload = json.loads(index_path.read_text())
|
|
136
|
+
contracts = index_payload.get("contracts")
|
|
137
|
+
if not isinstance(contracts, list):
|
|
138
|
+
raise RuntimeError(f"Internal API contract index is invalid: {index_path}")
|
|
139
|
+
|
|
140
|
+
aggregate_contracts = [
|
|
141
|
+
contract for contract in contracts if isinstance(contract, dict) and contract.get("kind") == "aggregate"
|
|
142
|
+
]
|
|
143
|
+
if not aggregate_contracts:
|
|
144
|
+
raise RuntimeError(f"Internal API contract index has no aggregate bundle entries: {index_path}")
|
|
145
|
+
|
|
146
|
+
for contract in aggregate_contracts:
|
|
147
|
+
version = str(contract.get("version") or "").strip()
|
|
148
|
+
files = contract.get("files") or {}
|
|
149
|
+
schema_path = self._contracts_root / str(files.get("schema") or "")
|
|
150
|
+
operations_path = self._contracts_root / str(files.get("operations") or "")
|
|
151
|
+
if not version or not schema_path.exists() or not operations_path.exists():
|
|
152
|
+
raise RuntimeError(
|
|
153
|
+
f"Internal API aggregate contract bundle is incomplete for version {version or '<unknown>'}: {contract}"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
schema_bundle = json.loads(schema_path.read_text())
|
|
157
|
+
operations_manifest = json.loads(operations_path.read_text())
|
|
158
|
+
operations = operations_manifest.get("operations")
|
|
159
|
+
if not isinstance(operations, list):
|
|
160
|
+
raise RuntimeError(f"Internal API operations manifest is invalid: {operations_path}")
|
|
161
|
+
|
|
162
|
+
self._schema_bundles[version] = schema_bundle
|
|
163
|
+
for operation in operations:
|
|
164
|
+
if not isinstance(operation, dict):
|
|
165
|
+
continue
|
|
166
|
+
method = str(operation.get("method") or "").upper()
|
|
167
|
+
path = str(operation.get("path") or "")
|
|
168
|
+
operation_id = str(operation.get("operation_id") or "").strip()
|
|
169
|
+
request_body = operation.get("request_body") or {}
|
|
170
|
+
content = request_body.get("content") or {}
|
|
171
|
+
request_content_types = tuple(sorted(str(media_type) for media_type in content.keys()))
|
|
172
|
+
application_json = content.get("application/json") if isinstance(content, dict) else None
|
|
173
|
+
schema_ref = None
|
|
174
|
+
if isinstance(application_json, dict):
|
|
175
|
+
schema_ref = application_json.get("schema_ref")
|
|
176
|
+
|
|
177
|
+
self._operations[(method, path)] = OperationContract(
|
|
178
|
+
version=version,
|
|
179
|
+
method=method,
|
|
180
|
+
path=path,
|
|
181
|
+
operation_id=operation_id,
|
|
182
|
+
request_body_required=bool(request_body.get("required", False)),
|
|
183
|
+
request_body_schema_ref=str(schema_ref) if schema_ref else None,
|
|
184
|
+
request_content_types=request_content_types,
|
|
185
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
_STD_KEYS = frozenset(
|
|
9
|
+
{
|
|
10
|
+
"name",
|
|
11
|
+
"msg",
|
|
12
|
+
"args",
|
|
13
|
+
"created",
|
|
14
|
+
"relativeCreated",
|
|
15
|
+
"levelname",
|
|
16
|
+
"levelno",
|
|
17
|
+
"pathname",
|
|
18
|
+
"filename",
|
|
19
|
+
"module",
|
|
20
|
+
"funcName",
|
|
21
|
+
"lineno",
|
|
22
|
+
"thread",
|
|
23
|
+
"threadName",
|
|
24
|
+
"processName",
|
|
25
|
+
"process",
|
|
26
|
+
"msecs",
|
|
27
|
+
"exc_info",
|
|
28
|
+
"exc_text",
|
|
29
|
+
"stack_info",
|
|
30
|
+
"taskName",
|
|
31
|
+
"message",
|
|
32
|
+
}
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class _JsonFormatter(logging.Formatter):
|
|
37
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
38
|
+
payload: dict[str, Any] = {
|
|
39
|
+
"ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(record.created)),
|
|
40
|
+
"level": record.levelname,
|
|
41
|
+
"logger": record.name,
|
|
42
|
+
"msg": record.getMessage(),
|
|
43
|
+
}
|
|
44
|
+
for key, value in record.__dict__.items():
|
|
45
|
+
if key.startswith("_") or key in _STD_KEYS:
|
|
46
|
+
continue
|
|
47
|
+
payload[key] = value
|
|
48
|
+
if record.exc_info:
|
|
49
|
+
payload["exception"] = self.formatException(record.exc_info)
|
|
50
|
+
return json.dumps(payload, default=str)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def configure_logging(level: str = "INFO") -> None:
|
|
54
|
+
handler = logging.StreamHandler()
|
|
55
|
+
handler.setFormatter(_JsonFormatter())
|
|
56
|
+
|
|
57
|
+
root = logging.getLogger()
|
|
58
|
+
root.handlers.clear()
|
|
59
|
+
root.addHandler(handler)
|
|
60
|
+
root.setLevel(getattr(logging, level.upper(), logging.INFO))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def log_event(logger: logging.Logger, event: str, level: str = "info", **context: Any) -> None:
|
|
64
|
+
raw_extra = {"event": event, **context}
|
|
65
|
+
|
|
66
|
+
# Never allow callers to overwrite reserved LogRecord attributes.
|
|
67
|
+
# Python's logging will raise KeyError (and can crash workers) if `extra`
|
|
68
|
+
# contains any standard LogRecord keys such as `message`.
|
|
69
|
+
safe_extra: dict[str, Any] = {}
|
|
70
|
+
for key, value in raw_extra.items():
|
|
71
|
+
if key in _STD_KEYS:
|
|
72
|
+
safe_extra[f"ctx_{key}"] = value
|
|
73
|
+
else:
|
|
74
|
+
safe_extra[key] = value
|
|
75
|
+
|
|
76
|
+
getattr(logger, level.lower())(event, extra=safe_extra)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
DEFAULT_SPARK_JAR_DIR = Path.home() / ".dq-spark-jars"
|
|
10
|
+
DIRECT_SPARK_PACKAGE_ARTIFACTS = (
|
|
11
|
+
"spark-avro_2.13",
|
|
12
|
+
"hadoop-aws",
|
|
13
|
+
"delta-spark_2.13",
|
|
14
|
+
"delta-storage",
|
|
15
|
+
"iceberg-spark-runtime-4.0_2.13",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _artifact_versions(jar_paths: list[Path], artifact_name: str) -> dict[str, list[str]]:
|
|
20
|
+
versions: dict[str, list[str]] = {}
|
|
21
|
+
pattern = re.compile(rf"(?:^|_){re.escape(artifact_name)}-(?P<version>[^/]+)\.jar$")
|
|
22
|
+
for path in jar_paths:
|
|
23
|
+
match = pattern.search(path.name)
|
|
24
|
+
if match is None:
|
|
25
|
+
continue
|
|
26
|
+
versions.setdefault(match.group("version"), []).append(path.name)
|
|
27
|
+
return versions
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _reject_duplicate_direct_artifacts(jar_paths: list[Path]) -> None:
|
|
31
|
+
conflicts: list[str] = []
|
|
32
|
+
for artifact_name in DIRECT_SPARK_PACKAGE_ARTIFACTS:
|
|
33
|
+
versions = _artifact_versions(jar_paths, artifact_name)
|
|
34
|
+
if len(versions) < 2:
|
|
35
|
+
continue
|
|
36
|
+
version_list = ", ".join(f"{version} ({', '.join(names)})" for version, names in sorted(versions.items()))
|
|
37
|
+
conflicts.append(f"{artifact_name}: {version_list}")
|
|
38
|
+
|
|
39
|
+
if conflicts:
|
|
40
|
+
raise SystemExit(
|
|
41
|
+
"Conflicting Spark package jar versions found in the shared Spark jar directory: "
|
|
42
|
+
+ "; ".join(conflicts)
|
|
43
|
+
+ ". Re-run dq-engine-warmup or clear the spark-jars volume so only the canonical package versions remain."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def spark_jar_paths() -> list[Path]:
|
|
48
|
+
jar_dir = Path(os.getenv("DQ_SPARK_JAR_DIR") or DEFAULT_SPARK_JAR_DIR)
|
|
49
|
+
if not jar_dir.is_dir():
|
|
50
|
+
raise SystemExit(
|
|
51
|
+
f"Spark jar directory not found: {jar_dir}. The dq-engine image must bake the required Spark jars during the build phase."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
all_jars = sorted(path for path in jar_dir.glob("*.jar") if path.is_file())
|
|
55
|
+
if not all_jars:
|
|
56
|
+
raise SystemExit(
|
|
57
|
+
f"No Spark jars were found in {jar_dir}. The dq-engine image must copy the build-time Spark cache into that directory."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
max_mb_env = os.getenv("DQ_SPARK_MAX_JAR_SIZE_MB")
|
|
61
|
+
try:
|
|
62
|
+
max_mb = int(max_mb_env) if max_mb_env else 200
|
|
63
|
+
except Exception:
|
|
64
|
+
max_mb = 200
|
|
65
|
+
|
|
66
|
+
include_large = os.getenv("DQ_SPARK_INCLUDE_LARGE_JARS", "").strip().lower() in ("1", "true", "yes")
|
|
67
|
+
|
|
68
|
+
filtered: list[Path] = []
|
|
69
|
+
excluded: list[tuple[str, float]] = []
|
|
70
|
+
for p in all_jars:
|
|
71
|
+
try:
|
|
72
|
+
size_mb = p.stat().st_size / (1024 * 1024)
|
|
73
|
+
except Exception:
|
|
74
|
+
size_mb = 0.0
|
|
75
|
+
if size_mb > max_mb and not include_large:
|
|
76
|
+
excluded.append((p.name, size_mb))
|
|
77
|
+
continue
|
|
78
|
+
filtered.append(p)
|
|
79
|
+
|
|
80
|
+
if not filtered:
|
|
81
|
+
raise SystemExit(
|
|
82
|
+
f"No Spark jars remain after applying size filter (max {max_mb}MB)."
|
|
83
|
+
" Set DQ_SPARK_INCLUDE_LARGE_JARS=1 to include large jars or increase DQ_SPARK_MAX_JAR_SIZE_MB."
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
_reject_duplicate_direct_artifacts(filtered)
|
|
87
|
+
|
|
88
|
+
if excluded:
|
|
89
|
+
names = ", ".join(name for name, _ in excluded[:10])
|
|
90
|
+
print(
|
|
91
|
+
f"warning: excluded {len(excluded)} large jar(s) >{max_mb}MB: {names}{'...' if len(excluded)>10 else ''}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return filtered
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def configure_spark_builder_with_local_jars(builder: Any) -> Any:
|
|
98
|
+
jar_paths = spark_jar_paths()
|
|
99
|
+
return builder.config("spark.jars", ",".join(str(path) for path in jar_paths))
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
DEFAULT_SPARK_MASTER = "local[*]"
|
|
8
|
+
DEFAULT_SPARK_UI_PORT = 4044
|
|
9
|
+
DEFAULT_SPARK_SESSION_TIMEZONE = "UTC"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def resolve_spark_master(default: str = DEFAULT_SPARK_MASTER) -> str:
|
|
13
|
+
return str(os.getenv("DQ_SPARK_MASTER") or default).strip() or default
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def resolve_spark_ui_port(raw_value: str | int | None = None) -> int:
|
|
17
|
+
if raw_value is None:
|
|
18
|
+
raw_value = os.getenv("DQ_SPARK_UI_PORT") or str(DEFAULT_SPARK_UI_PORT)
|
|
19
|
+
normalized = str(raw_value).strip()
|
|
20
|
+
try:
|
|
21
|
+
parsed = int(normalized)
|
|
22
|
+
except Exception as exc:
|
|
23
|
+
raise ValueError("DQ_SPARK_UI_PORT must be a positive integer") from exc
|
|
24
|
+
if parsed < 1:
|
|
25
|
+
raise ValueError("DQ_SPARK_UI_PORT must be a positive integer")
|
|
26
|
+
return parsed
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def configure_spark_builder(
|
|
30
|
+
builder: Any,
|
|
31
|
+
*,
|
|
32
|
+
spark_ui_port: str | int | None = None,
|
|
33
|
+
session_timezone: str | None = None,
|
|
34
|
+
) -> Any:
|
|
35
|
+
configured = builder.config("spark.ui.port", str(resolve_spark_ui_port(spark_ui_port)))
|
|
36
|
+
if session_timezone:
|
|
37
|
+
configured = configured.config("spark.sql.session.timeZone", str(session_timezone))
|
|
38
|
+
|
|
39
|
+
# Allow overriding driver/executor memory from environment variables.
|
|
40
|
+
# Respect DQ-prefixed vars first, then fall back to Spark-standard names.
|
|
41
|
+
driver_mem = os.getenv("DQ_SPARK_DRIVER_MEMORY") or os.getenv("SPARK_DRIVER_MEMORY")
|
|
42
|
+
executor_mem = os.getenv("DQ_SPARK_EXECUTOR_MEMORY") or os.getenv("SPARK_EXECUTOR_MEMORY")
|
|
43
|
+
if driver_mem:
|
|
44
|
+
configured = configured.config("spark.driver.memory", str(driver_mem))
|
|
45
|
+
if executor_mem:
|
|
46
|
+
configured = configured.config("spark.executor.memory", str(executor_mem))
|
|
47
|
+
|
|
48
|
+
return configured
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def build_spark_session_builder(
|
|
52
|
+
*,
|
|
53
|
+
SparkSession: Any,
|
|
54
|
+
app_name: str,
|
|
55
|
+
master: str | None = None,
|
|
56
|
+
spark_ui_port: str | int | None = None,
|
|
57
|
+
session_timezone: str | None = None,
|
|
58
|
+
) -> Any:
|
|
59
|
+
builder = SparkSession.builder.appName(app_name)
|
|
60
|
+
if master is not None:
|
|
61
|
+
builder = builder.master(master)
|
|
62
|
+
return configure_spark_builder(
|
|
63
|
+
builder,
|
|
64
|
+
spark_ui_port=spark_ui_port,
|
|
65
|
+
session_timezone=session_timezone,
|
|
66
|
+
)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import importlib.util
|
|
4
|
+
import types
|
|
5
|
+
import logging
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Make local source importable
|
|
10
|
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
SRC_DIR = os.path.join(ROOT_DIR, "dq-utils", "src")
|
|
12
|
+
if SRC_DIR not in sys.path:
|
|
13
|
+
sys.path.insert(0, SRC_DIR)
|
|
14
|
+
|
|
15
|
+
# Load module directly from file to avoid package-level side-effects
|
|
16
|
+
mod_path = os.path.join(SRC_DIR, "dq_utils", "auth_utils.py")
|
|
17
|
+
# Ensure a dq_utils package exists in sys.modules so dataclasses and relative
|
|
18
|
+
# module-level references resolve correctly when loading the module.
|
|
19
|
+
pkg = types.ModuleType("dq_utils")
|
|
20
|
+
pkg.__path__ = [os.path.join(SRC_DIR, "dq_utils")]
|
|
21
|
+
sys.modules["dq_utils"] = pkg
|
|
22
|
+
|
|
23
|
+
spec = importlib.util.spec_from_file_location("dq_utils.auth_utils", mod_path)
|
|
24
|
+
auth_utils = importlib.util.module_from_spec(spec)
|
|
25
|
+
# Ensure the module is present in sys.modules under its intended name so
|
|
26
|
+
# decorators (dataclasses) can resolve module references during class creation.
|
|
27
|
+
sys.modules[spec.name] = auth_utils
|
|
28
|
+
assert spec.loader is not None
|
|
29
|
+
spec.loader.exec_module(auth_utils)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def test_static_token_provider_accepts_and_returns_token():
|
|
33
|
+
with pytest.raises(auth_utils.AuthConfigError):
|
|
34
|
+
auth_utils.StaticTokenProvider("")
|
|
35
|
+
|
|
36
|
+
p = auth_utils.StaticTokenProvider(" secret ")
|
|
37
|
+
assert p.get_token(correlation_id="cid") == "secret"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_resolve_oidc_token_url_behaviour():
|
|
41
|
+
assert (
|
|
42
|
+
auth_utils.resolve_oidc_token_url(issuer="https://issuer", token_url=None)
|
|
43
|
+
== "https://issuer/protocol/openid-connect/token"
|
|
44
|
+
)
|
|
45
|
+
assert (
|
|
46
|
+
auth_utils.resolve_oidc_token_url(issuer=None, token_url="https://t")
|
|
47
|
+
== "https://t"
|
|
48
|
+
)
|
|
49
|
+
assert auth_utils.resolve_oidc_token_url(issuer=None, token_url=None) is None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DummyResponse:
|
|
53
|
+
def __init__(self, status_code=200, payload=None, json_raises=False):
|
|
54
|
+
self.status_code = status_code
|
|
55
|
+
self._payload = payload or {}
|
|
56
|
+
self._json_raises = json_raises
|
|
57
|
+
|
|
58
|
+
def json(self):
|
|
59
|
+
if self._json_raises:
|
|
60
|
+
raise ValueError("not json")
|
|
61
|
+
return self._payload
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_oidc_client_credentials_get_token_success_and_errors(monkeypatch):
|
|
65
|
+
calls = {}
|
|
66
|
+
|
|
67
|
+
def fake_post_success(url, data=None, headers=None, timeout=None):
|
|
68
|
+
calls['last'] = dict(url=url, data=data, headers=headers, timeout=timeout)
|
|
69
|
+
return DummyResponse(status_code=200, payload={"access_token": "abc", "expires_in": 3600})
|
|
70
|
+
|
|
71
|
+
provider = auth_utils.OidcClientCredentialsTokenProvider(
|
|
72
|
+
token_url="https://tok",
|
|
73
|
+
client_id="cid",
|
|
74
|
+
client_secret="cs",
|
|
75
|
+
scope=None,
|
|
76
|
+
refresh_skew_seconds=60,
|
|
77
|
+
timeout_seconds=1,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# network success
|
|
81
|
+
monkeypatch.setattr(auth_utils.requests, "post", fake_post_success)
|
|
82
|
+
token = provider.get_token(correlation_id="cid")
|
|
83
|
+
assert token == "abc"
|
|
84
|
+
|
|
85
|
+
# Clear cache so subsequent calls actually invoke the token endpoint
|
|
86
|
+
provider._cached = None
|
|
87
|
+
|
|
88
|
+
# response with error status
|
|
89
|
+
def fake_post_400(*a, **k):
|
|
90
|
+
return DummyResponse(status_code=400, payload={})
|
|
91
|
+
|
|
92
|
+
monkeypatch.setattr(auth_utils.requests, "post", fake_post_400)
|
|
93
|
+
with pytest.raises(auth_utils.AuthConfigError):
|
|
94
|
+
provider.get_token(correlation_id="cid")
|
|
95
|
+
|
|
96
|
+
# response with non-json
|
|
97
|
+
# Clear cache again for next scenario
|
|
98
|
+
provider._cached = None
|
|
99
|
+
|
|
100
|
+
def fake_post_nonjson(*a, **k):
|
|
101
|
+
return DummyResponse(status_code=200, json_raises=True)
|
|
102
|
+
|
|
103
|
+
monkeypatch.setattr(auth_utils.requests, "post", fake_post_nonjson)
|
|
104
|
+
with pytest.raises(auth_utils.AuthConfigError):
|
|
105
|
+
provider.get_token(correlation_id="cid")
|
|
106
|
+
|
|
107
|
+
# network exception
|
|
108
|
+
def fake_post_exc(*a, **k):
|
|
109
|
+
raise RuntimeError("boom")
|
|
110
|
+
|
|
111
|
+
monkeypatch.setattr(auth_utils.requests, "post", fake_post_exc)
|
|
112
|
+
with pytest.raises(auth_utils.AuthConfigError):
|
|
113
|
+
provider.get_token(correlation_id="cid")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def test_build_token_provider_from_env_prefers_static(monkeypatch):
|
|
117
|
+
monkeypatch.setenv("MY_STATIC_TOKEN", "s1")
|
|
118
|
+
p = auth_utils.build_token_provider_from_env(
|
|
119
|
+
static_token_env_var="MY_STATIC_TOKEN",
|
|
120
|
+
issuer_env_var="ISS",
|
|
121
|
+
token_url_env_var="T",
|
|
122
|
+
client_id_env_var="CID",
|
|
123
|
+
client_secret_env_var="CS",
|
|
124
|
+
scope_env_var="S",
|
|
125
|
+
)
|
|
126
|
+
assert isinstance(p, auth_utils.StaticTokenProvider)
|
|
127
|
+
assert p.get_token(correlation_id="c") == "s1"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_build_oidc_token_provider_from_env_raises_when_missing(monkeypatch):
|
|
131
|
+
monkeypatch.delenv("ISS", raising=False)
|
|
132
|
+
monkeypatch.delenv("T", raising=False)
|
|
133
|
+
monkeypatch.delenv("CID", raising=False)
|
|
134
|
+
monkeypatch.delenv("CS", raising=False)
|
|
135
|
+
|
|
136
|
+
with pytest.raises(auth_utils.AuthConfigError):
|
|
137
|
+
auth_utils.build_oidc_token_provider_from_env(
|
|
138
|
+
issuer_env_var="ISS",
|
|
139
|
+
token_url_env_var="T",
|
|
140
|
+
client_id_env_var="CID",
|
|
141
|
+
client_secret_env_var="CS",
|
|
142
|
+
scope_env_var="S",
|
|
143
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Ensure local package source is importable when running this test file directly.
|
|
8
|
+
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
9
|
+
SRC_DIR = os.path.join(ROOT_DIR, "dq-utils", "src")
|
|
10
|
+
if SRC_DIR not in sys.path:
|
|
11
|
+
sys.path.insert(0, SRC_DIR)
|
|
12
|
+
|
|
13
|
+
import importlib.util
|
|
14
|
+
|
|
15
|
+
# Load the module directly from its source file path to avoid importing
|
|
16
|
+
# dq_utils.__init__ (which pulls heavy optional deps during import).
|
|
17
|
+
mod_path = os.path.join(SRC_DIR, "dq_utils", "logging_utils.py")
|
|
18
|
+
spec = importlib.util.spec_from_file_location("dq_utils_logging_utils", mod_path)
|
|
19
|
+
logging_utils = importlib.util.module_from_spec(spec)
|
|
20
|
+
assert spec.loader is not None
|
|
21
|
+
spec.loader.exec_module(logging_utils)
|
|
22
|
+
|
|
23
|
+
_JsonFormatter = logging_utils._JsonFormatter
|
|
24
|
+
configure_logging = logging_utils.configure_logging
|
|
25
|
+
log_event = logging_utils.log_event
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_json_formatter_includes_custom_fields():
|
|
29
|
+
fmt = _JsonFormatter()
|
|
30
|
+
record = logging.LogRecord(
|
|
31
|
+
name="mylogger",
|
|
32
|
+
level=logging.INFO,
|
|
33
|
+
pathname=__file__,
|
|
34
|
+
lineno=10,
|
|
35
|
+
msg="hello",
|
|
36
|
+
args=(),
|
|
37
|
+
exc_info=None,
|
|
38
|
+
)
|
|
39
|
+
# add a non-standard attribute which should be included in the JSON
|
|
40
|
+
record.__dict__["custom_key"] = "custom_value"
|
|
41
|
+
|
|
42
|
+
payload = fmt.format(record)
|
|
43
|
+
data = json.loads(payload)
|
|
44
|
+
|
|
45
|
+
assert data["logger"] == "mylogger"
|
|
46
|
+
assert data["msg"] == "hello"
|
|
47
|
+
assert data["custom_key"] == "custom_value"
|
|
48
|
+
assert "ts" in data and "level" in data
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_configure_logging_sets_handler_and_level():
|
|
52
|
+
# configure logging and assert root logger has a StreamHandler and correct level
|
|
53
|
+
configure_logging("WARNING")
|
|
54
|
+
root = logging.getLogger()
|
|
55
|
+
assert any(isinstance(h, logging.StreamHandler) for h in root.handlers)
|
|
56
|
+
assert root.level == logging.WARNING
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_log_event_safe_extra_and_reserved_prefix():
|
|
60
|
+
captured = []
|
|
61
|
+
|
|
62
|
+
class ListHandler(logging.Handler):
|
|
63
|
+
def emit(self, rec: logging.LogRecord) -> None: # type: ignore[override]
|
|
64
|
+
# store a shallow copy of the record dict so assertions can inspect it
|
|
65
|
+
captured.append(rec.__dict__.copy())
|
|
66
|
+
|
|
67
|
+
logger = logging.getLogger("test_logger_for_log_event")
|
|
68
|
+
# ensure a clean handler set for the logger used in this test
|
|
69
|
+
logger.handlers.clear()
|
|
70
|
+
handler = ListHandler()
|
|
71
|
+
logger.addHandler(handler)
|
|
72
|
+
logger.setLevel(logging.DEBUG)
|
|
73
|
+
|
|
74
|
+
# Call log_event with a reserved key ('message') and a normal key ('user')
|
|
75
|
+
log_event(logger, "evt", level="info", message="danger", user="alice")
|
|
76
|
+
|
|
77
|
+
assert captured, "expected a log record to be emitted"
|
|
78
|
+
rec = captured[-1]
|
|
79
|
+
|
|
80
|
+
# reserved key should be prefixed to avoid overwriting LogRecord internals
|
|
81
|
+
assert rec.get("ctx_message") == "danger"
|
|
82
|
+
assert rec.get("user") == "alice"
|
|
83
|
+
assert rec.get("msg") == "evt"
|