dq-made-easy-utils 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: dq-made-easy-utils
3
+ Version: 0.1.0
4
+ Summary: Shared utilities for dq-made-easy Python services
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: jsonschema>=4.25.1
8
+ Requires-Dist: requests>=2.32.0
9
+
10
+ # dq-made-easy-utils
11
+
12
+ Shared Python utilities used across dq-made-easy services.
13
+
14
+ Import package name: `dq_utils`.
@@ -0,0 +1,5 @@
1
+ # dq-made-easy-utils
2
+
3
+ Shared Python utilities used across dq-made-easy services.
4
+
5
+ Import package name: `dq_utils`.
@@ -0,0 +1,20 @@
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "dq-made-easy-utils"
7
+ version = "0.1.0"
8
+ description = "Shared utilities for dq-made-easy Python services"
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ dependencies = [
12
+ "jsonschema>=4.25.1",
13
+ "requests>=2.32.0",
14
+ ]
15
+
16
+ [tool.setuptools]
17
+ package-dir = {"" = "src"}
18
+
19
+ [tool.setuptools.packages.find]
20
+ where = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: dq-made-easy-utils
3
+ Version: 0.1.0
4
+ Summary: Shared utilities for dq-made-easy Python services
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: jsonschema>=4.25.1
8
+ Requires-Dist: requests>=2.32.0
9
+
10
+ # dq-made-easy-utils
11
+
12
+ Shared Python utilities used across dq-made-easy services.
13
+
14
+ Import package name: `dq_utils`.
@@ -0,0 +1,15 @@
1
+ README.md
2
+ pyproject.toml
3
+ src/dq_made_easy_utils.egg-info/PKG-INFO
4
+ src/dq_made_easy_utils.egg-info/SOURCES.txt
5
+ src/dq_made_easy_utils.egg-info/dependency_links.txt
6
+ src/dq_made_easy_utils.egg-info/requires.txt
7
+ src/dq_made_easy_utils.egg-info/top_level.txt
8
+ src/dq_utils/__init__.py
9
+ src/dq_utils/auth_utils.py
10
+ src/dq_utils/internal_api_contracts.py
11
+ src/dq_utils/logging_utils.py
12
+ src/dq_utils/spark_jars.py
13
+ src/dq_utils/spark_runtime.py
14
+ tests/test_auth_utils.py
15
+ tests/test_logging_utils.py
@@ -0,0 +1,2 @@
1
+ jsonschema>=4.25.1
2
+ requests>=2.32.0
@@ -0,0 +1,11 @@
1
+ """Shared utilities for dq-made-easy Python services."""
2
+
3
+ from dq_utils.internal_api_contracts import InternalApiContractLookupError
4
+ from dq_utils.internal_api_contracts import InternalApiContractRegistry
5
+ from dq_utils.internal_api_contracts import InternalApiContractValidationError
6
+
7
+ __all__ = [
8
+ "InternalApiContractLookupError",
9
+ "InternalApiContractRegistry",
10
+ "InternalApiContractValidationError",
11
+ ]
@@ -0,0 +1,303 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Protocol
7
+
8
+ import requests
9
+
10
+
11
+ class AuthConfigError(RuntimeError):
12
+ pass
13
+
14
+
15
+ class TokenProvider(Protocol):
16
+ def get_token(self, *, correlation_id: str) -> str: ...
17
+
18
+
19
+ @dataclass
20
+ class TokenBundle:
21
+ access_token: str
22
+ expires_at_epoch_seconds: float
23
+
24
+
25
+ class StaticTokenProvider:
26
+ def __init__(self, token: str) -> None:
27
+ token = str(token or "").strip()
28
+ if not token:
29
+ raise AuthConfigError("Static token is empty")
30
+ self._token = token
31
+
32
+ def get_token(self, *, correlation_id: str) -> str:
33
+ _ = correlation_id
34
+ return self._token
35
+
36
+
37
+ class OidcClientCredentialsTokenProvider:
38
+ def __init__(
39
+ self,
40
+ *,
41
+ token_url: str,
42
+ client_id: str,
43
+ client_secret: str,
44
+ scope: str | None = None,
45
+ refresh_skew_seconds: int = 60,
46
+ timeout_seconds: int = 10,
47
+ ) -> None:
48
+ token_url = str(token_url or "").strip()
49
+ client_id = str(client_id or "").strip()
50
+ client_secret = str(client_secret or "").strip()
51
+ scope = str(scope or "").strip() or None
52
+
53
+ if not token_url:
54
+ raise AuthConfigError("OIDC token_url is required")
55
+ if not client_id:
56
+ raise AuthConfigError("OIDC client_id is required")
57
+ if not client_secret:
58
+ raise AuthConfigError("OIDC client_secret is required")
59
+
60
+ self._token_url = token_url
61
+ self._client_id = client_id
62
+ self._client_secret = client_secret
63
+ self._scope = scope
64
+ self._refresh_skew_seconds = int(refresh_skew_seconds)
65
+ self._timeout_seconds = int(timeout_seconds)
66
+ self._cached: TokenBundle | None = None
67
+
68
+ def get_token(self, *, correlation_id: str) -> str:
69
+ now = time.time()
70
+ if self._cached is not None and (self._cached.expires_at_epoch_seconds - self._refresh_skew_seconds) > now:
71
+ return self._cached.access_token
72
+
73
+ data: dict[str, str] = {
74
+ "grant_type": "client_credentials",
75
+ "client_id": self._client_id,
76
+ "client_secret": self._client_secret,
77
+ }
78
+ if self._scope:
79
+ data["scope"] = self._scope
80
+
81
+ try:
82
+ response = requests.post(
83
+ self._token_url,
84
+ data=data,
85
+ headers={"X-Correlation-ID": correlation_id},
86
+ timeout=self._timeout_seconds,
87
+ )
88
+ except Exception as exc:
89
+ raise AuthConfigError(
90
+ f"Unable to obtain OIDC access token (token endpoint unreachable at '{self._token_url}')"
91
+ ) from exc
92
+
93
+ if response.status_code >= 400:
94
+ raise AuthConfigError(
95
+ f"Unable to obtain OIDC access token (token endpoint returned {response.status_code})"
96
+ )
97
+
98
+ try:
99
+ payload = response.json()
100
+ except Exception as exc:
101
+ raise AuthConfigError("OIDC token endpoint returned non-JSON response") from exc
102
+
103
+ token = str(payload.get("access_token") or "").strip()
104
+ expires_in = payload.get("expires_in")
105
+ try:
106
+ expires_in_seconds = int(expires_in)
107
+ except Exception:
108
+ expires_in_seconds = 0
109
+
110
+ if not token:
111
+ raise AuthConfigError("OIDC token endpoint response missing access_token")
112
+ if expires_in_seconds <= 0:
113
+ raise AuthConfigError("OIDC token endpoint response missing/invalid expires_in")
114
+
115
+ self._cached = TokenBundle(
116
+ access_token=token,
117
+ expires_at_epoch_seconds=now + float(expires_in_seconds),
118
+ )
119
+ return token
120
+
121
+
122
+ class OidcPasswordTokenProvider:
123
+ def __init__(
124
+ self,
125
+ *,
126
+ token_url: str,
127
+ client_id: str,
128
+ username: str,
129
+ password: str,
130
+ client_secret: str | None = None,
131
+ scope: str | None = None,
132
+ refresh_skew_seconds: int = 60,
133
+ timeout_seconds: int = 10,
134
+ ) -> None:
135
+ token_url = str(token_url or "").strip()
136
+ client_id = str(client_id or "").strip()
137
+ username = str(username or "").strip()
138
+ password = str(password or "").strip()
139
+ client_secret = str(client_secret or "").strip() or None
140
+ scope = str(scope or "").strip() or None
141
+
142
+ if not token_url:
143
+ raise AuthConfigError("OIDC token_url is required")
144
+ if not client_id:
145
+ raise AuthConfigError("OIDC client_id is required")
146
+ if not username:
147
+ raise AuthConfigError("OIDC username is required")
148
+ if not password:
149
+ raise AuthConfigError("OIDC password is required")
150
+
151
+ self._token_url = token_url
152
+ self._client_id = client_id
153
+ self._username = username
154
+ self._password = password
155
+ self._client_secret = client_secret
156
+ self._scope = scope
157
+ self._refresh_skew_seconds = int(refresh_skew_seconds)
158
+ self._timeout_seconds = int(timeout_seconds)
159
+ self._cached: TokenBundle | None = None
160
+
161
+ def get_token(self, *, correlation_id: str) -> str:
162
+ now = time.time()
163
+ if self._cached is not None and (self._cached.expires_at_epoch_seconds - self._refresh_skew_seconds) > now:
164
+ return self._cached.access_token
165
+
166
+ data: dict[str, str] = {
167
+ "grant_type": "password",
168
+ "client_id": self._client_id,
169
+ "username": self._username,
170
+ "password": self._password,
171
+ }
172
+ if self._client_secret:
173
+ data["client_secret"] = self._client_secret
174
+ if self._scope:
175
+ data["scope"] = self._scope
176
+
177
+ try:
178
+ response = requests.post(
179
+ self._token_url,
180
+ data=data,
181
+ headers={"X-Correlation-ID": correlation_id},
182
+ timeout=self._timeout_seconds,
183
+ )
184
+ except Exception as exc:
185
+ raise AuthConfigError(
186
+ f"Unable to obtain OIDC access token (token endpoint unreachable at '{self._token_url}')"
187
+ ) from exc
188
+
189
+ if response.status_code >= 400:
190
+ raise AuthConfigError(
191
+ f"Unable to obtain OIDC access token (token endpoint returned {response.status_code})"
192
+ )
193
+
194
+ try:
195
+ payload = response.json()
196
+ except Exception as exc:
197
+ raise AuthConfigError("OIDC token endpoint returned non-JSON response") from exc
198
+
199
+ token = str(payload.get("access_token") or "").strip()
200
+ expires_in = payload.get("expires_in")
201
+ try:
202
+ expires_in_seconds = int(expires_in)
203
+ except Exception:
204
+ expires_in_seconds = 0
205
+
206
+ if not token:
207
+ raise AuthConfigError("OIDC token endpoint response missing access_token")
208
+ if expires_in_seconds <= 0:
209
+ raise AuthConfigError("OIDC token endpoint response missing/invalid expires_in")
210
+
211
+ self._cached = TokenBundle(
212
+ access_token=token,
213
+ expires_at_epoch_seconds=now + float(expires_in_seconds),
214
+ )
215
+ return token
216
+
217
+
218
+ def resolve_oidc_token_url(*, issuer: str | None, token_url: str | None) -> str | None:
219
+ token_url_value = str(token_url or "").strip()
220
+ if token_url_value:
221
+ return token_url_value
222
+
223
+ issuer_value = str(issuer or "").strip().rstrip("/")
224
+ if issuer_value:
225
+ return issuer_value + "/protocol/openid-connect/token"
226
+
227
+ return None
228
+
229
+
230
+ def build_token_provider_from_env(
231
+ *,
232
+ static_token_env_var: str,
233
+ issuer_env_var: str,
234
+ token_url_env_var: str,
235
+ client_id_env_var: str,
236
+ client_secret_env_var: str,
237
+ scope_env_var: str,
238
+ refresh_skew_seconds: int = 60,
239
+ ) -> TokenProvider:
240
+ static_token = str(os.getenv(static_token_env_var) or "").strip()
241
+ if static_token:
242
+ return StaticTokenProvider(static_token)
243
+
244
+ token_url = resolve_oidc_token_url(
245
+ issuer=os.getenv(issuer_env_var),
246
+ token_url=os.getenv(token_url_env_var),
247
+ )
248
+ client_id = str(os.getenv(client_id_env_var) or "").strip()
249
+ client_secret = str(os.getenv(client_secret_env_var) or "").strip()
250
+ scope = str(os.getenv(scope_env_var) or "").strip() or None
251
+
252
+ if token_url and client_id and client_secret:
253
+ return OidcClientCredentialsTokenProvider(
254
+ token_url=token_url,
255
+ client_id=client_id,
256
+ client_secret=client_secret,
257
+ scope=scope,
258
+ refresh_skew_seconds=refresh_skew_seconds,
259
+ )
260
+
261
+ raise AuthConfigError(
262
+ "Auth is not configured. Set a static bearer token in "
263
+ f"{static_token_env_var}, or configure OIDC client credentials using "
264
+ f"({issuer_env_var} or {token_url_env_var}) plus {client_id_env_var} and {client_secret_env_var}."
265
+ )
266
+
267
+
268
+ def build_oidc_token_provider_from_env(
269
+ *,
270
+ issuer_env_var: str,
271
+ token_url_env_var: str,
272
+ client_id_env_var: str,
273
+ client_secret_env_var: str,
274
+ scope_env_var: str,
275
+ refresh_skew_seconds: int = 60,
276
+ ) -> TokenProvider:
277
+ """Build an OIDC client-credentials token provider from env.
278
+
279
+ This intentionally does not support static bearer tokens. Callers that need
280
+ fail-fast token rotation should use this helper.
281
+ """
282
+
283
+ token_url = resolve_oidc_token_url(
284
+ issuer=os.getenv(issuer_env_var),
285
+ token_url=os.getenv(token_url_env_var),
286
+ )
287
+ client_id = str(os.getenv(client_id_env_var) or "").strip()
288
+ client_secret = str(os.getenv(client_secret_env_var) or "").strip()
289
+ scope = str(os.getenv(scope_env_var) or "").strip() or None
290
+
291
+ if token_url and client_id and client_secret:
292
+ return OidcClientCredentialsTokenProvider(
293
+ token_url=token_url,
294
+ client_id=client_id,
295
+ client_secret=client_secret,
296
+ scope=scope,
297
+ refresh_skew_seconds=refresh_skew_seconds,
298
+ )
299
+
300
+ raise AuthConfigError(
301
+ "OIDC auth is not configured. Configure OIDC client credentials using "
302
+ f"({issuer_env_var} or {token_url_env_var}) plus {client_id_env_var} and {client_secret_env_var}."
303
+ )
@@ -0,0 +1,185 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from jsonschema import Draft202012Validator
9
+
10
+
11
+ def _format_json_path(segments: tuple[Any, ...]) -> str:
12
+ path = "$"
13
+ for segment in segments:
14
+ if isinstance(segment, int):
15
+ path += f"[{segment}]"
16
+ continue
17
+ text = str(segment)
18
+ if text.isidentifier():
19
+ path += f".{text}"
20
+ continue
21
+ path += f"[{json.dumps(text)}]"
22
+ return path
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class ContractValidationIssue:
27
+ json_path: str
28
+ schema_path: str
29
+ message: str
30
+ validator: str
31
+
32
+ def as_dict(self) -> dict[str, str]:
33
+ return {
34
+ "json_path": self.json_path,
35
+ "schema_path": self.schema_path,
36
+ "message": self.message,
37
+ "validator": self.validator,
38
+ }
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class OperationContract:
43
+ version: str
44
+ method: str
45
+ path: str
46
+ operation_id: str
47
+ request_body_required: bool
48
+ request_body_schema_ref: str | None
49
+ request_content_types: tuple[str, ...]
50
+
51
+
52
+ class InternalApiContractLookupError(RuntimeError):
53
+ pass
54
+
55
+
56
+ class InternalApiContractValidationError(RuntimeError):
57
+ def __init__(self, operation: OperationContract, issues: list[ContractValidationIssue]) -> None:
58
+ self.operation = operation
59
+ self.issues = tuple(issues)
60
+ super().__init__(
61
+ f"Request payload does not match contract for {operation.method} {operation.path} ({operation.operation_id})"
62
+ )
63
+
64
+ def as_dict(self) -> dict[str, Any]:
65
+ return {
66
+ "operation_id": self.operation.operation_id,
67
+ "path": self.operation.path,
68
+ "method": self.operation.method,
69
+ "validation_errors": [issue.as_dict() for issue in self.issues],
70
+ }
71
+
72
+
73
+ class InternalApiContractRegistry:
74
+ def __init__(self, contracts_root: str | Path) -> None:
75
+ self._contracts_root = Path(contracts_root)
76
+ self._operations: dict[tuple[str, str], OperationContract] = {}
77
+ self._schema_bundles: dict[str, dict[str, Any]] = {}
78
+ self._validators: dict[tuple[str, str], Draft202012Validator] = {}
79
+ self._load()
80
+
81
+ @property
82
+ def contracts_root(self) -> Path:
83
+ return self._contracts_root
84
+
85
+ def get_operation(self, method: str, path: str) -> OperationContract | None:
86
+ return self._operations.get((str(method or "").upper(), str(path or "")))
87
+
88
+ def validate_request_payload(self, method: str, path: str, payload: Any) -> OperationContract:
89
+ operation = self.get_operation(method, path)
90
+ if operation is None:
91
+ raise InternalApiContractLookupError(f"No internal API contract found for {method} {path}")
92
+ if operation.request_body_schema_ref is None:
93
+ return operation
94
+
95
+ validator = self._get_validator(operation.version, operation.request_body_schema_ref)
96
+ errors = sorted(validator.iter_errors(payload), key=lambda err: (list(err.path), list(err.schema_path)))
97
+ if not errors:
98
+ return operation
99
+
100
+ issues = [
101
+ ContractValidationIssue(
102
+ json_path=_format_json_path(tuple(error.path)),
103
+ schema_path=_format_json_path(tuple(error.schema_path)),
104
+ message=error.message,
105
+ validator=str(error.validator),
106
+ )
107
+ for error in errors
108
+ ]
109
+ raise InternalApiContractValidationError(operation, issues)
110
+
111
+ def _get_validator(self, version: str, schema_ref: str) -> Draft202012Validator:
112
+ cache_key = (version, schema_ref)
113
+ cached = self._validators.get(cache_key)
114
+ if cached is not None:
115
+ return cached
116
+
117
+ schema_bundle = self._schema_bundles.get(version)
118
+ if schema_bundle is None:
119
+ raise InternalApiContractLookupError(f"No schema bundle loaded for internal API version {version}")
120
+
121
+ validation_schema = {
122
+ "$schema": schema_bundle.get("$schema", "https://json-schema.org/draft/2020-12/schema"),
123
+ "$defs": schema_bundle.get("$defs", {}),
124
+ "allOf": [{"$ref": schema_ref}],
125
+ }
126
+ validator = Draft202012Validator(validation_schema)
127
+ self._validators[cache_key] = validator
128
+ return validator
129
+
130
+ def _load(self) -> None:
131
+ index_path = self._contracts_root / "index.json"
132
+ if not index_path.exists():
133
+ raise RuntimeError(f"Internal API contract index is missing: {index_path}")
134
+
135
+ index_payload = json.loads(index_path.read_text())
136
+ contracts = index_payload.get("contracts")
137
+ if not isinstance(contracts, list):
138
+ raise RuntimeError(f"Internal API contract index is invalid: {index_path}")
139
+
140
+ aggregate_contracts = [
141
+ contract for contract in contracts if isinstance(contract, dict) and contract.get("kind") == "aggregate"
142
+ ]
143
+ if not aggregate_contracts:
144
+ raise RuntimeError(f"Internal API contract index has no aggregate bundle entries: {index_path}")
145
+
146
+ for contract in aggregate_contracts:
147
+ version = str(contract.get("version") or "").strip()
148
+ files = contract.get("files") or {}
149
+ schema_path = self._contracts_root / str(files.get("schema") or "")
150
+ operations_path = self._contracts_root / str(files.get("operations") or "")
151
+ if not version or not schema_path.exists() or not operations_path.exists():
152
+ raise RuntimeError(
153
+ f"Internal API aggregate contract bundle is incomplete for version {version or '<unknown>'}: {contract}"
154
+ )
155
+
156
+ schema_bundle = json.loads(schema_path.read_text())
157
+ operations_manifest = json.loads(operations_path.read_text())
158
+ operations = operations_manifest.get("operations")
159
+ if not isinstance(operations, list):
160
+ raise RuntimeError(f"Internal API operations manifest is invalid: {operations_path}")
161
+
162
+ self._schema_bundles[version] = schema_bundle
163
+ for operation in operations:
164
+ if not isinstance(operation, dict):
165
+ continue
166
+ method = str(operation.get("method") or "").upper()
167
+ path = str(operation.get("path") or "")
168
+ operation_id = str(operation.get("operation_id") or "").strip()
169
+ request_body = operation.get("request_body") or {}
170
+ content = request_body.get("content") or {}
171
+ request_content_types = tuple(sorted(str(media_type) for media_type in content.keys()))
172
+ application_json = content.get("application/json") if isinstance(content, dict) else None
173
+ schema_ref = None
174
+ if isinstance(application_json, dict):
175
+ schema_ref = application_json.get("schema_ref")
176
+
177
+ self._operations[(method, path)] = OperationContract(
178
+ version=version,
179
+ method=method,
180
+ path=path,
181
+ operation_id=operation_id,
182
+ request_body_required=bool(request_body.get("required", False)),
183
+ request_body_schema_ref=str(schema_ref) if schema_ref else None,
184
+ request_content_types=request_content_types,
185
+ )
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import time
6
+ from typing import Any
7
+
8
+ _STD_KEYS = frozenset(
9
+ {
10
+ "name",
11
+ "msg",
12
+ "args",
13
+ "created",
14
+ "relativeCreated",
15
+ "levelname",
16
+ "levelno",
17
+ "pathname",
18
+ "filename",
19
+ "module",
20
+ "funcName",
21
+ "lineno",
22
+ "thread",
23
+ "threadName",
24
+ "processName",
25
+ "process",
26
+ "msecs",
27
+ "exc_info",
28
+ "exc_text",
29
+ "stack_info",
30
+ "taskName",
31
+ "message",
32
+ }
33
+ )
34
+
35
+
36
+ class _JsonFormatter(logging.Formatter):
37
+ def format(self, record: logging.LogRecord) -> str:
38
+ payload: dict[str, Any] = {
39
+ "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(record.created)),
40
+ "level": record.levelname,
41
+ "logger": record.name,
42
+ "msg": record.getMessage(),
43
+ }
44
+ for key, value in record.__dict__.items():
45
+ if key.startswith("_") or key in _STD_KEYS:
46
+ continue
47
+ payload[key] = value
48
+ if record.exc_info:
49
+ payload["exception"] = self.formatException(record.exc_info)
50
+ return json.dumps(payload, default=str)
51
+
52
+
53
+ def configure_logging(level: str = "INFO") -> None:
54
+ handler = logging.StreamHandler()
55
+ handler.setFormatter(_JsonFormatter())
56
+
57
+ root = logging.getLogger()
58
+ root.handlers.clear()
59
+ root.addHandler(handler)
60
+ root.setLevel(getattr(logging, level.upper(), logging.INFO))
61
+
62
+
63
+ def log_event(logger: logging.Logger, event: str, level: str = "info", **context: Any) -> None:
64
+ raw_extra = {"event": event, **context}
65
+
66
+ # Never allow callers to overwrite reserved LogRecord attributes.
67
+ # Python's logging will raise KeyError (and can crash workers) if `extra`
68
+ # contains any standard LogRecord keys such as `message`.
69
+ safe_extra: dict[str, Any] = {}
70
+ for key, value in raw_extra.items():
71
+ if key in _STD_KEYS:
72
+ safe_extra[f"ctx_{key}"] = value
73
+ else:
74
+ safe_extra[key] = value
75
+
76
+ getattr(logger, level.lower())(event, extra=safe_extra)
@@ -0,0 +1,99 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+
9
+ DEFAULT_SPARK_JAR_DIR = Path.home() / ".dq-spark-jars"
10
+ DIRECT_SPARK_PACKAGE_ARTIFACTS = (
11
+ "spark-avro_2.13",
12
+ "hadoop-aws",
13
+ "delta-spark_2.13",
14
+ "delta-storage",
15
+ "iceberg-spark-runtime-4.0_2.13",
16
+ )
17
+
18
+
19
+ def _artifact_versions(jar_paths: list[Path], artifact_name: str) -> dict[str, list[str]]:
20
+ versions: dict[str, list[str]] = {}
21
+ pattern = re.compile(rf"(?:^|_){re.escape(artifact_name)}-(?P<version>[^/]+)\.jar$")
22
+ for path in jar_paths:
23
+ match = pattern.search(path.name)
24
+ if match is None:
25
+ continue
26
+ versions.setdefault(match.group("version"), []).append(path.name)
27
+ return versions
28
+
29
+
30
+ def _reject_duplicate_direct_artifacts(jar_paths: list[Path]) -> None:
31
+ conflicts: list[str] = []
32
+ for artifact_name in DIRECT_SPARK_PACKAGE_ARTIFACTS:
33
+ versions = _artifact_versions(jar_paths, artifact_name)
34
+ if len(versions) < 2:
35
+ continue
36
+ version_list = ", ".join(f"{version} ({', '.join(names)})" for version, names in sorted(versions.items()))
37
+ conflicts.append(f"{artifact_name}: {version_list}")
38
+
39
+ if conflicts:
40
+ raise SystemExit(
41
+ "Conflicting Spark package jar versions found in the shared Spark jar directory: "
42
+ + "; ".join(conflicts)
43
+ + ". Re-run dq-engine-warmup or clear the spark-jars volume so only the canonical package versions remain."
44
+ )
45
+
46
+
47
+ def spark_jar_paths() -> list[Path]:
48
+ jar_dir = Path(os.getenv("DQ_SPARK_JAR_DIR") or DEFAULT_SPARK_JAR_DIR)
49
+ if not jar_dir.is_dir():
50
+ raise SystemExit(
51
+ f"Spark jar directory not found: {jar_dir}. The dq-engine image must bake the required Spark jars during the build phase."
52
+ )
53
+
54
+ all_jars = sorted(path for path in jar_dir.glob("*.jar") if path.is_file())
55
+ if not all_jars:
56
+ raise SystemExit(
57
+ f"No Spark jars were found in {jar_dir}. The dq-engine image must copy the build-time Spark cache into that directory."
58
+ )
59
+
60
+ max_mb_env = os.getenv("DQ_SPARK_MAX_JAR_SIZE_MB")
61
+ try:
62
+ max_mb = int(max_mb_env) if max_mb_env else 200
63
+ except Exception:
64
+ max_mb = 200
65
+
66
+ include_large = os.getenv("DQ_SPARK_INCLUDE_LARGE_JARS", "").strip().lower() in ("1", "true", "yes")
67
+
68
+ filtered: list[Path] = []
69
+ excluded: list[tuple[str, float]] = []
70
+ for p in all_jars:
71
+ try:
72
+ size_mb = p.stat().st_size / (1024 * 1024)
73
+ except Exception:
74
+ size_mb = 0.0
75
+ if size_mb > max_mb and not include_large:
76
+ excluded.append((p.name, size_mb))
77
+ continue
78
+ filtered.append(p)
79
+
80
+ if not filtered:
81
+ raise SystemExit(
82
+ f"No Spark jars remain after applying size filter (max {max_mb}MB)."
83
+ " Set DQ_SPARK_INCLUDE_LARGE_JARS=1 to include large jars or increase DQ_SPARK_MAX_JAR_SIZE_MB."
84
+ )
85
+
86
+ _reject_duplicate_direct_artifacts(filtered)
87
+
88
+ if excluded:
89
+ names = ", ".join(name for name, _ in excluded[:10])
90
+ print(
91
+ f"warning: excluded {len(excluded)} large jar(s) >{max_mb}MB: {names}{'...' if len(excluded)>10 else ''}"
92
+ )
93
+
94
+ return filtered
95
+
96
+
97
+ def configure_spark_builder_with_local_jars(builder: Any) -> Any:
98
+ jar_paths = spark_jar_paths()
99
+ return builder.config("spark.jars", ",".join(str(path) for path in jar_paths))
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+
7
+ DEFAULT_SPARK_MASTER = "local[*]"
8
+ DEFAULT_SPARK_UI_PORT = 4044
9
+ DEFAULT_SPARK_SESSION_TIMEZONE = "UTC"
10
+
11
+
12
+ def resolve_spark_master(default: str = DEFAULT_SPARK_MASTER) -> str:
13
+ return str(os.getenv("DQ_SPARK_MASTER") or default).strip() or default
14
+
15
+
16
+ def resolve_spark_ui_port(raw_value: str | int | None = None) -> int:
17
+ if raw_value is None:
18
+ raw_value = os.getenv("DQ_SPARK_UI_PORT") or str(DEFAULT_SPARK_UI_PORT)
19
+ normalized = str(raw_value).strip()
20
+ try:
21
+ parsed = int(normalized)
22
+ except Exception as exc:
23
+ raise ValueError("DQ_SPARK_UI_PORT must be a positive integer") from exc
24
+ if parsed < 1:
25
+ raise ValueError("DQ_SPARK_UI_PORT must be a positive integer")
26
+ return parsed
27
+
28
+
29
+ def configure_spark_builder(
30
+ builder: Any,
31
+ *,
32
+ spark_ui_port: str | int | None = None,
33
+ session_timezone: str | None = None,
34
+ ) -> Any:
35
+ configured = builder.config("spark.ui.port", str(resolve_spark_ui_port(spark_ui_port)))
36
+ if session_timezone:
37
+ configured = configured.config("spark.sql.session.timeZone", str(session_timezone))
38
+
39
+ # Allow overriding driver/executor memory from environment variables.
40
+ # Respect DQ-prefixed vars first, then fall back to Spark-standard names.
41
+ driver_mem = os.getenv("DQ_SPARK_DRIVER_MEMORY") or os.getenv("SPARK_DRIVER_MEMORY")
42
+ executor_mem = os.getenv("DQ_SPARK_EXECUTOR_MEMORY") or os.getenv("SPARK_EXECUTOR_MEMORY")
43
+ if driver_mem:
44
+ configured = configured.config("spark.driver.memory", str(driver_mem))
45
+ if executor_mem:
46
+ configured = configured.config("spark.executor.memory", str(executor_mem))
47
+
48
+ return configured
49
+
50
+
51
+ def build_spark_session_builder(
52
+ *,
53
+ SparkSession: Any,
54
+ app_name: str,
55
+ master: str | None = None,
56
+ spark_ui_port: str | int | None = None,
57
+ session_timezone: str | None = None,
58
+ ) -> Any:
59
+ builder = SparkSession.builder.appName(app_name)
60
+ if master is not None:
61
+ builder = builder.master(master)
62
+ return configure_spark_builder(
63
+ builder,
64
+ spark_ui_port=spark_ui_port,
65
+ session_timezone=session_timezone,
66
+ )
@@ -0,0 +1,143 @@
1
+ import os
2
+ import sys
3
+ import importlib.util
4
+ import types
5
+ import logging
6
+ import pytest
7
+
8
+
9
+ # Make local source importable
10
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ SRC_DIR = os.path.join(ROOT_DIR, "dq-utils", "src")
12
+ if SRC_DIR not in sys.path:
13
+ sys.path.insert(0, SRC_DIR)
14
+
15
+ # Load module directly from file to avoid package-level side-effects
16
+ mod_path = os.path.join(SRC_DIR, "dq_utils", "auth_utils.py")
17
+ # Ensure a dq_utils package exists in sys.modules so dataclasses and relative
18
+ # module-level references resolve correctly when loading the module.
19
+ pkg = types.ModuleType("dq_utils")
20
+ pkg.__path__ = [os.path.join(SRC_DIR, "dq_utils")]
21
+ sys.modules["dq_utils"] = pkg
22
+
23
+ spec = importlib.util.spec_from_file_location("dq_utils.auth_utils", mod_path)
24
+ auth_utils = importlib.util.module_from_spec(spec)
25
+ # Ensure the module is present in sys.modules under its intended name so
26
+ # decorators (dataclasses) can resolve module references during class creation.
27
+ sys.modules[spec.name] = auth_utils
28
+ assert spec.loader is not None
29
+ spec.loader.exec_module(auth_utils)
30
+
31
+
32
+ def test_static_token_provider_accepts_and_returns_token():
33
+ with pytest.raises(auth_utils.AuthConfigError):
34
+ auth_utils.StaticTokenProvider("")
35
+
36
+ p = auth_utils.StaticTokenProvider(" secret ")
37
+ assert p.get_token(correlation_id="cid") == "secret"
38
+
39
+
40
+ def test_resolve_oidc_token_url_behaviour():
41
+ assert (
42
+ auth_utils.resolve_oidc_token_url(issuer="https://issuer", token_url=None)
43
+ == "https://issuer/protocol/openid-connect/token"
44
+ )
45
+ assert (
46
+ auth_utils.resolve_oidc_token_url(issuer=None, token_url="https://t")
47
+ == "https://t"
48
+ )
49
+ assert auth_utils.resolve_oidc_token_url(issuer=None, token_url=None) is None
50
+
51
+
52
+ class DummyResponse:
53
+ def __init__(self, status_code=200, payload=None, json_raises=False):
54
+ self.status_code = status_code
55
+ self._payload = payload or {}
56
+ self._json_raises = json_raises
57
+
58
+ def json(self):
59
+ if self._json_raises:
60
+ raise ValueError("not json")
61
+ return self._payload
62
+
63
+
64
+ def test_oidc_client_credentials_get_token_success_and_errors(monkeypatch):
65
+ calls = {}
66
+
67
+ def fake_post_success(url, data=None, headers=None, timeout=None):
68
+ calls['last'] = dict(url=url, data=data, headers=headers, timeout=timeout)
69
+ return DummyResponse(status_code=200, payload={"access_token": "abc", "expires_in": 3600})
70
+
71
+ provider = auth_utils.OidcClientCredentialsTokenProvider(
72
+ token_url="https://tok",
73
+ client_id="cid",
74
+ client_secret="cs",
75
+ scope=None,
76
+ refresh_skew_seconds=60,
77
+ timeout_seconds=1,
78
+ )
79
+
80
+ # network success
81
+ monkeypatch.setattr(auth_utils.requests, "post", fake_post_success)
82
+ token = provider.get_token(correlation_id="cid")
83
+ assert token == "abc"
84
+
85
+ # Clear cache so subsequent calls actually invoke the token endpoint
86
+ provider._cached = None
87
+
88
+ # response with error status
89
+ def fake_post_400(*a, **k):
90
+ return DummyResponse(status_code=400, payload={})
91
+
92
+ monkeypatch.setattr(auth_utils.requests, "post", fake_post_400)
93
+ with pytest.raises(auth_utils.AuthConfigError):
94
+ provider.get_token(correlation_id="cid")
95
+
96
+ # response with non-json
97
+ # Clear cache again for next scenario
98
+ provider._cached = None
99
+
100
+ def fake_post_nonjson(*a, **k):
101
+ return DummyResponse(status_code=200, json_raises=True)
102
+
103
+ monkeypatch.setattr(auth_utils.requests, "post", fake_post_nonjson)
104
+ with pytest.raises(auth_utils.AuthConfigError):
105
+ provider.get_token(correlation_id="cid")
106
+
107
+ # network exception
108
+ def fake_post_exc(*a, **k):
109
+ raise RuntimeError("boom")
110
+
111
+ monkeypatch.setattr(auth_utils.requests, "post", fake_post_exc)
112
+ with pytest.raises(auth_utils.AuthConfigError):
113
+ provider.get_token(correlation_id="cid")
114
+
115
+
116
+ def test_build_token_provider_from_env_prefers_static(monkeypatch):
117
+ monkeypatch.setenv("MY_STATIC_TOKEN", "s1")
118
+ p = auth_utils.build_token_provider_from_env(
119
+ static_token_env_var="MY_STATIC_TOKEN",
120
+ issuer_env_var="ISS",
121
+ token_url_env_var="T",
122
+ client_id_env_var="CID",
123
+ client_secret_env_var="CS",
124
+ scope_env_var="S",
125
+ )
126
+ assert isinstance(p, auth_utils.StaticTokenProvider)
127
+ assert p.get_token(correlation_id="c") == "s1"
128
+
129
+
130
+ def test_build_oidc_token_provider_from_env_raises_when_missing(monkeypatch):
131
+ monkeypatch.delenv("ISS", raising=False)
132
+ monkeypatch.delenv("T", raising=False)
133
+ monkeypatch.delenv("CID", raising=False)
134
+ monkeypatch.delenv("CS", raising=False)
135
+
136
+ with pytest.raises(auth_utils.AuthConfigError):
137
+ auth_utils.build_oidc_token_provider_from_env(
138
+ issuer_env_var="ISS",
139
+ token_url_env_var="T",
140
+ client_id_env_var="CID",
141
+ client_secret_env_var="CS",
142
+ scope_env_var="S",
143
+ )
@@ -0,0 +1,83 @@
1
+ import os
2
+ import sys
3
+ import json
4
+ import logging
5
+
6
+
7
+ # Ensure local package source is importable when running this test file directly.
8
+ ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
9
+ SRC_DIR = os.path.join(ROOT_DIR, "dq-utils", "src")
10
+ if SRC_DIR not in sys.path:
11
+ sys.path.insert(0, SRC_DIR)
12
+
13
+ import importlib.util
14
+
15
+ # Load the module directly from its source file path to avoid importing
16
+ # dq_utils.__init__ (which pulls heavy optional deps during import).
17
+ mod_path = os.path.join(SRC_DIR, "dq_utils", "logging_utils.py")
18
+ spec = importlib.util.spec_from_file_location("dq_utils_logging_utils", mod_path)
19
+ logging_utils = importlib.util.module_from_spec(spec)
20
+ assert spec.loader is not None
21
+ spec.loader.exec_module(logging_utils)
22
+
23
+ _JsonFormatter = logging_utils._JsonFormatter
24
+ configure_logging = logging_utils.configure_logging
25
+ log_event = logging_utils.log_event
26
+
27
+
28
+ def test_json_formatter_includes_custom_fields():
29
+ fmt = _JsonFormatter()
30
+ record = logging.LogRecord(
31
+ name="mylogger",
32
+ level=logging.INFO,
33
+ pathname=__file__,
34
+ lineno=10,
35
+ msg="hello",
36
+ args=(),
37
+ exc_info=None,
38
+ )
39
+ # add a non-standard attribute which should be included in the JSON
40
+ record.__dict__["custom_key"] = "custom_value"
41
+
42
+ payload = fmt.format(record)
43
+ data = json.loads(payload)
44
+
45
+ assert data["logger"] == "mylogger"
46
+ assert data["msg"] == "hello"
47
+ assert data["custom_key"] == "custom_value"
48
+ assert "ts" in data and "level" in data
49
+
50
+
51
+ def test_configure_logging_sets_handler_and_level():
52
+ # configure logging and assert root logger has a StreamHandler and correct level
53
+ configure_logging("WARNING")
54
+ root = logging.getLogger()
55
+ assert any(isinstance(h, logging.StreamHandler) for h in root.handlers)
56
+ assert root.level == logging.WARNING
57
+
58
+
59
+ def test_log_event_safe_extra_and_reserved_prefix():
60
+ captured = []
61
+
62
+ class ListHandler(logging.Handler):
63
+ def emit(self, rec: logging.LogRecord) -> None: # type: ignore[override]
64
+ # store a shallow copy of the record dict so assertions can inspect it
65
+ captured.append(rec.__dict__.copy())
66
+
67
+ logger = logging.getLogger("test_logger_for_log_event")
68
+ # ensure a clean handler set for the logger used in this test
69
+ logger.handlers.clear()
70
+ handler = ListHandler()
71
+ logger.addHandler(handler)
72
+ logger.setLevel(logging.DEBUG)
73
+
74
+ # Call log_event with a reserved key ('message') and a normal key ('user')
75
+ log_event(logger, "evt", level="info", message="danger", user="alice")
76
+
77
+ assert captured, "expected a log record to be emitted"
78
+ rec = captured[-1]
79
+
80
+ # reserved key should be prefixed to avoid overwriting LogRecord internals
81
+ assert rec.get("ctx_message") == "danger"
82
+ assert rec.get("user") == "alice"
83
+ assert rec.get("msg") == "evt"