pytest-dag 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytest_dag/__init__.py +3 -0
- pytest_dag/_build_config.py +4 -0
- pytest_dag/_pytest_dag_core.py +683 -0
- pytest_dag/config.py +136 -0
- pytest_dag/graph.py +349 -0
- pytest_dag/migrate.py +584 -0
- pytest_dag/plugin.py +525 -0
- pytest_dag-3.0.0.dist-info/METADATA +224 -0
- pytest_dag-3.0.0.dist-info/RECORD +13 -0
- pytest_dag-3.0.0.dist-info/WHEEL +5 -0
- pytest_dag-3.0.0.dist-info/entry_points.txt +2 -0
- pytest_dag-3.0.0.dist-info/licenses/LICENSE +44 -0
- pytest_dag-3.0.0.dist-info/top_level.txt +1 -0
pytest_dag/__init__.py
ADDED
|
@@ -0,0 +1,683 @@
|
|
|
1
|
+
"""Pure-Python core for pytest-dag.
|
|
2
|
+
|
|
3
|
+
This module is a drop-in replacement for the previous Rust extension module
|
|
4
|
+
`pytest_dag._pytest_dag_core`.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import base64
|
|
10
|
+
import importlib
|
|
11
|
+
import ipaddress
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
import secrets
|
|
15
|
+
import urllib.error
|
|
16
|
+
import urllib.request
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
from urllib.parse import urlparse
|
|
21
|
+
|
|
22
|
+
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
|
|
23
|
+
from cryptography.hazmat.primitives.serialization import load_pem_public_key
|
|
24
|
+
|
|
25
|
+
from ._build_config import OFFICIAL_LICENSE_ENDPOINT, OFFICIAL_SIGNING_PUBLIC_KEY_PEM
|
|
26
|
+
from .graph import _detect_cycles, _load_yaml_deps
|
|
27
|
+
|
|
28
|
+
PLAN_PROTOCOL_VERSION = "1"
|
|
29
|
+
SIG_CLOCK_SKEW_SECS = 30
|
|
30
|
+
HTTP_TIMEOUT_SECS = 5
|
|
31
|
+
DEFAULT_USER_AGENT = (
|
|
32
|
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
|
33
|
+
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
|
34
|
+
"Chrome/133.0.0.0 Safari/537.36"
|
|
35
|
+
)
|
|
36
|
+
_OUTCOME_ORDER = {
|
|
37
|
+
"PASSED",
|
|
38
|
+
"FAILED",
|
|
39
|
+
"SKIPPED",
|
|
40
|
+
"DAG_SKIPPED",
|
|
41
|
+
"XFAILED",
|
|
42
|
+
"XPASSED",
|
|
43
|
+
"ERROR",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _debug(msg: str) -> None:
|
|
48
|
+
if os.getenv("PYTEST_DAG_DEBUG") == "1":
|
|
49
|
+
print(f"pytest-dag [DEBUG] {msg}", file=os.sys.stderr)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _license_error(message: str) -> RuntimeError:
|
|
53
|
+
return RuntimeError(message)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _internal_error(message: str) -> RuntimeError:
|
|
57
|
+
return RuntimeError(message)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _env_true(name: str) -> bool:
|
|
61
|
+
return os.getenv(name, "").lower() in {"1", "true", "yes", "on"}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _normalize_endpoint(raw: str) -> str:
|
|
65
|
+
return raw.strip().rstrip("/")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _is_local_or_private_host(host: str) -> bool:
|
|
69
|
+
h = host.lower()
|
|
70
|
+
if h in {"localhost", "ip6-localhost"}:
|
|
71
|
+
return True
|
|
72
|
+
try:
|
|
73
|
+
ip = ipaddress.ip_address(h)
|
|
74
|
+
except ValueError:
|
|
75
|
+
return False
|
|
76
|
+
if isinstance(ip, ipaddress.IPv4Address):
|
|
77
|
+
return ip.is_loopback or ip.is_private or ip.is_link_local
|
|
78
|
+
return ip.is_loopback
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _resolve_endpoint() -> str:
|
|
82
|
+
raw = OFFICIAL_LICENSE_ENDPOINT.strip()
|
|
83
|
+
if not raw or raw.startswith("__UNSET_"):
|
|
84
|
+
raise _internal_error(
|
|
85
|
+
"Official license endpoint is not configured in this build. "
|
|
86
|
+
"Set repository environment and regenerate build config."
|
|
87
|
+
)
|
|
88
|
+
endpoint = _normalize_endpoint(raw)
|
|
89
|
+
parsed = urlparse(endpoint)
|
|
90
|
+
if parsed.scheme != "https":
|
|
91
|
+
raise _license_error("Official license endpoint must use HTTPS")
|
|
92
|
+
if not parsed.hostname:
|
|
93
|
+
raise _license_error("Invalid default license endpoint URL: missing host")
|
|
94
|
+
return endpoint
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _builtin_public_key() -> Ed25519PublicKey:
|
|
98
|
+
pem = OFFICIAL_SIGNING_PUBLIC_KEY_PEM.strip()
|
|
99
|
+
if not pem or pem.startswith("__UNSET_"):
|
|
100
|
+
raise _internal_error(
|
|
101
|
+
"Official signing public key is not configured in this build. "
|
|
102
|
+
"Set repository environment and regenerate build config."
|
|
103
|
+
)
|
|
104
|
+
try:
|
|
105
|
+
key = load_pem_public_key(pem.encode("utf-8"))
|
|
106
|
+
except Exception as exc:
|
|
107
|
+
raise _internal_error(f"Failed to load built-in public key: {exc}") from None
|
|
108
|
+
if not isinstance(key, Ed25519PublicKey):
|
|
109
|
+
raise _internal_error("Failed to load built-in public key: not Ed25519")
|
|
110
|
+
return key
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _load_pem_public_key(pem: str) -> Ed25519PublicKey:
|
|
114
|
+
try:
|
|
115
|
+
key = load_pem_public_key(pem.encode("utf-8"))
|
|
116
|
+
except Exception as exc:
|
|
117
|
+
raise _internal_error(f"Failed to load custom public key: {exc}") from None
|
|
118
|
+
if not isinstance(key, Ed25519PublicKey):
|
|
119
|
+
raise _internal_error("Failed to load custom public key: not Ed25519")
|
|
120
|
+
return key
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _verifying_key_for_endpoint(endpoint: str) -> Ed25519PublicKey:
|
|
124
|
+
_ = endpoint
|
|
125
|
+
return _builtin_public_key()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _generate_nonce() -> str:
|
|
129
|
+
return base64.urlsafe_b64encode(secrets.token_bytes(24)).rstrip(b"=").decode("ascii")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _canonical_json(data: dict[str, Any]) -> bytes:
|
|
133
|
+
try:
|
|
134
|
+
return json.dumps(
|
|
135
|
+
{k: data[k] for k in sorted(data)},
|
|
136
|
+
separators=(",", ":"),
|
|
137
|
+
ensure_ascii=False,
|
|
138
|
+
).encode("utf-8")
|
|
139
|
+
except Exception as exc:
|
|
140
|
+
raise _internal_error(f"canonical_json serialisation failed: {exc}") from None
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _b64url_decode(value: str) -> bytes:
|
|
144
|
+
stripped = value.strip().rstrip("=")
|
|
145
|
+
padded = stripped + ("=" * ((4 - len(stripped) % 4) % 4))
|
|
146
|
+
try:
|
|
147
|
+
return base64.urlsafe_b64decode(padded)
|
|
148
|
+
except Exception as exc:
|
|
149
|
+
raise _license_error(f"Base64 decode failed: {exc}") from None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _verify_ed25519(
|
|
153
|
+
verifying_key: Ed25519PublicKey, message: bytes, signature_bytes: bytes
|
|
154
|
+
) -> None:
|
|
155
|
+
if len(signature_bytes) != 64:
|
|
156
|
+
raise _license_error("Invalid Ed25519 signature length")
|
|
157
|
+
try:
|
|
158
|
+
verifying_key.verify(signature_bytes, message)
|
|
159
|
+
except Exception:
|
|
160
|
+
raise _license_error("Ed25519 signature verification failed") from None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _user_agent() -> str:
|
|
164
|
+
return os.getenv("PYTEST_DAG_USER_AGENT", DEFAULT_USER_AGENT)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _http_json(
|
|
168
|
+
method: str,
|
|
169
|
+
url: str,
|
|
170
|
+
headers: dict[str, str],
|
|
171
|
+
payload: dict[str, Any] | None = None,
|
|
172
|
+
timeout_secs: int = HTTP_TIMEOUT_SECS,
|
|
173
|
+
) -> dict[str, Any]:
|
|
174
|
+
req_headers = {
|
|
175
|
+
"User-Agent": _user_agent(),
|
|
176
|
+
"Accept": "application/json, text/plain, */*",
|
|
177
|
+
"Accept-Language": "en-US,en;q=0.9",
|
|
178
|
+
"Cache-Control": "no-cache",
|
|
179
|
+
"Pragma": "no-cache",
|
|
180
|
+
}
|
|
181
|
+
req_headers.update(headers)
|
|
182
|
+
|
|
183
|
+
data: bytes | None = None
|
|
184
|
+
if payload is not None:
|
|
185
|
+
req_headers["Content-Type"] = "application/json"
|
|
186
|
+
data = json.dumps(payload).encode("utf-8")
|
|
187
|
+
|
|
188
|
+
req = urllib.request.Request(url=url, data=data, headers=req_headers, method=method)
|
|
189
|
+
try:
|
|
190
|
+
with urllib.request.urlopen(req, timeout=timeout_secs) as resp:
|
|
191
|
+
body = resp.read().decode("utf-8")
|
|
192
|
+
except urllib.error.URLError as exc:
|
|
193
|
+
if method == "GET":
|
|
194
|
+
raise _license_error(f"pytest-dag: License server unreachable. ({exc})")
|
|
195
|
+
raise _license_error(
|
|
196
|
+
f"pytest-dag: License server unreachable during validation. ({exc})"
|
|
197
|
+
)
|
|
198
|
+
except Exception as exc:
|
|
199
|
+
raise _license_error(f"HTTP {method} failed: {exc}") from None
|
|
200
|
+
|
|
201
|
+
_debug(f"HTTP {method} response body (first 300): {body[:300]}")
|
|
202
|
+
try:
|
|
203
|
+
value = json.loads(body)
|
|
204
|
+
except Exception as exc:
|
|
205
|
+
raise _license_error(f"HTTP {method} JSON parse failed: {exc}") from None
|
|
206
|
+
if not isinstance(value, dict):
|
|
207
|
+
raise _license_error("License response is not a JSON object")
|
|
208
|
+
return value
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _required_str(data: dict[str, Any], key: str) -> str:
|
|
212
|
+
v = data.get(key)
|
|
213
|
+
if isinstance(v, str) and v:
|
|
214
|
+
return v
|
|
215
|
+
raise _license_error(f"Missing or invalid signed field: {key}")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _required_i64(data: dict[str, Any], key: str) -> int:
|
|
219
|
+
v = data.get(key)
|
|
220
|
+
if isinstance(v, bool):
|
|
221
|
+
raise _license_error(f"Missing or invalid signed field: {key}")
|
|
222
|
+
if isinstance(v, int):
|
|
223
|
+
return v
|
|
224
|
+
raise _license_error(f"Missing or invalid signed field: {key}")
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _build_signed_payload(ctx: str, data: dict[str, Any]) -> dict[str, Any]:
|
|
228
|
+
payload: dict[str, Any] = {
|
|
229
|
+
"ctx": ctx,
|
|
230
|
+
"sig_v": _required_i64(data, "sig_v"),
|
|
231
|
+
"sig_alg": _required_str(data, "sig_alg"),
|
|
232
|
+
"sig_kid": _required_str(data, "sig_kid"),
|
|
233
|
+
"sig_ts": _required_i64(data, "sig_ts"),
|
|
234
|
+
"sig_exp": _required_i64(data, "sig_exp"),
|
|
235
|
+
"sig_nonce": _required_str(data, "sig_nonce"),
|
|
236
|
+
}
|
|
237
|
+
if ctx == "feature-flag":
|
|
238
|
+
value = data.get("paywall_enabled")
|
|
239
|
+
if not isinstance(value, bool):
|
|
240
|
+
raise _license_error("Missing or invalid signed field: paywall_enabled")
|
|
241
|
+
payload["paywall_enabled"] = value
|
|
242
|
+
elif ctx == "validate":
|
|
243
|
+
value = data.get("valid")
|
|
244
|
+
if not isinstance(value, bool):
|
|
245
|
+
raise _license_error("Missing or invalid signed field: valid")
|
|
246
|
+
payload["valid"] = value
|
|
247
|
+
payload["reason"] = data.get("reason")
|
|
248
|
+
payload["expiry"] = data.get("expiry")
|
|
249
|
+
payload["user_count"] = data.get("user_count")
|
|
250
|
+
payload["license_type"] = data.get("license_type")
|
|
251
|
+
elif ctx == "plan-start":
|
|
252
|
+
payload["plan_session_id"] = _required_str(data, "plan_session_id")
|
|
253
|
+
order = data.get("execution_order")
|
|
254
|
+
if not isinstance(order, list) or not all(isinstance(v, str) for v in order):
|
|
255
|
+
raise _license_error("Missing or invalid signed field: execution_order")
|
|
256
|
+
payload["execution_order"] = order
|
|
257
|
+
payload["protocol_version"] = _required_str(data, "protocol_version")
|
|
258
|
+
payload["server_version"] = _required_str(data, "server_version")
|
|
259
|
+
else:
|
|
260
|
+
raise _internal_error(f"Unknown signature context: {ctx}")
|
|
261
|
+
return payload
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _now_unix() -> int:
|
|
265
|
+
import time
|
|
266
|
+
|
|
267
|
+
return int(time.time())
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _verify_signed_response(
|
|
271
|
+
ctx: str, data: dict[str, Any], expected_nonce: str, endpoint: str
|
|
272
|
+
) -> None:
|
|
273
|
+
sig = _required_str(data, "sig")
|
|
274
|
+
sig_alg = _required_str(data, "sig_alg")
|
|
275
|
+
sig_v = _required_i64(data, "sig_v")
|
|
276
|
+
sig_nonce = _required_str(data, "sig_nonce")
|
|
277
|
+
sig_ts = _required_i64(data, "sig_ts")
|
|
278
|
+
sig_exp = _required_i64(data, "sig_exp")
|
|
279
|
+
|
|
280
|
+
if sig_alg != "Ed25519":
|
|
281
|
+
raise _license_error("Unsupported license response signature algorithm")
|
|
282
|
+
if sig_v != 1:
|
|
283
|
+
raise _license_error("Unsupported license response signature version")
|
|
284
|
+
if sig_nonce != expected_nonce:
|
|
285
|
+
raise _license_error("License response nonce mismatch")
|
|
286
|
+
|
|
287
|
+
now = _now_unix()
|
|
288
|
+
if sig_ts > now + SIG_CLOCK_SKEW_SECS:
|
|
289
|
+
raise _license_error("License response signature timestamp is in the future")
|
|
290
|
+
if sig_exp < now - SIG_CLOCK_SKEW_SECS:
|
|
291
|
+
raise _license_error("License response signature has expired")
|
|
292
|
+
|
|
293
|
+
payload = _build_signed_payload(ctx, data)
|
|
294
|
+
message = _canonical_json(payload)
|
|
295
|
+
signature = _b64url_decode(sig)
|
|
296
|
+
key = _verifying_key_for_endpoint(endpoint)
|
|
297
|
+
_verify_ed25519(key, message, signature)
|
|
298
|
+
|
|
299
|
+
_debug(
|
|
300
|
+
f"{ctx} response signature verified "
|
|
301
|
+
f"(kid={data.get('sig_kid', '?')}, exp={sig_exp})"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
@dataclass
|
|
306
|
+
class DagSession:
|
|
307
|
+
"""Session object compatible with the old Rust DagSession surface."""
|
|
308
|
+
|
|
309
|
+
_token: bytes
|
|
310
|
+
license_checked: bool
|
|
311
|
+
paywall_was_active: bool
|
|
312
|
+
deps: dict[str, set[str]] = field(default_factory=dict)
|
|
313
|
+
reverse_deps: dict[str, set[str]] = field(default_factory=dict)
|
|
314
|
+
results: dict[str, str] = field(default_factory=dict)
|
|
315
|
+
ordered: list[str] = field(default_factory=list)
|
|
316
|
+
license_key: str | None = None
|
|
317
|
+
endpoint: str = ""
|
|
318
|
+
plan_session_id: str | None = None
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def test_count(self) -> int:
|
|
322
|
+
return len(self.ordered)
|
|
323
|
+
|
|
324
|
+
def __repr__(self) -> str:
|
|
325
|
+
return (
|
|
326
|
+
f"<DagSession license_checked={self.license_checked} "
|
|
327
|
+
f"paywall={self.paywall_was_active} tests={len(self.ordered)}>"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def add_dependency(self, dependent: str, dependency: str) -> None:
|
|
331
|
+
self.deps.setdefault(dependent, set()).add(dependency)
|
|
332
|
+
self.reverse_deps.setdefault(dependency, set()).add(dependent)
|
|
333
|
+
self.deps.setdefault(dependency, set())
|
|
334
|
+
self.reverse_deps.setdefault(dependent, set())
|
|
335
|
+
|
|
336
|
+
def blocking_deps(self, nodeid: str, block_on_str: str) -> list[str]:
|
|
337
|
+
block_on = {s.strip().upper() for s in block_on_str.split(",")}
|
|
338
|
+
blocked: list[str] = []
|
|
339
|
+
for dep in self.deps.get(nodeid, set()):
|
|
340
|
+
outcome = self.results.get(dep)
|
|
341
|
+
if outcome is None:
|
|
342
|
+
continue
|
|
343
|
+
if outcome == "DAG_SKIPPED" or outcome in block_on:
|
|
344
|
+
blocked.append(dep)
|
|
345
|
+
return blocked
|
|
346
|
+
|
|
347
|
+
def record(self, nodeid: str, outcome: str, overwrite: bool) -> None:
|
|
348
|
+
value = outcome if outcome in _OUTCOME_ORDER else "FAILED"
|
|
349
|
+
if overwrite or nodeid not in self.results:
|
|
350
|
+
self.results[nodeid] = value
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def _enforce_license(license_key: str | None) -> bool:
|
|
354
|
+
endpoint = _resolve_endpoint()
|
|
355
|
+
_debug(f"license endpoint = {endpoint}")
|
|
356
|
+
|
|
357
|
+
ff_nonce = _generate_nonce()
|
|
358
|
+
ff_url = f"{endpoint}/api/pytest-dag/v1/feature-flag"
|
|
359
|
+
_debug(f"GET {ff_url}")
|
|
360
|
+
flag = _http_json("GET", ff_url, {"X-Pytest-Dag-Nonce": ff_nonce})
|
|
361
|
+
try:
|
|
362
|
+
_verify_signed_response("feature-flag", flag, ff_nonce, endpoint)
|
|
363
|
+
except RuntimeError as exc:
|
|
364
|
+
raise _license_error(
|
|
365
|
+
f"pytest-dag: License server returned an invalid response from {ff_url}.\n"
|
|
366
|
+
"This usually means the endpoint is not a pytest-dag license server.\n"
|
|
367
|
+
f"Detail: {exc}\n"
|
|
368
|
+
"Hint: verify your backend deployment and build-time endpoint configuration.\n\n"
|
|
369
|
+
"Purchase at:\n"
|
|
370
|
+
" https://pytest-dag.slrsoft.ca/licenses/purchase\n"
|
|
371
|
+
"Support: support@slrsoft.ca"
|
|
372
|
+
) from None
|
|
373
|
+
|
|
374
|
+
paywall_enabled = bool(flag.get("paywall_enabled", False))
|
|
375
|
+
_debug(f"paywall_enabled={paywall_enabled}")
|
|
376
|
+
if not paywall_enabled:
|
|
377
|
+
_debug("paywall_enabled=false -> no license required, proceeding")
|
|
378
|
+
return False
|
|
379
|
+
|
|
380
|
+
_debug("paywall_enabled=true -> license key required")
|
|
381
|
+
if license_key is None:
|
|
382
|
+
raise _license_error(
|
|
383
|
+
"pytest-dag: License key not provided.\n"
|
|
384
|
+
"Set PYTEST_DAG_LICENSE_KEY=pd-XXXX-XXXX-XXXX-XXXX or use\n"
|
|
385
|
+
"--pytest-dag-license-key / --pytest-dag-license-key-file\n\n"
|
|
386
|
+
"Purchase at:\n"
|
|
387
|
+
" https://pytest-dag.slrsoft.ca/licenses/purchase\n"
|
|
388
|
+
"Support: support@slrsoft.ca"
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
val_nonce = _generate_nonce()
|
|
392
|
+
val_url = f"{endpoint}/api/pytest-dag/v1/validate"
|
|
393
|
+
_debug(f"POST {val_url} (key={license_key[:8]}...)")
|
|
394
|
+
result = _http_json(
|
|
395
|
+
"POST",
|
|
396
|
+
val_url,
|
|
397
|
+
{"X-Pytest-Dag-Nonce": val_nonce},
|
|
398
|
+
{"license_key": license_key},
|
|
399
|
+
)
|
|
400
|
+
try:
|
|
401
|
+
_verify_signed_response("validate", result, val_nonce, endpoint)
|
|
402
|
+
except RuntimeError as exc:
|
|
403
|
+
raise _license_error(
|
|
404
|
+
"pytest-dag: License server returned an invalid validation response from "
|
|
405
|
+
f"{val_url}.\n"
|
|
406
|
+
f"Detail: {exc}\n\n"
|
|
407
|
+
"Hint: verify the deployed license service is returning Ed25519-signed "
|
|
408
|
+
"responses for /api/pytest-dag/v1/validate.\n\n"
|
|
409
|
+
"Purchase at:\n"
|
|
410
|
+
" https://pytest-dag.slrsoft.ca/licenses/purchase\n"
|
|
411
|
+
"Support: support@slrsoft.ca"
|
|
412
|
+
) from None
|
|
413
|
+
|
|
414
|
+
valid = bool(result.get("valid", False))
|
|
415
|
+
if not valid:
|
|
416
|
+
reason = result.get("reason") or "unknown error"
|
|
417
|
+
expiry = result.get("expiry") or ""
|
|
418
|
+
msg = f"License validation failed.\n Reason: {reason}"
|
|
419
|
+
if expiry:
|
|
420
|
+
msg += f"\n Expiry: {expiry}"
|
|
421
|
+
msg += "\n\nRenew at: https://pytest-dag.slrsoft.ca/licenses/purchase"
|
|
422
|
+
raise _license_error(msg)
|
|
423
|
+
|
|
424
|
+
_debug("license valid -> proceeding")
|
|
425
|
+
return True
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _load_local_test_mode_adapter():
|
|
429
|
+
module_name = "test_support.pytest_dag_local_test_mode"
|
|
430
|
+
try:
|
|
431
|
+
module = importlib.import_module(module_name)
|
|
432
|
+
except Exception as exc:
|
|
433
|
+
raise _internal_error(
|
|
434
|
+
"PYTEST_DAG_TEST_MODE is only supported for local test runs with a local "
|
|
435
|
+
"test adapter module present. This build/runtime does not provide "
|
|
436
|
+
"'test_support.pytest_dag_local_test_mode'. Disable PYTEST_DAG_TEST_MODE."
|
|
437
|
+
) from exc
|
|
438
|
+
|
|
439
|
+
create_session = getattr(module, "create_dag_session", None)
|
|
440
|
+
if not callable(create_session):
|
|
441
|
+
raise _internal_error(
|
|
442
|
+
"Invalid local test adapter: expected callable "
|
|
443
|
+
"test_support.pytest_dag_local_test_mode.create_dag_session(...)"
|
|
444
|
+
)
|
|
445
|
+
return create_session
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def enforce_and_init(license_key: str | None = None) -> DagSession:
|
|
449
|
+
token = secrets.token_bytes(32)
|
|
450
|
+
if os.getenv("PYTEST_DAG_TEST_MODE") == "1":
|
|
451
|
+
create_session = _load_local_test_mode_adapter()
|
|
452
|
+
return create_session(
|
|
453
|
+
token=token,
|
|
454
|
+
license_key=license_key,
|
|
455
|
+
session_cls=DagSession,
|
|
456
|
+
enforce_license=_enforce_license,
|
|
457
|
+
resolve_endpoint=_resolve_endpoint,
|
|
458
|
+
)
|
|
459
|
+
paywall_active = _enforce_license(license_key)
|
|
460
|
+
return DagSession(
|
|
461
|
+
token,
|
|
462
|
+
license_checked=True,
|
|
463
|
+
paywall_was_active=paywall_active,
|
|
464
|
+
license_key=license_key,
|
|
465
|
+
endpoint=_resolve_endpoint(),
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def _start_remote_plan(session: DagSession, items: list[dict[str, Any]]) -> list[str]:
|
|
470
|
+
nonce = _generate_nonce()
|
|
471
|
+
nodes = [item["nodeid"] for item in items]
|
|
472
|
+
collected = set(nodes)
|
|
473
|
+
deps: dict[str, list[str]] = {}
|
|
474
|
+
for node in nodes:
|
|
475
|
+
deps[node] = sorted(d for d in session.deps.get(node, set()) if d in collected)
|
|
476
|
+
|
|
477
|
+
payload = {
|
|
478
|
+
"protocol_version": PLAN_PROTOCOL_VERSION,
|
|
479
|
+
"client_version": get_version(),
|
|
480
|
+
"license_key": session.license_key,
|
|
481
|
+
"nodes": nodes,
|
|
482
|
+
"deps": deps,
|
|
483
|
+
}
|
|
484
|
+
url = f"{session.endpoint}/api/pytest-dag/v1/plan/start"
|
|
485
|
+
result = _http_json("POST", url, {"X-Pytest-Dag-Nonce": nonce}, payload)
|
|
486
|
+
_verify_signed_response("plan-start", result, nonce, session.endpoint)
|
|
487
|
+
session.plan_session_id = _required_str(result, "plan_session_id")
|
|
488
|
+
|
|
489
|
+
order = result.get("execution_order")
|
|
490
|
+
if not isinstance(order, list) or not all(isinstance(v, str) for v in order):
|
|
491
|
+
raise _license_error("Invalid execution_order in plan/start response")
|
|
492
|
+
if set(order) != set(nodes):
|
|
493
|
+
raise _license_error("plan/start execution_order does not match collected tests")
|
|
494
|
+
|
|
495
|
+
protocol = _required_str(result, "protocol_version")
|
|
496
|
+
if protocol != PLAN_PROTOCOL_VERSION:
|
|
497
|
+
raise _license_error(
|
|
498
|
+
f"Protocol mismatch: server={protocol}, client={PLAN_PROTOCOL_VERSION}"
|
|
499
|
+
)
|
|
500
|
+
return order
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def _report_remote_outcome(session: DagSession, nodeid: str, outcome: str) -> None:
|
|
504
|
+
if not session.plan_session_id:
|
|
505
|
+
return
|
|
506
|
+
url = f"{session.endpoint}/api/pytest-dag/v1/plan/report"
|
|
507
|
+
_http_json(
|
|
508
|
+
"POST",
|
|
509
|
+
url,
|
|
510
|
+
{},
|
|
511
|
+
{
|
|
512
|
+
"plan_session_id": session.plan_session_id,
|
|
513
|
+
"node_id": nodeid,
|
|
514
|
+
"outcome": outcome,
|
|
515
|
+
},
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def finalize_session(session: DagSession) -> None:
|
|
520
|
+
if not session.plan_session_id:
|
|
521
|
+
return
|
|
522
|
+
url = f"{session.endpoint}/api/pytest-dag/v1/plan/finish"
|
|
523
|
+
try:
|
|
524
|
+
_http_json(
|
|
525
|
+
"POST",
|
|
526
|
+
url,
|
|
527
|
+
{},
|
|
528
|
+
{"plan_session_id": session.plan_session_id},
|
|
529
|
+
)
|
|
530
|
+
finally:
|
|
531
|
+
session.plan_session_id = None
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def build_dag(
|
|
535
|
+
session: DagSession,
|
|
536
|
+
items: list[dict[str, Any]],
|
|
537
|
+
dag_file: str | None = None,
|
|
538
|
+
rootdir: str = ".",
|
|
539
|
+
strict: bool = True,
|
|
540
|
+
) -> None:
|
|
541
|
+
nodeid_map = {item["nodeid"]: item["nodeid"] for item in items}
|
|
542
|
+
short_name_map: dict[tuple[str, str], list[str]] = {}
|
|
543
|
+
for item in items:
|
|
544
|
+
nodeid = item["nodeid"]
|
|
545
|
+
module = nodeid.split("::")[0]
|
|
546
|
+
short = nodeid.split("::")[-1]
|
|
547
|
+
short_name_map.setdefault((module, short), []).append(nodeid)
|
|
548
|
+
|
|
549
|
+
def _resolve(dep_str: str, context_nodeid: str) -> str | None:
|
|
550
|
+
if "::" in dep_str:
|
|
551
|
+
return nodeid_map.get(dep_str)
|
|
552
|
+
module = context_nodeid.split("::")[0]
|
|
553
|
+
matches = short_name_map.get((module, dep_str), [])
|
|
554
|
+
if len(matches) == 1:
|
|
555
|
+
return matches[0]
|
|
556
|
+
return None
|
|
557
|
+
|
|
558
|
+
errors: list[str] = []
|
|
559
|
+
|
|
560
|
+
def _resolve_add(dependent: str, dep_str: str, context_nodeid: str) -> None:
|
|
561
|
+
resolved = _resolve(dep_str, context_nodeid)
|
|
562
|
+
if resolved is not None:
|
|
563
|
+
session.add_dependency(dependent, resolved)
|
|
564
|
+
return
|
|
565
|
+
msg = (
|
|
566
|
+
f"pytest-dag: dependency not found: {dep_str!r} "
|
|
567
|
+
f"(required by {dependent!r})"
|
|
568
|
+
)
|
|
569
|
+
if strict:
|
|
570
|
+
errors.append(msg)
|
|
571
|
+
else:
|
|
572
|
+
virtual = f"<missing:{dep_str}>"
|
|
573
|
+
session.results[virtual] = "FAILED"
|
|
574
|
+
session.add_dependency(dependent, virtual)
|
|
575
|
+
|
|
576
|
+
for item in items:
|
|
577
|
+
for dep_str in list(item.get("raw_deps", []) or []):
|
|
578
|
+
_resolve_add(item["nodeid"], dep_str, item["nodeid"])
|
|
579
|
+
|
|
580
|
+
if dag_file:
|
|
581
|
+
yaml_deps = _load_yaml_deps(dag_file, Path(rootdir))
|
|
582
|
+
first_nodeid = items[0]["nodeid"] if items else ""
|
|
583
|
+
for dependent, dep_list in yaml_deps.items():
|
|
584
|
+
context = dependent if dependent in nodeid_map else first_nodeid
|
|
585
|
+
for dep_str in dep_list:
|
|
586
|
+
_resolve_add(dependent, dep_str, context)
|
|
587
|
+
|
|
588
|
+
if errors:
|
|
589
|
+
raise _license_error("\n".join(errors))
|
|
590
|
+
|
|
591
|
+
try:
|
|
592
|
+
_detect_cycles(_StateView(session))
|
|
593
|
+
except Exception as exc:
|
|
594
|
+
raise _license_error(str(exc)) from None
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def topo_sort(session: DagSession, items: list[dict[str, Any]]) -> list[str]:
|
|
598
|
+
if not items:
|
|
599
|
+
return []
|
|
600
|
+
if session.paywall_was_active:
|
|
601
|
+
return _start_remote_plan(session, items)
|
|
602
|
+
|
|
603
|
+
original_order = {item["nodeid"]: idx for idx, item in enumerate(items)}
|
|
604
|
+
collected = {item["nodeid"] for item in items}
|
|
605
|
+
in_degree = {nid: 0 for nid in collected}
|
|
606
|
+
adj: dict[str, list[str]] = {nid: [] for nid in collected}
|
|
607
|
+
|
|
608
|
+
for item in items:
|
|
609
|
+
nid = item["nodeid"]
|
|
610
|
+
for dep in session.deps.get(nid, set()):
|
|
611
|
+
if dep in collected:
|
|
612
|
+
in_degree[nid] += 1
|
|
613
|
+
adj[dep].append(nid)
|
|
614
|
+
|
|
615
|
+
import heapq
|
|
616
|
+
|
|
617
|
+
ready: list[tuple[int, str]] = []
|
|
618
|
+
for nid, deg in in_degree.items():
|
|
619
|
+
if deg == 0:
|
|
620
|
+
heapq.heappush(ready, (original_order.get(nid, 10**9), nid))
|
|
621
|
+
|
|
622
|
+
result: list[str] = []
|
|
623
|
+
while ready:
|
|
624
|
+
_, nid = heapq.heappop(ready)
|
|
625
|
+
result.append(nid)
|
|
626
|
+
for dependent in adj.get(nid, []):
|
|
627
|
+
in_degree[dependent] -= 1
|
|
628
|
+
if in_degree[dependent] == 0:
|
|
629
|
+
heapq.heappush(
|
|
630
|
+
ready, (original_order.get(dependent, 10**9), dependent)
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
if len(result) != len(items):
|
|
634
|
+
raise _license_error(
|
|
635
|
+
"pytest-dag: topological sort failed – remaining cycle in graph"
|
|
636
|
+
)
|
|
637
|
+
return result
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def check_block(session: DagSession, nodeid: str, block_on: str) -> tuple[bool, str]:
|
|
641
|
+
blocked = session.blocking_deps(nodeid, block_on)
|
|
642
|
+
if not blocked:
|
|
643
|
+
return False, ""
|
|
644
|
+
reasons = [f"{dep} ({session.results.get(dep, 'FAILED')})" for dep in blocked]
|
|
645
|
+
return True, f"pytest-dag: blocked by {', '.join(reasons)}"
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def record_outcome(
|
|
649
|
+
session: DagSession, nodeid: str, outcome: str, overwrite: bool = False
|
|
650
|
+
) -> None:
|
|
651
|
+
session.record(nodeid, outcome, overwrite)
|
|
652
|
+
try:
|
|
653
|
+
_report_remote_outcome(session, nodeid, outcome)
|
|
654
|
+
except RuntimeError as exc:
|
|
655
|
+
_debug(f"report outcome failed: {exc}")
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def get_ordered(session: DagSession) -> list[str]:
|
|
659
|
+
return list(session.ordered)
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def set_ordered(session: DagSession, ordered: list[str]) -> None:
|
|
663
|
+
session.ordered = list(ordered)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def get_deps(session: DagSession) -> dict[str, list[str]]:
|
|
667
|
+
return {k: list(vs) for k, vs in session.deps.items()}
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def get_outcome(session: DagSession, nodeid: str) -> str | None:
|
|
671
|
+
return session.results.get(nodeid)
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def get_version() -> str:
|
|
675
|
+
return "2.0.7"
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
class _StateView:
|
|
679
|
+
"""Adapter to reuse existing cycle detection from graph.py."""
|
|
680
|
+
|
|
681
|
+
def __init__(self, session: DagSession) -> None:
|
|
682
|
+
self.deps = session.deps
|
|
683
|
+
self.reverse_deps = session.reverse_deps
|