nullrun 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nullrun/__init__.py +282 -0
- nullrun/__version__.py +4 -0
- nullrun/actions.py +455 -0
- nullrun/breaker/__init__.py +27 -0
- nullrun/breaker/circuit_breaker.py +402 -0
- nullrun/breaker/exceptions.py +319 -0
- nullrun/context.py +208 -0
- nullrun/decorators.py +649 -0
- nullrun/instrumentation/__init__.py +23 -0
- nullrun/instrumentation/_safe_patch.py +99 -0
- nullrun/instrumentation/auto.py +1095 -0
- nullrun/instrumentation/auto_requests.py +257 -0
- nullrun/instrumentation/autogen.py +163 -0
- nullrun/instrumentation/crewai.py +140 -0
- nullrun/instrumentation/langgraph.py +412 -0
- nullrun/instrumentation/llama_index.py +110 -0
- nullrun/observability.py +160 -0
- nullrun/py.typed +0 -0
- nullrun/runtime.py +1806 -0
- nullrun/toolbox/__init__.py +20 -0
- nullrun/toolbox/langgraph.py +94 -0
- nullrun/tracing.py +155 -0
- nullrun/transport.py +1509 -0
- nullrun/transport_websocket.py +627 -0
- nullrun-0.4.0.dist-info/METADATA +194 -0
- nullrun-0.4.0.dist-info/RECORD +28 -0
- nullrun-0.4.0.dist-info/WHEEL +4 -0
- nullrun-0.4.0.dist-info/licenses/LICENSE +201 -0
nullrun/transport.py
ADDED
|
@@ -0,0 +1,1509 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transport layer for NullRun SDK.
|
|
3
|
+
|
|
4
|
+
Handles HTTP communication with batching and background flush.
|
|
5
|
+
Includes fallback modes for Gateway unavailability.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import hashlib
|
|
9
|
+
import hmac
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import random
|
|
14
|
+
import threading
|
|
15
|
+
import time
|
|
16
|
+
import uuid
|
|
17
|
+
import weakref
|
|
18
|
+
from collections import OrderedDict
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import httpx
|
|
24
|
+
|
|
25
|
+
from nullrun.actions import handle_action
|
|
26
|
+
from nullrun.breaker.circuit_breaker import CircuitBreaker
|
|
27
|
+
from nullrun.breaker.exceptions import (
|
|
28
|
+
BreakerTransportError,
|
|
29
|
+
InsecureTransportError,
|
|
30
|
+
NullRunAuthenticationError,
|
|
31
|
+
NullRunTransportError,
|
|
32
|
+
RateLimitError,
|
|
33
|
+
TransportErrorSource,
|
|
34
|
+
)
|
|
35
|
+
from nullrun.observability import metrics
|
|
36
|
+
|
|
37
|
+
# OpenTelemetry imports (lazy-loaded to support optional dependency)
|
|
38
|
+
try:
|
|
39
|
+
from opentelemetry import trace
|
|
40
|
+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
|
41
|
+
_OTEL_AVAILABLE = True
|
|
42
|
+
except ImportError:
|
|
43
|
+
_OTEL_AVAILABLE = False
|
|
44
|
+
trace = None # type: ignore[assignment]
|
|
45
|
+
TraceContextTextMapPropagator = None # type: ignore[assignment]
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
__api_version__ = "1.0"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# =============================================================================
|
|
57
|
+
# HMAC Request Signing (Task 11)
|
|
58
|
+
# =============================================================================
|
|
59
|
+
|
|
60
|
+
def generate_hmac_signature(
|
|
61
|
+
api_key: str,
|
|
62
|
+
secret_key: str,
|
|
63
|
+
timestamp: int,
|
|
64
|
+
body: str,
|
|
65
|
+
) -> str:
|
|
66
|
+
"""
|
|
67
|
+
Generate HMAC-SHA256 signature for request authentication.
|
|
68
|
+
|
|
69
|
+
Signature = HMAC-SHA256(secret_key, timestamp + ":" + api_key + ":" + body_hash)
|
|
70
|
+
Body hash = SHA256(request_body)
|
|
71
|
+
|
|
72
|
+
This provides:
|
|
73
|
+
- Authentication: API key identifies the client
|
|
74
|
+
- Integrity: Body hash ensures request hasn't been tampered with
|
|
75
|
+
- Freshness: Timestamp prevents replay attacks
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
api_key: Client's API key (identifier)
|
|
79
|
+
secret_key: Client's secret key (used for HMAC)
|
|
80
|
+
timestamp: Unix timestamp in seconds
|
|
81
|
+
body: Request body as JSON string
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Hex-encoded HMAC-SHA256 signature
|
|
85
|
+
"""
|
|
86
|
+
body_hash = hashlib.sha256(body.encode('utf-8')).hexdigest()
|
|
87
|
+
message = f"{timestamp}:{api_key}:{body_hash}"
|
|
88
|
+
|
|
89
|
+
signature = hmac.new(
|
|
90
|
+
secret_key.encode('utf-8'),
|
|
91
|
+
message.encode('utf-8'),
|
|
92
|
+
hashlib.sha256
|
|
93
|
+
).hexdigest()
|
|
94
|
+
|
|
95
|
+
return signature
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def verify_hmac_signature(
|
|
99
|
+
api_key: str,
|
|
100
|
+
secret_key: str,
|
|
101
|
+
timestamp: int,
|
|
102
|
+
body: str,
|
|
103
|
+
signature: str,
|
|
104
|
+
max_age_seconds: int = 300,
|
|
105
|
+
) -> bool:
|
|
106
|
+
"""
|
|
107
|
+
Verify HMAC signature from request.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
api_key: Client's API key
|
|
111
|
+
secret_key: Client's secret key
|
|
112
|
+
timestamp: Unix timestamp from request
|
|
113
|
+
body: Request body as JSON string
|
|
114
|
+
signature: HMAC signature to verify
|
|
115
|
+
max_age_seconds: Maximum allowed age of request (default 5 min)
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
True if signature is valid and request is fresh
|
|
119
|
+
"""
|
|
120
|
+
# Check timestamp freshness
|
|
121
|
+
current_time = int(time.time())
|
|
122
|
+
if abs(current_time - timestamp) > max_age_seconds:
|
|
123
|
+
logger.warning(f"Request timestamp too old: {timestamp} vs current {current_time}")
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
# Recompute expected signature
|
|
127
|
+
expected = generate_hmac_signature(api_key, secret_key, timestamp, body)
|
|
128
|
+
|
|
129
|
+
# Constant-time comparison to prevent timing attacks
|
|
130
|
+
return hmac.compare_digest(expected, signature)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# =============================================================================
|
|
134
|
+
# Policy Cache for CACHED fallback mode
|
|
135
|
+
# =============================================================================
|
|
136
|
+
|
|
137
|
+
class CachedDecision:
|
|
138
|
+
"""Represents a cached execute decision."""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
decision: str,
|
|
143
|
+
policy_id: str | None = None,
|
|
144
|
+
ttl_seconds: float = 300.0,
|
|
145
|
+
policy_version: int | None = None,
|
|
146
|
+
):
|
|
147
|
+
self.decision = decision
|
|
148
|
+
self.policy_id = policy_id
|
|
149
|
+
self.cached_at = time.monotonic()
|
|
150
|
+
self.ttl_seconds = ttl_seconds
|
|
151
|
+
# Phase 5 #5.2: dedicated field, not a `ttl_seconds` repurpose.
|
|
152
|
+
self.policy_version = policy_version
|
|
153
|
+
|
|
154
|
+
def is_expired(self) -> bool:
|
|
155
|
+
return time.monotonic() - self.cached_at > self.ttl_seconds
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class PolicyCache:
|
|
159
|
+
"""
|
|
160
|
+
LRU cache for execute decisions. Used in CACHED fallback mode.
|
|
161
|
+
|
|
162
|
+
Cache key is (organization_id, policy_version) to prevent cache thrashing.
|
|
163
|
+
At 1000+ users with unique workflow_ids, keying by tool caused constant eviction.
|
|
164
|
+
Now we key by organization + policy version, so all tools in an organization share
|
|
165
|
+
the same policy cached entry until the policy version changes.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def __init__(self, maxsize: int = 1000, ttl_seconds: float = 300.0):
|
|
169
|
+
self._cache: OrderedDict[str, CachedDecision] = OrderedDict()
|
|
170
|
+
self._maxsize = maxsize
|
|
171
|
+
self._ttl = ttl_seconds
|
|
172
|
+
self._hits = 0
|
|
173
|
+
self._misses = 0
|
|
174
|
+
|
|
175
|
+
def get(self, key: str) -> CachedDecision | None:
|
|
176
|
+
decision = self._cache.get(key)
|
|
177
|
+
if decision is None:
|
|
178
|
+
self._misses += 1
|
|
179
|
+
return None
|
|
180
|
+
if decision.is_expired():
|
|
181
|
+
del self._cache[key]
|
|
182
|
+
self._misses += 1
|
|
183
|
+
return None
|
|
184
|
+
self._cache.move_to_end(key)
|
|
185
|
+
self._hits += 1
|
|
186
|
+
return decision
|
|
187
|
+
|
|
188
|
+
def set(self, key: str, decision: str, policy_id: str = None, policy_version: int = None) -> None:
|
|
189
|
+
if key in self._cache:
|
|
190
|
+
self._cache.move_to_end(key)
|
|
191
|
+
elif len(self._cache) >= self._maxsize:
|
|
192
|
+
self._cache.popitem(last=False)
|
|
193
|
+
# Phase 5 #5.2: pass policy_version as a dedicated field.
|
|
194
|
+
# The previous implementation wrote it into ttl_seconds, which
|
|
195
|
+
# corrupted the cache-lifetime check (see plan #5.2).
|
|
196
|
+
self._cache[key] = CachedDecision(
|
|
197
|
+
decision=decision,
|
|
198
|
+
policy_id=policy_id,
|
|
199
|
+
ttl_seconds=self._ttl,
|
|
200
|
+
policy_version=policy_version,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def make_key(self, organization_id: str, policy_version: int = None) -> str:
|
|
204
|
+
"""Generate cache key from organization_id and policy_version."""
|
|
205
|
+
if policy_version is not None:
|
|
206
|
+
return f"{organization_id}:{policy_version}"
|
|
207
|
+
return f"{organization_id}:0" # Default to version 0 if not provided
|
|
208
|
+
|
|
209
|
+
def get_stats(self) -> dict:
|
|
210
|
+
"""Get cache statistics for observability."""
|
|
211
|
+
total = self._hits + self._misses
|
|
212
|
+
hit_rate = self._hits / total if total > 0 else 0.0
|
|
213
|
+
return {
|
|
214
|
+
"size": len(self._cache),
|
|
215
|
+
"hits": self._hits,
|
|
216
|
+
"misses": self._misses,
|
|
217
|
+
"hit_rate": hit_rate,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
def __len__(self) -> int:
|
|
221
|
+
return len(self._cache)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _signed_request_body(payload: dict[str, Any]) -> bytes:
|
|
227
|
+
"""Serialise a JSON payload to the canonical bytes the HMAC
|
|
228
|
+
signature is computed over.
|
|
229
|
+
|
|
230
|
+
All three signed POST call sites (``_send_batch_with_retry_info``,
|
|
231
|
+
``Transport.execute``, ``Transport.check``) MUST serialise via this
|
|
232
|
+
helper and pass the result with ``content=body`` to
|
|
233
|
+
``httpx.Client.post``. Sending via ``json=...`` lets httpx
|
|
234
|
+
re-serialise with its default compact separators, which produces
|
|
235
|
+
a body that does NOT match the body the HMAC signature was
|
|
236
|
+
computed over. The Rust server at
|
|
237
|
+
``backend/src/auth/hmac.rs:466-518`` is strict -- it recomputes
|
|
238
|
+
``sha256(body)`` from the raw wire bytes and rejects with 401
|
|
239
|
+
on mismatch.
|
|
240
|
+
"""
|
|
241
|
+
return json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
|
242
|
+
|
|
243
|
+
# =============================================================================
|
|
244
|
+
# Retry with exponential backoff + jitter
|
|
245
|
+
# =============================================================================
|
|
246
|
+
|
|
247
|
+
"""
|
|
248
|
+
Retry with exponential backoff + jitter + Retry-After header support
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def _retry_with_backoff(
|
|
252
|
+
func: Callable[[], Any],
|
|
253
|
+
max_retries: int = 3,
|
|
254
|
+
base_delay: float = 0.5,
|
|
255
|
+
max_delay: float = 30.0,
|
|
256
|
+
backoff_factor: float = 2.0,
|
|
257
|
+
jitter: float = 0.1,
|
|
258
|
+
last_retry_after_seconds: float = 0.0,
|
|
259
|
+
on_transport_error: str | Callable[[Exception], dict[str, Any]] | None = None,
|
|
260
|
+
) -> Any:
|
|
261
|
+
"""
|
|
262
|
+
Retry with exponential backoff and jitter, honoring Retry-After header.
|
|
263
|
+
|
|
264
|
+
When Retry-After is provided (from backend 429 response), use it directly
|
|
265
|
+
instead of exponential backoff to prevent retry storms.
|
|
266
|
+
|
|
267
|
+
Formula (without Retry-After): delay = min(base_delay * backoff_factor^attempt, max_delay)
|
|
268
|
+
delay += random.uniform(-jitter * delay, jitter * delay)
|
|
269
|
+
Formula (with Retry-After): actual_delay = min(last_retry_after_seconds, max_delay)
|
|
270
|
+
"""
|
|
271
|
+
last_exc: Exception | None = None
|
|
272
|
+
|
|
273
|
+
for attempt in range(max_retries + 1):
|
|
274
|
+
try:
|
|
275
|
+
result = func()
|
|
276
|
+
|
|
277
|
+
if hasattr(result, "status_code"):
|
|
278
|
+
if result.status_code == 401:
|
|
279
|
+
raise NullRunAuthenticationError("Invalid API key")
|
|
280
|
+
if result.status_code >= 500 and on_transport_error == "raise":
|
|
281
|
+
# Round 3 (Phase 0.4.0): 5xx is a classified
|
|
282
|
+
# GATEWAY_ERROR. Don't retry -- this is a server
|
|
283
|
+
# bug, not a network blip. Only raise when the
|
|
284
|
+
# caller has opted into the typed-error contract
|
|
285
|
+
# via on_transport_error="raise".
|
|
286
|
+
raise NullRunTransportError(
|
|
287
|
+
f"Gateway returned {result.status_code}",
|
|
288
|
+
source=TransportErrorSource.GATEWAY_ERROR,
|
|
289
|
+
endpoint="execute",
|
|
290
|
+
status_code=result.status_code,
|
|
291
|
+
)
|
|
292
|
+
if result.status_code >= 400:
|
|
293
|
+
result.raise_for_status()
|
|
294
|
+
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
except (BreakerTransportError, NullRunAuthenticationError, NullRunTransportError):
|
|
298
|
+
raise
|
|
299
|
+
|
|
300
|
+
except Exception as exc:
|
|
301
|
+
last_exc = exc
|
|
302
|
+
# Sprint 3 follow-up (B24): bump ``last_error`` so the
|
|
303
|
+
# operator can read the most recent failure type without
|
|
304
|
+
# grepping logs. The string is the exception class
|
|
305
|
+
# name plus the message — short, searchable, and
|
|
306
|
+
# doesn't leak request bodies.
|
|
307
|
+
metrics.set_transport("last_error", f"{type(exc).__name__}: {exc}")
|
|
308
|
+
# ``timeouts`` is a specific subcategory of retry
|
|
309
|
+
# trigger — distinguished so an SRE can alert on
|
|
310
|
+
# ``timeouts > N per minute`` separately from
|
|
311
|
+
# generic 5xx retries.
|
|
312
|
+
if isinstance(exc, (httpx.TimeoutException, httpx.ConnectTimeout, httpx.ReadTimeout)):
|
|
313
|
+
metrics.inc_transport("timeouts")
|
|
314
|
+
|
|
315
|
+
if attempt >= max_retries:
|
|
316
|
+
break
|
|
317
|
+
|
|
318
|
+
# Bump ``retries_total`` for every retry attempt
|
|
319
|
+
# (not for the final failure). The counter is
|
|
320
|
+
# distinct from the final BreakerTransportError —
|
|
321
|
+
# it measures how often the SDK had to retry
|
|
322
|
+
# because the backend was flaky.
|
|
323
|
+
metrics.inc_transport("retries_total")
|
|
324
|
+
|
|
325
|
+
# Honor Retry-After from backend if present (from 429 response)
|
|
326
|
+
if last_retry_after_seconds > 0:
|
|
327
|
+
actual_delay = min(last_retry_after_seconds, max_delay)
|
|
328
|
+
# Reset after use so next retry uses exponential backoff
|
|
329
|
+
last_retry_after_seconds = 0.0
|
|
330
|
+
logger.warning(
|
|
331
|
+
"Request failed (attempt %d/%d), honoring Retry-After %.2fs: %s",
|
|
332
|
+
attempt + 1,
|
|
333
|
+
max_retries + 1,
|
|
334
|
+
actual_delay,
|
|
335
|
+
type(exc).__name__,
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
delay = min(base_delay * (backoff_factor ** attempt), max_delay)
|
|
339
|
+
jitter_amount = delay * jitter
|
|
340
|
+
# Standard jitter for retry delay -- not crypto-sensitive
|
|
341
|
+
actual_delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
|
342
|
+
actual_delay = max(0.0, actual_delay)
|
|
343
|
+
logger.warning(
|
|
344
|
+
"Request failed (attempt %d/%d), retrying in %.2fs: %s",
|
|
345
|
+
attempt + 1,
|
|
346
|
+
max_retries + 1,
|
|
347
|
+
actual_delay,
|
|
348
|
+
type(exc).__name__,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
time.sleep(actual_delay)
|
|
352
|
+
|
|
353
|
+
raise BreakerTransportError(
|
|
354
|
+
f"Request failed after {max_retries + 1} attempts"
|
|
355
|
+
) from last_exc
|
|
356
|
+
|
|
357
|
+
# =============================================================================
|
|
358
|
+
# Fallback Modes (Phase 1 - SDK Resilience)
|
|
359
|
+
# =============================================================================
|
|
360
|
+
|
|
361
|
+
class FallbackMode:
|
|
362
|
+
"""
|
|
363
|
+
SDK behavior when Gateway is unavailable.
|
|
364
|
+
|
|
365
|
+
This is CRITICAL for production - Gateway unavailability should NOT
|
|
366
|
+
block agent execution, but behavior must be defined and logged.
|
|
367
|
+
"""
|
|
368
|
+
# Block if Gateway unavailable (for critical tools)
|
|
369
|
+
STRICT = "strict"
|
|
370
|
+
# Allow if Gateway unavailable, log locally (DEFAULT)
|
|
371
|
+
PERMISSIVE = "permissive"
|
|
372
|
+
# Use cached decision if Gateway unavailable
|
|
373
|
+
CACHED = "cached"
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
class DecisionSource:
|
|
377
|
+
"""
|
|
378
|
+
Where the decision originated - for provenance tracking.
|
|
379
|
+
"""
|
|
380
|
+
GATEWAY = "gateway"
|
|
381
|
+
CACHED = "cached"
|
|
382
|
+
FALLBACK = "fallback"
|
|
383
|
+
LOCAL = "local"
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
@dataclass
|
|
387
|
+
class FlushConfig:
|
|
388
|
+
"""Configuration for transport flush behavior."""
|
|
389
|
+
batch_size: int = 50
|
|
390
|
+
flush_interval: float = 5.0 # seconds
|
|
391
|
+
max_retries: int = 3
|
|
392
|
+
retry_delay: float = 1.0 # seconds
|
|
393
|
+
max_buffer_size: int = 1000 # Max events before dropping oldest
|
|
394
|
+
max_failed_flush: int = 10 # Circuit breaker: stop trying after this many failures
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
@dataclass
|
|
398
|
+
class ExecuteConfig:
|
|
399
|
+
"""Configuration for execute (strict mode) behavior."""
|
|
400
|
+
# Fallback mode when Gateway is unavailable
|
|
401
|
+
fallback_mode: str = FallbackMode.PERMISSIVE
|
|
402
|
+
# Gateway timeout in seconds
|
|
403
|
+
timeout: float = 5.0
|
|
404
|
+
# Max retries for execute calls
|
|
405
|
+
max_retries: int = 2
|
|
406
|
+
# Cache TTL for CACHED mode (seconds)
|
|
407
|
+
cache_ttl: float = 60.0
|
|
408
|
+
# Cache max size
|
|
409
|
+
cache_max_size: int = 10000
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class Transport:
|
|
413
|
+
"""
|
|
414
|
+
HTTP transport with batching support.
|
|
415
|
+
|
|
416
|
+
Features:
|
|
417
|
+
- Non-blocking track() calls (append to buffer)
|
|
418
|
+
- Background flush at intervals or when batch_size reached
|
|
419
|
+
- Retry logic for failed requests
|
|
420
|
+
- Thread-safe for sync usage
|
|
421
|
+
- HMAC request signing for secure authentication
|
|
422
|
+
- Distributed circuit breaker via Redis for multi-worker deployments
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
def __init__(
|
|
426
|
+
self,
|
|
427
|
+
api_url: str,
|
|
428
|
+
api_key: str | None = None,
|
|
429
|
+
secret_key: str | None = None,
|
|
430
|
+
config: FlushConfig | None = None,
|
|
431
|
+
redis_client: Any = None,
|
|
432
|
+
):
|
|
433
|
+
self.api_url = api_url.rstrip("/")
|
|
434
|
+
|
|
435
|
+
# TLS enforcement: reject non-localhost HTTP URLs. The check
|
|
436
|
+
# must NOT be a startswith chain — that allowed homograph
|
|
437
|
+
# attacks (http://127.0.0.1.attacker.com, http://localhost.evil.com)
|
|
438
|
+
# and rejected legitimate inputs (http://[::1]:8080, http://LOCALHOST).
|
|
439
|
+
# We use urllib.parse.urlparse to extract the canonical hostname,
|
|
440
|
+
# then check the host against a small allow-list that includes the
|
|
441
|
+
# full IPv4 loopback range (127.0.0.0/8) and IPv6 loopback (::1).
|
|
442
|
+
# For IPv4 we use ``ipaddress.ip_address`` so that
|
|
443
|
+
# ``127.0.0.1.attacker.com`` (a string that happens to start
|
|
444
|
+
# with "127.") is NOT mistakenly treated as a loopback IP.
|
|
445
|
+
from ipaddress import ip_address
|
|
446
|
+
from urllib.parse import urlparse
|
|
447
|
+
|
|
448
|
+
parsed = urlparse(self.api_url)
|
|
449
|
+
if parsed.scheme == "http":
|
|
450
|
+
host = (parsed.hostname or "").lower()
|
|
451
|
+
allowed = host == "localhost" or host == "::1"
|
|
452
|
+
if not allowed:
|
|
453
|
+
try:
|
|
454
|
+
addr = ip_address(host)
|
|
455
|
+
allowed = addr.is_loopback
|
|
456
|
+
except ValueError:
|
|
457
|
+
allowed = False
|
|
458
|
+
if not allowed:
|
|
459
|
+
raise InsecureTransportError(
|
|
460
|
+
f"Insecure URL detected: {self.api_url}. "
|
|
461
|
+
f"HTTP is only allowed for localhost / 127.0.0.0/8 / ::1. "
|
|
462
|
+
f"Use https:// for production."
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
self.api_key = api_key
|
|
466
|
+
self.secret_key = secret_key # HMAC signing key
|
|
467
|
+
self.config = config or FlushConfig()
|
|
468
|
+
# Phase 8 #8.4: allow env-var override of batch size and
|
|
469
|
+
# flush interval. Useful for tuning high-throughput agents
|
|
470
|
+
# without subclassing.
|
|
471
|
+
if "NULLRUN_BATCH_SIZE" in os.environ:
|
|
472
|
+
try:
|
|
473
|
+
self.config.batch_size = int(os.environ["NULLRUN_BATCH_SIZE"])
|
|
474
|
+
except ValueError:
|
|
475
|
+
logger.warning(
|
|
476
|
+
"NULLRUN_BATCH_SIZE=%r is not an int; ignoring",
|
|
477
|
+
os.environ["NULLRUN_BATCH_SIZE"],
|
|
478
|
+
)
|
|
479
|
+
if "NULLRUN_FLUSH_INTERVAL_MS" in os.environ:
|
|
480
|
+
try:
|
|
481
|
+
self.config.flush_interval = (
|
|
482
|
+
int(os.environ["NULLRUN_FLUSH_INTERVAL_MS"]) / 1000.0
|
|
483
|
+
)
|
|
484
|
+
except ValueError:
|
|
485
|
+
logger.warning(
|
|
486
|
+
"NULLRUN_FLUSH_INTERVAL_MS=%r is not an int; ignoring",
|
|
487
|
+
os.environ["NULLRUN_FLUSH_INTERVAL_MS"],
|
|
488
|
+
)
|
|
489
|
+
self._buffer: list[dict[str, Any]] = []
|
|
490
|
+
self._in_flight: dict[str, dict[str, Any]] = {} # event_id -> event for retry dedup
|
|
491
|
+
self._lock = threading.RLock() # RLock so re-entrant acquisition (e.g.
|
|
492
|
+
# test fixtures that hold the lock
|
|
493
|
+
# while calling lock-acquiring
|
|
494
|
+
# methods) doesn't deadlock.
|
|
495
|
+
self._flush_thread: threading.Thread | None = None
|
|
496
|
+
self._running = False
|
|
497
|
+
|
|
498
|
+
# mTLS client certificate support
|
|
499
|
+
# NULLRUN_TLS_CLIENT_CERT and NULLRUN_TLS_CLIENT_KEY env vars for client cert auth
|
|
500
|
+
client_cert_path = os.environ.get("NULLRUN_TLS_CLIENT_CERT")
|
|
501
|
+
client_key_path = os.environ.get("NULLRUN_TLS_CLIENT_KEY")
|
|
502
|
+
ca_cert_path = os.environ.get("NULLRUN_TLS_CA_CERT") # Optional custom CA
|
|
503
|
+
|
|
504
|
+
# Build SSL configuration for mTLS
|
|
505
|
+
# For client cert auth: verify is a CA cert, cert is tuple of (client_cert, client_key)
|
|
506
|
+
verify_cert: bool | str = True
|
|
507
|
+
client_cert: tuple[str, str] | None = None
|
|
508
|
+
if client_cert_path and client_key_path:
|
|
509
|
+
# Client certificate authentication (mTLS)
|
|
510
|
+
client_cert = (client_cert_path, client_key_path)
|
|
511
|
+
verify_cert = ca_cert_path if ca_cert_path else True
|
|
512
|
+
logger.debug(f"mTLS enabled: client_cert={client_cert_path}")
|
|
513
|
+
elif ca_cert_path:
|
|
514
|
+
# Custom CA certificate only (no client cert)
|
|
515
|
+
verify_cert = ca_cert_path
|
|
516
|
+
logger.debug(f"Custom CA configured: ca_cert={ca_cert_path}")
|
|
517
|
+
|
|
518
|
+
self._client = httpx.Client(
|
|
519
|
+
timeout=httpx.Timeout(
|
|
520
|
+
connect=5.0,
|
|
521
|
+
read=30.0,
|
|
522
|
+
write=10.0,
|
|
523
|
+
pool=5.0,
|
|
524
|
+
),
|
|
525
|
+
verify=verify_cert,
|
|
526
|
+
cert=client_cert,
|
|
527
|
+
limits=httpx.Limits(
|
|
528
|
+
max_connections=10,
|
|
529
|
+
max_keepalive_connections=5,
|
|
530
|
+
keepalive_expiry=30.0,
|
|
531
|
+
),
|
|
532
|
+
)
|
|
533
|
+
self._redis_client = redis_client
|
|
534
|
+
self._circuit_breaker = CircuitBreaker(
|
|
535
|
+
failure_threshold=self.config.max_failed_flush,
|
|
536
|
+
recovery_timeout=30.0,
|
|
537
|
+
redis_client=redis_client,
|
|
538
|
+
name="transport",
|
|
539
|
+
)
|
|
540
|
+
self._stopped = False # Track if stop() was called
|
|
541
|
+
self._policy_cache = PolicyCache(
|
|
542
|
+
maxsize=1000,
|
|
543
|
+
ttl_seconds=300.0,
|
|
544
|
+
)
|
|
545
|
+
_masked = api_key[:8] + "***" if api_key and len(api_key) >= 8 else "***"
|
|
546
|
+
logger.debug(f"Transport initialized: api_url={self.api_url}, api_key={_masked}")
|
|
547
|
+
|
|
548
|
+
# OpenTelemetry tracer initialization (lazy - only if opentelemetry is installed)
|
|
549
|
+
self._tracer = None
|
|
550
|
+
self._propagator = None
|
|
551
|
+
if _OTEL_AVAILABLE:
|
|
552
|
+
self._tracer = trace.get_tracer("nullrun.transport")
|
|
553
|
+
self._propagator = TraceContextTextMapPropagator()
|
|
554
|
+
|
|
555
|
+
# Register final-flush hook via weakref.finalize so the
|
|
556
|
+
# callback only fires if this Transport instance is still
|
|
557
|
+
# alive at process exit. Replaces the previous
|
|
558
|
+
# ``atexit.register`` (which accumulated one handler per
|
|
559
|
+
# Transport in long-running deployments) and the previous
|
|
560
|
+
# ``signal.signal`` handler (which hijacked SIGTERM/SIGINT
|
|
561
|
+
# process-wide and called ``sys.exit(0)`` from inside the
|
|
562
|
+
# signal context). The fix contract is pinned by
|
|
563
|
+
# tests/test_signal_safety.py.
|
|
564
|
+
self._finalizer = weakref.finalize(self, self._atexit_flush_safe)
|
|
565
|
+
|
|
566
|
+
@staticmethod
|
|
567
|
+
def _atexit_flush_safe(_self_id: int | None = None) -> None:
|
|
568
|
+
"""Weakref finalizer entry point.
|
|
569
|
+
|
|
570
|
+
``weakref.finalize`` calls this with no arguments (the
|
|
571
|
+
reference to ``self`` has been dropped by the time the
|
|
572
|
+
callback fires). We cannot reach into the transport from
|
|
573
|
+
here — the buffer, the httpx client, and the lock are all
|
|
574
|
+
gone. The recommended lifecycle is to call ``stop()``
|
|
575
|
+
explicitly (or use ``Transport`` as a context manager).
|
|
576
|
+
If the caller did neither, we log a one-time DEBUG line
|
|
577
|
+
and return.
|
|
578
|
+
|
|
579
|
+
The staticmethod signature accepts an optional positional
|
|
580
|
+
arg so that ``weakref.finalize`` succeeds and so that
|
|
581
|
+
tests can call ``_atexit_flush_safe(id(t))`` to assert
|
|
582
|
+
the wrapper swallows exceptions raised by a patched
|
|
583
|
+
``_atexit_flush``.
|
|
584
|
+
"""
|
|
585
|
+
logger.debug(
|
|
586
|
+
"Transport finalizer fired without explicit stop(); "
|
|
587
|
+
"remaining events may be lost. Use Transport as a context "
|
|
588
|
+
"manager or call stop() explicitly."
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def _persist_to_wal(self) -> None:
|
|
592
|
+
"""Persist unflushed events to WAL file for replay on restart."""
|
|
593
|
+
if not self._buffer:
|
|
594
|
+
return
|
|
595
|
+
event_count = len(self._buffer)
|
|
596
|
+
wal_path = os.path.join(os.getcwd(), ".nullrun.wal")
|
|
597
|
+
with open(wal_path, "a") as f:
|
|
598
|
+
for event in self._buffer:
|
|
599
|
+
f.write(json.dumps(event) + "\n")
|
|
600
|
+
self._buffer.clear()
|
|
601
|
+
logger.debug(f"Persisted {event_count} events to WAL at {wal_path}")
|
|
602
|
+
|
|
603
|
+
def _replay_from_wal(self) -> None:
|
|
604
|
+
"""Replay events from WAL file on startup."""
|
|
605
|
+
wal_path = os.path.join(os.getcwd(), ".nullrun.wal")
|
|
606
|
+
if not os.path.exists(wal_path):
|
|
607
|
+
return
|
|
608
|
+
events = []
|
|
609
|
+
with open(wal_path) as f:
|
|
610
|
+
for line in f:
|
|
611
|
+
try:
|
|
612
|
+
events.append(json.loads(line.strip()))
|
|
613
|
+
except json.JSONDecodeError:
|
|
614
|
+
continue
|
|
615
|
+
if events:
|
|
616
|
+
self._buffer.extend(events)
|
|
617
|
+
self._do_flush()
|
|
618
|
+
os.remove(wal_path) # Clean up WAL after successful replay
|
|
619
|
+
logger.info(f"Replayed {len(events)} events from WAL")
|
|
620
|
+
|
|
621
|
+
def track(self, event: dict[str, Any]) -> None:
|
|
622
|
+
"""
|
|
623
|
+
Add event to buffer. Non-blocking.
|
|
624
|
+
|
|
625
|
+
Events are flushed either when batch_size is reached or
|
|
626
|
+
flush_interval elapses.
|
|
627
|
+
"""
|
|
628
|
+
with self._lock:
|
|
629
|
+
# Generate event_id if not provided
|
|
630
|
+
if "event_id" not in event or not event["event_id"]:
|
|
631
|
+
event["event_id"] = str(uuid.uuid4())
|
|
632
|
+
|
|
633
|
+
# Store in-flight for retry dedup
|
|
634
|
+
self._in_flight[event["event_id"]] = event
|
|
635
|
+
|
|
636
|
+
self._buffer.append(event)
|
|
637
|
+
metrics.inc_transport("events_enqueued")
|
|
638
|
+
|
|
639
|
+
if len(self._buffer) >= self.config.batch_size:
|
|
640
|
+
self._do_flush_locked()
|
|
641
|
+
|
|
642
|
+
def start(self) -> None:
|
|
643
|
+
"""Start background flush thread."""
|
|
644
|
+
if self._running:
|
|
645
|
+
return
|
|
646
|
+
# Replay any events from WAL that were persisted due to previous crash
|
|
647
|
+
self._replay_from_wal()
|
|
648
|
+
self._running = True
|
|
649
|
+
self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
|
|
650
|
+
self._flush_thread.start()
|
|
651
|
+
logger.info("Transport flush thread started")
|
|
652
|
+
|
|
653
|
+
def __enter__(self) -> "Transport":
|
|
654
|
+
"""Context-manager entry: start the flush thread and return self.
|
|
655
|
+
|
|
656
|
+
Pairs with ``__exit__`` so callers can write
|
|
657
|
+
``with Transport(...) as t:`` and rely on ``stop()`` running
|
|
658
|
+
on the way out. Replaces the manual ``start() / stop()`` pair
|
|
659
|
+
that was easy to forget in long-running services.
|
|
660
|
+
"""
|
|
661
|
+
self.start()
|
|
662
|
+
return self
|
|
663
|
+
|
|
664
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
665
|
+
"""Context-manager exit: stop the flush thread and persist WAL.
|
|
666
|
+
|
|
667
|
+
Always stops, regardless of whether the body raised. The
|
|
668
|
+
exception (if any) is NOT swallowed — the caller still sees
|
|
669
|
+
it after the with-block.
|
|
670
|
+
"""
|
|
671
|
+
try:
|
|
672
|
+
self.stop()
|
|
673
|
+
except Exception as e: # noqa: BLE001 — best-effort on context exit
|
|
674
|
+
logger.debug(f"Transport.__exit__: stop() raised: {e}")
|
|
675
|
+
|
|
676
|
+
def stop(self, timeout: float = 10.0) -> None:
|
|
677
|
+
"""Stop background flush thread and flush remaining events."""
|
|
678
|
+
self._running = False
|
|
679
|
+
self._stopped = True # Mark as stopped to prevent double flush
|
|
680
|
+
if self._flush_thread:
|
|
681
|
+
self._flush_thread.join(timeout=timeout)
|
|
682
|
+
self._do_flush() # Final flush
|
|
683
|
+
self._persist_to_wal() # WAL any remaining events
|
|
684
|
+
self._client.close()
|
|
685
|
+
# Detach the weakref finalizer — stop() is the canonical
|
|
686
|
+
# "I am done" path. After this point the finalizer will
|
|
687
|
+
# silently no-op even if the interpreter is still alive.
|
|
688
|
+
if getattr(self, "_finalizer", None) is not None and self._finalizer.alive:
|
|
689
|
+
self._finalizer.detach()
|
|
690
|
+
logger.info("Transport stopped")
|
|
691
|
+
|
|
692
|
+
def _flush_loop(self) -> None:
|
|
693
|
+
"""Background loop that periodically flushes."""
|
|
694
|
+
while self._running:
|
|
695
|
+
time.sleep(self.config.flush_interval)
|
|
696
|
+
if self._running:
|
|
697
|
+
self._do_flush()
|
|
698
|
+
|
|
699
|
+
def _do_flush(self) -> None:
|
|
700
|
+
"""Perform the actual flush."""
|
|
701
|
+
with self._lock:
|
|
702
|
+
self._do_flush_locked()
|
|
703
|
+
|
|
704
|
+
def _do_flush_locked(self) -> None:
|
|
705
|
+
"""Flush under lock. Must be called with _lock held."""
|
|
706
|
+
if not self._buffer:
|
|
707
|
+
logger.debug("Buffer empty, skipping flush")
|
|
708
|
+
return
|
|
709
|
+
|
|
710
|
+
batch = self._buffer[:]
|
|
711
|
+
self._buffer.clear()
|
|
712
|
+
logger.debug(f"Sending batch of {len(batch)} events")
|
|
713
|
+
|
|
714
|
+
# Circuit breaker wrapped send - uses proper 3-state circuit breaker
|
|
715
|
+
def send_batch():
|
|
716
|
+
result = self._send_batch_with_retry_info(batch)
|
|
717
|
+
# Remove accepted events from in-flight
|
|
718
|
+
if result.accepted_event_ids:
|
|
719
|
+
for event in batch:
|
|
720
|
+
if event.get("event_id") in result.accepted_event_ids:
|
|
721
|
+
self._in_flight.pop(event.get("event_id"), None)
|
|
722
|
+
logger.debug(f"Flushed {len(batch)} events")
|
|
723
|
+
# Update metrics on successful flush (thread-safe)
|
|
724
|
+
metrics.inc_transport("batches_sent")
|
|
725
|
+
metrics.inc_transport("events_sent", len(batch))
|
|
726
|
+
metrics.set_transport("last_flush_at", time.monotonic())
|
|
727
|
+
return result
|
|
728
|
+
|
|
729
|
+
try:
|
|
730
|
+
self._circuit_breaker.call(send_batch)
|
|
731
|
+
except BreakerTransportError:
|
|
732
|
+
# Circuit breaker is open - re-add batch to buffer for retry later
|
|
733
|
+
logger.warning(
|
|
734
|
+
f"Circuit breaker OPEN. Batch of {len(batch)} events will be re-queued."
|
|
735
|
+
)
|
|
736
|
+
# Enforce max buffer size BEFORE re-queue to prevent unbounded growth
|
|
737
|
+
# Drop oldest events first to make room for new batch
|
|
738
|
+
available_space = self.config.max_buffer_size - len(self._buffer)
|
|
739
|
+
if available_space < len(batch):
|
|
740
|
+
overflow = len(batch) - available_space
|
|
741
|
+
if overflow > 0:
|
|
742
|
+
# Drop oldest from front (batch) since it hasn't been sent yet
|
|
743
|
+
logger.warning(f"Buffer overflow on CB OPEN: dropping {overflow} oldest events from pending batch")
|
|
744
|
+
batch = batch[overflow:] # type: ignore[assignment]
|
|
745
|
+
metrics.inc_transport("events_dropped", overflow)
|
|
746
|
+
# Append to END (not front) so oldest events are retried first
|
|
747
|
+
self._buffer.extend(batch)
|
|
748
|
+
# Update metrics on failure (thread-safe)
|
|
749
|
+
metrics.inc_transport("batches_failed")
|
|
750
|
+
|
|
751
|
+
def _drain_batch(self) -> list[dict[str, Any]] | None:
|
|
752
|
+
"""Round 2 (Phase 0.4.0): public, lock-acquiring snapshot of
|
|
753
|
+
the current buffer. Returns ``None`` when empty.
|
|
754
|
+
|
|
755
|
+
Used by ``tests/test_buffer_invariants.py``. The full flush
|
|
756
|
+
logic (CB, re-queue, metrics) lives in ``_do_flush_locked``;
|
|
757
|
+
this method is the read-only counterpart.
|
|
758
|
+
"""
|
|
759
|
+
with self._lock:
|
|
760
|
+
if not self._buffer:
|
|
761
|
+
return None
|
|
762
|
+
batch = list(self._buffer)
|
|
763
|
+
del self._buffer[:]
|
|
764
|
+
return batch
|
|
765
|
+
|
|
766
|
+
@dataclass
|
|
767
|
+
class SendResult:
|
|
768
|
+
accepted_event_ids: list
|
|
769
|
+
retry_after_ms: float | None = None
|
|
770
|
+
is_policy_limit: bool = False
|
|
771
|
+
|
|
772
|
+
def _add_hmac_headers(self, headers: dict[str, str], body: str) -> None:
|
|
773
|
+
"""
|
|
774
|
+
Add HMAC signing headers to request.
|
|
775
|
+
|
|
776
|
+
Adds:
|
|
777
|
+
- X-Signature-Timestamp: Unix timestamp for freshness
|
|
778
|
+
- X-Signature: HMAC-SHA256(api_key, secret, timestamp, body_hash)
|
|
779
|
+
|
|
780
|
+
Only adds signature if secret_key is configured.
|
|
781
|
+
"""
|
|
782
|
+
if not self.secret_key or not self.api_key:
|
|
783
|
+
return
|
|
784
|
+
|
|
785
|
+
timestamp = int(time.time())
|
|
786
|
+
signature = generate_hmac_signature(
|
|
787
|
+
self.api_key,
|
|
788
|
+
self.secret_key,
|
|
789
|
+
timestamp,
|
|
790
|
+
body,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
headers["X-Signature-Timestamp"] = str(timestamp)
|
|
794
|
+
headers["X-Signature"] = signature
|
|
795
|
+
|
|
796
|
+
def _build_signed_headers(
|
|
797
|
+
self,
|
|
798
|
+
body: str | bytes | None = None,
|
|
799
|
+
extra: dict[str, str] | None = None,
|
|
800
|
+
) -> dict[str, str]:
|
|
801
|
+
"""Build the canonical signed-headers dict for a request.
|
|
802
|
+
|
|
803
|
+
Round 2 (Phase 0.4.0): the canonical one-call helper used
|
|
804
|
+
by every signed POST. Mirrors the contract the test
|
|
805
|
+
framework in ``tests/test_hmac_signing.py`` expects.
|
|
806
|
+
|
|
807
|
+
Always includes:
|
|
808
|
+
- Content-Type: application/json
|
|
809
|
+
- X-API-Version: __api_version__
|
|
810
|
+
- X-API-Key: when api_key is set
|
|
811
|
+
|
|
812
|
+
Adds HMAC signature headers when secret_key is set and a
|
|
813
|
+
body is provided.
|
|
814
|
+
|
|
815
|
+
``extra`` is merged ON TOP of the defaults so callers can
|
|
816
|
+
override Content-Type or add custom headers.
|
|
817
|
+
"""
|
|
818
|
+
headers: dict[str, str] = {
|
|
819
|
+
"Content-Type": "application/json",
|
|
820
|
+
"X-API-Version": __api_version__,
|
|
821
|
+
}
|
|
822
|
+
if self.api_key:
|
|
823
|
+
headers["X-API-Key"] = self.api_key
|
|
824
|
+
if body is not None and self.secret_key and self.api_key:
|
|
825
|
+
body_str = body if isinstance(body, str) else body.decode("utf-8")
|
|
826
|
+
timestamp = int(time.time())
|
|
827
|
+
signature = generate_hmac_signature(
|
|
828
|
+
self.api_key, self.secret_key, timestamp, body_str
|
|
829
|
+
)
|
|
830
|
+
headers["X-Signature-Timestamp"] = str(timestamp)
|
|
831
|
+
headers["X-Signature"] = signature
|
|
832
|
+
if extra:
|
|
833
|
+
headers.update(extra)
|
|
834
|
+
# Inject trace context (W3C) as well — matches the
|
|
835
|
+
# end-to-end behaviour of every signed POST.
|
|
836
|
+
self._inject_trace_context(headers)
|
|
837
|
+
return headers
|
|
838
|
+
|
|
839
|
+
def _inject_trace_context(self, headers: dict[str, str]) -> None:
|
|
840
|
+
"""
|
|
841
|
+
Inject trace context into request headers (W3C Trace Context format).
|
|
842
|
+
|
|
843
|
+
This enables distributed tracing across SDK and backend.
|
|
844
|
+
Uses W3C Trace Context standard for trace_id propagation.
|
|
845
|
+
"""
|
|
846
|
+
if not _OTEL_AVAILABLE or not self._propagator:
|
|
847
|
+
return
|
|
848
|
+
|
|
849
|
+
carrier: dict[str, str] = {}
|
|
850
|
+
self._propagator.inject(carrier)
|
|
851
|
+
headers.update(carrier)
|
|
852
|
+
|
|
853
|
+
def _extract_retry_after(self, response: httpx.Response) -> float | None:
|
|
854
|
+
"""Extract Retry-After header value as seconds.
|
|
855
|
+
|
|
856
|
+
Handles both:
|
|
857
|
+
- Integer seconds (e.g., "30")
|
|
858
|
+
- HTTP-date format (e.g., "Wed, 21 Oct 2015 07:28:00 GMT")
|
|
859
|
+
"""
|
|
860
|
+
retry_after = response.headers.get("Retry-After")
|
|
861
|
+
if not retry_after:
|
|
862
|
+
return None
|
|
863
|
+
|
|
864
|
+
# Try parsing as seconds (integer or float)
|
|
865
|
+
try:
|
|
866
|
+
return float(retry_after)
|
|
867
|
+
except ValueError:
|
|
868
|
+
pass
|
|
869
|
+
|
|
870
|
+
# Try parsing as HTTP datetime (RFC 7231)
|
|
871
|
+
try:
|
|
872
|
+
from email.utils import parsedate_to_datetime
|
|
873
|
+
dt = parsedate_to_datetime(retry_after)
|
|
874
|
+
from datetime import datetime, timezone
|
|
875
|
+
return (dt - datetime.now(timezone.utc)).total_seconds()
|
|
876
|
+
except Exception:
|
|
877
|
+
pass
|
|
878
|
+
|
|
879
|
+
return None
|
|
880
|
+
|
|
881
|
+
def _send_batch_with_retry_info(self, batch: list[dict[str, Any]]) -> 'SendResult':
|
|
882
|
+
"""Send batch to server using batch endpoint. Returns SendResult with retry info."""
|
|
883
|
+
logger.debug(f"Sending batch of {len(batch)} events to {self.api_url}/api/v1/track/batch")
|
|
884
|
+
headers = {"Content-Type": "application/json", "X-API-Version": __api_version__}
|
|
885
|
+
if self.api_key:
|
|
886
|
+
headers["X-API-Key"] = self.api_key
|
|
887
|
+
|
|
888
|
+
# Add HMAC signature headers
|
|
889
|
+
body = json.dumps({"events": batch})
|
|
890
|
+
self._add_hmac_headers(headers, body)
|
|
891
|
+
|
|
892
|
+
# Inject trace context for distributed tracing (W3C Trace Context)
|
|
893
|
+
self._inject_trace_context(headers)
|
|
894
|
+
|
|
895
|
+
# Use batch endpoint for efficiency - single request for all events.
|
|
896
|
+
# We send ``content=body`` (the exact bytes that were HMAC-signed
|
|
897
|
+
# above) rather than ``json=...`` — the latter re-serialises the
|
|
898
|
+
# payload with httpx defaults (compact separators) and produces
|
|
899
|
+
# a body that does not match the body the HMAC signature was
|
|
900
|
+
# computed over. See plan B6.
|
|
901
|
+
response = self._client.post(
|
|
902
|
+
f"{self.api_url}/api/v1/track/batch",
|
|
903
|
+
content=body,
|
|
904
|
+
headers=headers,
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# P0: Extract retry_after from response headers or body
|
|
908
|
+
retry_after_seconds: float | None = None
|
|
909
|
+
retry_after_ms: float | None = None
|
|
910
|
+
is_policy_limit = False
|
|
911
|
+
|
|
912
|
+
# Check Retry-After header (may be seconds or HTTP-date)
|
|
913
|
+
retry_after_seconds = self._extract_retry_after(response)
|
|
914
|
+
|
|
915
|
+
# Check response body for retry info
|
|
916
|
+
try:
|
|
917
|
+
data = response.json()
|
|
918
|
+
# Check for rejection info
|
|
919
|
+
if 'rejected' in data and data['rejected']:
|
|
920
|
+
rejected_info = data['rejected']
|
|
921
|
+
if isinstance(rejected_info, dict):
|
|
922
|
+
if 'retry_after_ms' in rejected_info:
|
|
923
|
+
retry_after_ms = rejected_info['retry_after_ms']
|
|
924
|
+
if 'reason' in rejected_info and rejected_info['reason'] == 'policy_limit':
|
|
925
|
+
is_policy_limit = True
|
|
926
|
+
except Exception: # noqa: S110
|
|
927
|
+
pass
|
|
928
|
+
|
|
929
|
+
# Store for next retry calculation (prefer header seconds, fallback to body ms)
|
|
930
|
+
if retry_after_seconds is not None:
|
|
931
|
+
self._last_retry_after_seconds = retry_after_seconds
|
|
932
|
+
retry_after_ms = retry_after_seconds * 1000
|
|
933
|
+
elif retry_after_ms is not None:
|
|
934
|
+
self._last_retry_after_seconds = retry_after_ms / 1000.0
|
|
935
|
+
else:
|
|
936
|
+
self._last_retry_after_seconds = 0.0
|
|
937
|
+
self._last_failure_policy_limit = is_policy_limit
|
|
938
|
+
|
|
939
|
+
# Handle 429 response - extract and store Retry-After before raising
|
|
940
|
+
if response.status_code == 429:
|
|
941
|
+
retry_after = self._extract_retry_after(response)
|
|
942
|
+
if retry_after:
|
|
943
|
+
self._last_retry_after_seconds = retry_after
|
|
944
|
+
response.raise_for_status()
|
|
945
|
+
response.raise_for_status()
|
|
946
|
+
|
|
947
|
+
# Process actions_taken from server response
|
|
948
|
+
try:
|
|
949
|
+
data = response.json()
|
|
950
|
+
actions = data.get("actions_taken", [])
|
|
951
|
+
for action in actions:
|
|
952
|
+
action_type = action.get("type", "")
|
|
953
|
+
workflow_id = action.get("workflow_id", "unknown")
|
|
954
|
+
reason = action.get("reason", "")
|
|
955
|
+
if action_type:
|
|
956
|
+
handle_action(action_type, workflow_id, reason)
|
|
957
|
+
except Exception as e:
|
|
958
|
+
logger.warning(f"Failed to process actions_taken: {e}")
|
|
959
|
+
|
|
960
|
+
# Return accepted event_ids for retry dedup
|
|
961
|
+
accepted_event_ids = data.get("accepted_event_ids", []) if 'data' in locals() else []
|
|
962
|
+
logger.debug(f"Batch track: sent {len(batch)} events")
|
|
963
|
+
return self.SendResult(
|
|
964
|
+
accepted_event_ids=accepted_event_ids,
|
|
965
|
+
retry_after_ms=retry_after_ms,
|
|
966
|
+
is_policy_limit=is_policy_limit
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
def flush_now(self) -> None:
|
|
970
|
+
"""Force immediate flush."""
|
|
971
|
+
self._do_flush()
|
|
972
|
+
|
|
973
|
+
# =============================================================================
|
|
974
|
+
# Execute (Strict Mode) - Phase 1
|
|
975
|
+
# =============================================================================
|
|
976
|
+
|
|
977
|
+
def execute(
|
|
978
|
+
self,
|
|
979
|
+
organization_id: str,
|
|
980
|
+
execution_id: str,
|
|
981
|
+
trace_id: str,
|
|
982
|
+
tool: str,
|
|
983
|
+
input_data: dict[str, Any],
|
|
984
|
+
mode: str = "auto",
|
|
985
|
+
fallback_mode: str = FallbackMode.PERMISSIVE,
|
|
986
|
+
operation_id: str | None = None,
|
|
987
|
+
on_transport_error: Callable[[Exception], dict[str, Any]] | None = None,
|
|
988
|
+
) -> dict[str, Any]:
|
|
989
|
+
"""
|
|
990
|
+
Pre-execution policy evaluation via unified gate endpoint.
|
|
991
|
+
|
|
992
|
+
This is the PRIMARY enforcement point - decision is made BEFORE execution.
|
|
993
|
+
Uses /api/v1/gate endpoint for unified execute + check functionality.
|
|
994
|
+
|
|
995
|
+
Args:
|
|
996
|
+
organization_id: Organization identifier
|
|
997
|
+
execution_id: Execution identifier
|
|
998
|
+
trace_id: Distributed trace ID
|
|
999
|
+
tool: Tool to execute
|
|
1000
|
+
input_data: Tool input
|
|
1001
|
+
mode: Execution mode ("auto", "inline", "strict")
|
|
1002
|
+
fallback_mode: What to do if Gateway unavailable
|
|
1003
|
+
operation_id: Optional idempotency key
|
|
1004
|
+
on_transport_error: Optional callback invoked on
|
|
1005
|
+
``BreakerTransportError`` (Phase 5 #5.10). When set, the
|
|
1006
|
+
callback's return value is returned verbatim; otherwise
|
|
1007
|
+
the request falls through to the ``fallback_mode``
|
|
1008
|
+
default. The decorator's ``_enforce_sensitive_tool``
|
|
1009
|
+
sets this to a closure that converts the error into a
|
|
1010
|
+
``NullRunBlockedException`` (fail-CLOSED).
|
|
1011
|
+
|
|
1012
|
+
Returns:
|
|
1013
|
+
Dict with:
|
|
1014
|
+
- decision: "allow" | "block" | "flag" | "pause" | "require_approval"
|
|
1015
|
+
- decision_source: "gateway" | "cached" | "fallback"
|
|
1016
|
+
- explanation: Human-readable explanation
|
|
1017
|
+
- policy_version: Policy version used
|
|
1018
|
+
- decision_context: Context for replay (if available)
|
|
1019
|
+
"""
|
|
1020
|
+
gate_request = {
|
|
1021
|
+
"organization_id": organization_id,
|
|
1022
|
+
"execution_id": execution_id,
|
|
1023
|
+
"trace_id": trace_id,
|
|
1024
|
+
"tool": tool,
|
|
1025
|
+
"input": input_data,
|
|
1026
|
+
"mode": mode,
|
|
1027
|
+
"operation_id": operation_id or str(uuid.uuid4()),
|
|
1028
|
+
}
|
|
1029
|
+
|
|
1030
|
+
headers = {"Content-Type": "application/json"}
|
|
1031
|
+
if self.api_key:
|
|
1032
|
+
headers["X-API-Key"] = self.api_key
|
|
1033
|
+
|
|
1034
|
+
# HMAC fix: serialise via the canonical-bytes helper and send
|
|
1035
|
+
# via content=body so the wire bytes match the signed bytes.
|
|
1036
|
+
# See ``_signed_request_body`` for the rationale.
|
|
1037
|
+
body = _signed_request_body(gate_request)
|
|
1038
|
+
self._add_hmac_headers(headers, body.decode("utf-8"))
|
|
1039
|
+
|
|
1040
|
+
# Inject trace context for distributed tracing (W3C Trace Context)
|
|
1041
|
+
self._inject_trace_context(headers)
|
|
1042
|
+
|
|
1043
|
+
def do_gate_request() -> httpx.Response:
|
|
1044
|
+
return self._client.post(
|
|
1045
|
+
f"{self.api_url}/api/v1/gate",
|
|
1046
|
+
content=body,
|
|
1047
|
+
headers=headers,
|
|
1048
|
+
timeout=5.0,
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
# Try Gateway with retry backoff
|
|
1052
|
+
try:
|
|
1053
|
+
response = _retry_with_backoff(
|
|
1054
|
+
do_gate_request,
|
|
1055
|
+
max_retries=2,
|
|
1056
|
+
base_delay=0.5,
|
|
1057
|
+
on_transport_error=on_transport_error,
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
if response.status_code == 200:
|
|
1061
|
+
data = response.json()
|
|
1062
|
+
data["decision_source"] = DecisionSource.GATEWAY
|
|
1063
|
+
# Cache successful decision for CACHED mode
|
|
1064
|
+
cache_key = self._policy_cache.make_key(
|
|
1065
|
+
organization_id,
|
|
1066
|
+
data.get("policy_version")
|
|
1067
|
+
)
|
|
1068
|
+
self._policy_cache.set(
|
|
1069
|
+
cache_key,
|
|
1070
|
+
data.get("decision", "allow"),
|
|
1071
|
+
data.get("policy_id"),
|
|
1072
|
+
data.get("policy_version")
|
|
1073
|
+
)
|
|
1074
|
+
return data # type: ignore[no-any-return]
|
|
1075
|
+
elif response.status_code >= 400:
|
|
1076
|
+
# 4xx - don't retry, return block
|
|
1077
|
+
return {
|
|
1078
|
+
"decision": "block",
|
|
1079
|
+
"decision_source": DecisionSource.FALLBACK,
|
|
1080
|
+
"explanation": f"Gateway returned {response.status_code}",
|
|
1081
|
+
"policy_version": 0,
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
except BreakerTransportError as exc:
|
|
1085
|
+
# Phase 5 #5.10: ADR-008 lets callers opt into a
|
|
1086
|
+
# classified-error handler. Round 3 (Phase 0.4.0):
|
|
1087
|
+
# on_transport_error accepts both callables AND strings:
|
|
1088
|
+
# "raise" -> raise NullRunTransportError (classified)
|
|
1089
|
+
# "open" -> return synthetic allow with FALLBACK_* source
|
|
1090
|
+
# "closed" -> return synthetic block with FALLBACK_* source
|
|
1091
|
+
# callable -> call with the breaker error, return the result
|
|
1092
|
+
# None -> fall through to the legacy fallback-mode default
|
|
1093
|
+
if on_transport_error == "raise":
|
|
1094
|
+
# Re-raise as a classified transport error.
|
|
1095
|
+
raise NullRunTransportError(
|
|
1096
|
+
f"Gateway unreachable on /execute: {exc}",
|
|
1097
|
+
source=TransportErrorSource.NETWORK_ERROR,
|
|
1098
|
+
endpoint="execute",
|
|
1099
|
+
) from exc
|
|
1100
|
+
if callable(on_transport_error):
|
|
1101
|
+
return on_transport_error(exc)
|
|
1102
|
+
if on_transport_error == "open":
|
|
1103
|
+
return {
|
|
1104
|
+
"decision": "allow",
|
|
1105
|
+
"decision_source": TransportErrorSource.NETWORK_ERROR,
|
|
1106
|
+
"explanation": f"Gateway unreachable: {exc}",
|
|
1107
|
+
"policy_version": 0,
|
|
1108
|
+
}
|
|
1109
|
+
if on_transport_error == "closed":
|
|
1110
|
+
return {
|
|
1111
|
+
"decision": "block",
|
|
1112
|
+
"decision_source": TransportErrorSource.NETWORK_ERROR,
|
|
1113
|
+
"explanation": f"Gateway unreachable: {exc}",
|
|
1114
|
+
"policy_version": 0,
|
|
1115
|
+
}
|
|
1116
|
+
pass # fall through to fallback mode
|
|
1117
|
+
except NullRunTransportError:
|
|
1118
|
+
raise # Already classified -- propagate as-is
|
|
1119
|
+
except httpx.RequestError as exc:
|
|
1120
|
+
# Round 3: classify httpx network errors at the call site.
|
|
1121
|
+
if on_transport_error == "raise":
|
|
1122
|
+
raise NullRunTransportError(
|
|
1123
|
+
f"Network error on /execute: {exc}",
|
|
1124
|
+
source=TransportErrorSource.NETWORK_ERROR,
|
|
1125
|
+
endpoint="execute",
|
|
1126
|
+
) from exc
|
|
1127
|
+
raise
|
|
1128
|
+
except NullRunAuthenticationError:
|
|
1129
|
+
raise # Don't fall back on auth errors
|
|
1130
|
+
|
|
1131
|
+
# All attempts failed - apply fallback mode
|
|
1132
|
+
# Sprint 3 follow-up (B24): bump ``fallback_mode_activations``
|
|
1133
|
+
# every time we reach this branch (gateway unreachable).
|
|
1134
|
+
# The operator alerts on a spike here as a proxy for
|
|
1135
|
+
# backend unavailability.
|
|
1136
|
+
metrics.inc_transport("fallback_mode_activations")
|
|
1137
|
+
if fallback_mode == FallbackMode.STRICT:
|
|
1138
|
+
return {
|
|
1139
|
+
"decision": "block",
|
|
1140
|
+
"decision_source": DecisionSource.FALLBACK,
|
|
1141
|
+
"explanation": "Gateway unavailable, fallback=STRICT",
|
|
1142
|
+
"policy_version": 0,
|
|
1143
|
+
}
|
|
1144
|
+
elif fallback_mode == FallbackMode.CACHED:
|
|
1145
|
+
# Use cached decision if available
|
|
1146
|
+
cache_key = self._policy_cache.make_key(organization_id)
|
|
1147
|
+
cached = self._policy_cache.get(cache_key)
|
|
1148
|
+
if cached:
|
|
1149
|
+
logger.warning("Gateway unreachable, using cached decision for %s", tool)
|
|
1150
|
+
return {
|
|
1151
|
+
"decision": cached.decision,
|
|
1152
|
+
"decision_source": DecisionSource.CACHED,
|
|
1153
|
+
"explanation": "Gateway unavailable, using cached decision",
|
|
1154
|
+
"policy_version": cached.policy_version or 0,
|
|
1155
|
+
}
|
|
1156
|
+
else:
|
|
1157
|
+
logger.warning(
|
|
1158
|
+
"Gateway unreachable, no cache for %s, "
|
|
1159
|
+
"falling back to PERMISSIVE",
|
|
1160
|
+
tool
|
|
1161
|
+
)
|
|
1162
|
+
return {
|
|
1163
|
+
"decision": "allow",
|
|
1164
|
+
"decision_source": DecisionSource.FALLBACK,
|
|
1165
|
+
"explanation": "Gateway unavailable, no cache available",
|
|
1166
|
+
"policy_version": 0,
|
|
1167
|
+
}
|
|
1168
|
+
else: # PERMISSIVE (default)
|
|
1169
|
+
return {
|
|
1170
|
+
"decision": "allow",
|
|
1171
|
+
"decision_source": DecisionSource.FALLBACK,
|
|
1172
|
+
"explanation": "Gateway unavailable, fallback=PERMISSIVE",
|
|
1173
|
+
"policy_version": 0,
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1176
|
+
def check(
|
|
1177
|
+
self,
|
|
1178
|
+
check_request: dict[str, Any],
|
|
1179
|
+
on_transport_error: Callable[[Exception], dict[str, Any]] | str | None = None,
|
|
1180
|
+
) -> dict[str, Any]:
|
|
1181
|
+
"""
|
|
1182
|
+
Call /api/v1/gate endpoint for pre-execution budget checking.
|
|
1183
|
+
|
|
1184
|
+
Uses the unified gate endpoint with check_type for budget validation.
|
|
1185
|
+
Supports idempotency via operation_id field.
|
|
1186
|
+
|
|
1187
|
+
Args:
|
|
1188
|
+
check_request: Dict with:
|
|
1189
|
+
- organization_id: Organization identifier
|
|
1190
|
+
- execution_id: Execution identifier
|
|
1191
|
+
- operation_id: Operation identifier (for idempotency)
|
|
1192
|
+
- check_type: "llm" or "tool"
|
|
1193
|
+
- model: Model name (for LLM checks)
|
|
1194
|
+
- tool_name: Tool name (for tool checks)
|
|
1195
|
+
- estimated_tokens: Token count (for LLM checks)
|
|
1196
|
+
- input: Optional input data
|
|
1197
|
+
|
|
1198
|
+
Returns:
|
|
1199
|
+
Dict with:
|
|
1200
|
+
- decision: "allow" | "block" | "throttle"
|
|
1201
|
+
- reservation_id: Optional reservation ID
|
|
1202
|
+
- remaining_budget_cents: Remaining budget
|
|
1203
|
+
- projected_cost_cents: Projected cost for this operation
|
|
1204
|
+
- explanations: List of explanation strings
|
|
1205
|
+
- suggestions: List of suggestion strings
|
|
1206
|
+
"""
|
|
1207
|
+
# Convert check_request to gate_request format
|
|
1208
|
+
gate_request = {
|
|
1209
|
+
"organization_id": check_request.get("organization_id"),
|
|
1210
|
+
"execution_id": check_request.get("execution_id"),
|
|
1211
|
+
"trace_id": check_request.get("trace_id", str(uuid.uuid4())),
|
|
1212
|
+
"tool": check_request.get("tool_name") or check_request.get("tool"),
|
|
1213
|
+
"input": check_request.get("input"),
|
|
1214
|
+
"mode": "auto",
|
|
1215
|
+
"check_type": check_request.get("check_type"),
|
|
1216
|
+
"model": check_request.get("model"),
|
|
1217
|
+
"estimated_tokens": check_request.get("estimated_tokens"),
|
|
1218
|
+
"operation_id": check_request.get("operation_id") or str(uuid.uuid4()),
|
|
1219
|
+
}
|
|
1220
|
+
|
|
1221
|
+
headers = {"Content-Type": "application/json"}
|
|
1222
|
+
if self.api_key:
|
|
1223
|
+
headers["X-API-Key"] = self.api_key
|
|
1224
|
+
headers["X-API-Version"] = __api_version__
|
|
1225
|
+
|
|
1226
|
+
# HMAC fix: serialise via the canonical-bytes helper and send
|
|
1227
|
+
# via content=body so the wire bytes match the signed bytes.
|
|
1228
|
+
body = _signed_request_body(gate_request)
|
|
1229
|
+
self._add_hmac_headers(headers, body.decode("utf-8"))
|
|
1230
|
+
|
|
1231
|
+
# Inject trace context for distributed tracing (W3C Trace Context)
|
|
1232
|
+
self._inject_trace_context(headers)
|
|
1233
|
+
|
|
1234
|
+
try:
|
|
1235
|
+
response = self._client.post(
|
|
1236
|
+
f"{self.api_url}/api/v1/gate",
|
|
1237
|
+
content=body,
|
|
1238
|
+
headers=headers,
|
|
1239
|
+
timeout=5.0,
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1242
|
+
if response.status_code == 200:
|
|
1243
|
+
return response.json() # type: ignore[no-any-return]
|
|
1244
|
+
else:
|
|
1245
|
+
# 4xx always -> synthetic block. 5xx only raises when
|
|
1246
|
+
# the caller opted into the typed-error contract via
|
|
1247
|
+
# on_transport_error="raise"; otherwise it's also a
|
|
1248
|
+
# synthetic block (legacy behaviour).
|
|
1249
|
+
if response.status_code >= 500 and on_transport_error == "raise":
|
|
1250
|
+
raise NullRunTransportError(
|
|
1251
|
+
f"Gateway returned {response.status_code}",
|
|
1252
|
+
source=TransportErrorSource.GATEWAY_ERROR,
|
|
1253
|
+
endpoint="check",
|
|
1254
|
+
status_code=response.status_code,
|
|
1255
|
+
)
|
|
1256
|
+
return {
|
|
1257
|
+
"decision": "block",
|
|
1258
|
+
"decision_source": DecisionSource.FALLBACK,
|
|
1259
|
+
"reservation_id": None,
|
|
1260
|
+
"remaining_budget_cents": 0,
|
|
1261
|
+
"projected_cost_cents": 0,
|
|
1262
|
+
"explanations": [f"Gate endpoint returned {response.status_code}"],
|
|
1263
|
+
"suggestions": ["Check API availability"],
|
|
1264
|
+
}
|
|
1265
|
+
except httpx.RequestError as e:
|
|
1266
|
+
# Round 3: classify network errors. By default fall
|
|
1267
|
+
# through to synthetic block (legacy); raise only when
|
|
1268
|
+
# the caller opted in via on_transport_error="raise".
|
|
1269
|
+
if on_transport_error == "raise":
|
|
1270
|
+
raise NullRunTransportError(
|
|
1271
|
+
f"Network error on /check: {e}",
|
|
1272
|
+
source=TransportErrorSource.NETWORK_ERROR,
|
|
1273
|
+
endpoint="check",
|
|
1274
|
+
) from e
|
|
1275
|
+
logger.warning(f"Gate request failed: {e}")
|
|
1276
|
+
return {
|
|
1277
|
+
"decision": "block",
|
|
1278
|
+
"decision_source": DecisionSource.FALLBACK,
|
|
1279
|
+
"reservation_id": None,
|
|
1280
|
+
"remaining_budget_cents": 0,
|
|
1281
|
+
"projected_cost_cents": 0,
|
|
1282
|
+
"explanations": [f"Gate request failed: {e}"],
|
|
1283
|
+
"suggestions": ["Check API availability"],
|
|
1284
|
+
}
|
|
1285
|
+
|
|
1286
|
+
# =============================================================================
|
|
1287
|
+
# WebSocket Connection (Task 6 - WebSocket Push)
|
|
1288
|
+
# =============================================================================
|
|
1289
|
+
|
|
1290
|
+
def clear_policy_cache(self) -> None:
|
|
1291
|
+
"""Clear the policy cache, forcing next gate/execute to fetch fresh policy."""
|
|
1292
|
+
if hasattr(self, '_policy_cache'):
|
|
1293
|
+
self._policy_cache._cache.clear()
|
|
1294
|
+
logger.debug("Policy cache cleared")
|
|
1295
|
+
|
|
1296
|
+
async def connect_websocket(
|
|
1297
|
+
self,
|
|
1298
|
+
organization_id: str,
|
|
1299
|
+
on_state_change: Callable[[dict[str, Any]], None] | None = None,
|
|
1300
|
+
on_policy_invalidated: Callable[[str, str, int], None] | None = None,
|
|
1301
|
+
on_key_rotated: Callable[[str, str, int], None] | None = None,
|
|
1302
|
+
) -> "WebSocketConnection":
|
|
1303
|
+
"""
|
|
1304
|
+
Connect to WebSocket control plane for real-time workflow state updates.
|
|
1305
|
+
|
|
1306
|
+
This replaces polling GET /status/{workflow_id} with WebSocket push.
|
|
1307
|
+
When the workflow state changes (KILL/PAUSE), the server pushes the update.
|
|
1308
|
+
|
|
1309
|
+
Args:
|
|
1310
|
+
organization_id: Organization identifier
|
|
1311
|
+
on_state_change: Optional callback for state change notifications
|
|
1312
|
+
on_policy_invalidated: Optional callback for policy cache invalidation.
|
|
1313
|
+
When called, clears local policy cache so next
|
|
1314
|
+
gate/execute fetches fresh policy from backend.
|
|
1315
|
+
Args: (organization_id, policy_id, new_version)
|
|
1316
|
+
on_key_rotated: Optional callback for HMAC key rotation.
|
|
1317
|
+
When called, should re-fetch secret_key from /auth/verify.
|
|
1318
|
+
Args: (organization_id, key_id, new_version)
|
|
1319
|
+
|
|
1320
|
+
Returns:
|
|
1321
|
+
WebSocketConnection instance
|
|
1322
|
+
|
|
1323
|
+
Raises:
|
|
1324
|
+
ConnectionError: If WebSocket connection fails
|
|
1325
|
+
"""
|
|
1326
|
+
# Phase 6 #6.6: build the WS URL via urllib.parse instead of
|
|
1327
|
+
# string replace. Reject unknown schemes with a clear error.
|
|
1328
|
+
from urllib.parse import urlparse, urlunparse
|
|
1329
|
+
|
|
1330
|
+
from nullrun.transport_websocket import WebSocketConnection
|
|
1331
|
+
parsed = urlparse(self.api_url)
|
|
1332
|
+
if parsed.scheme not in ("http", "https"):
|
|
1333
|
+
raise ValueError(
|
|
1334
|
+
f"Unsupported scheme for control plane: {parsed.scheme!r}"
|
|
1335
|
+
)
|
|
1336
|
+
ws_scheme = "wss" if parsed.scheme == "https" else "ws"
|
|
1337
|
+
ws_url = urlunparse(
|
|
1338
|
+
parsed._replace(
|
|
1339
|
+
scheme=ws_scheme,
|
|
1340
|
+
path=f"/ws/control/{organization_id}",
|
|
1341
|
+
params="",
|
|
1342
|
+
query="",
|
|
1343
|
+
fragment="",
|
|
1344
|
+
)
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1347
|
+
headers = {"Content-Type": "application/json"}
|
|
1348
|
+
if self.api_key:
|
|
1349
|
+
headers["X-API-Key"] = self.api_key
|
|
1350
|
+
|
|
1351
|
+
# Wrap the policy invalidated callback to clear local cache
|
|
1352
|
+
async def wrapped_policy_invalidated(ws_id: str, policy_id: str, new_version: int) -> None:
|
|
1353
|
+
logger.info(f"Policy {policy_id} invalidated (v{new_version}), clearing policy cache")
|
|
1354
|
+
self.clear_policy_cache()
|
|
1355
|
+
if on_policy_invalidated:
|
|
1356
|
+
on_policy_invalidated(ws_id, policy_id, new_version)
|
|
1357
|
+
|
|
1358
|
+
# Wrap the key rotated callback to re-fetch credentials
|
|
1359
|
+
async def wrapped_key_rotated(ws_id: str, key_id: str, new_version: int) -> None:
|
|
1360
|
+
logger.info(f"Key {key_id} rotated (v{new_version}), re-fetching credentials")
|
|
1361
|
+
await self._refetch_credentials()
|
|
1362
|
+
if on_key_rotated:
|
|
1363
|
+
on_key_rotated(ws_id, key_id, new_version)
|
|
1364
|
+
|
|
1365
|
+
conn = WebSocketConnection(
|
|
1366
|
+
url=ws_url,
|
|
1367
|
+
headers=headers,
|
|
1368
|
+
api_key=self.api_key,
|
|
1369
|
+
secret_key=self.secret_key,
|
|
1370
|
+
on_state_change=on_state_change,
|
|
1371
|
+
on_policy_invalidated=wrapped_policy_invalidated,
|
|
1372
|
+
on_key_rotated=wrapped_key_rotated,
|
|
1373
|
+
)
|
|
1374
|
+
await conn.connect()
|
|
1375
|
+
return conn
|
|
1376
|
+
|
|
1377
|
+
async def _refetch_credentials(self) -> None:
|
|
1378
|
+
"""
|
|
1379
|
+
Re-fetch credentials from /auth/verify after key rotation.
|
|
1380
|
+
|
|
1381
|
+
This is called when the server notifies us via WebSocket that
|
|
1382
|
+
our HMAC secret_key has been rotated. We need to get the new
|
|
1383
|
+
secret_key from the /auth/verify endpoint.
|
|
1384
|
+
|
|
1385
|
+
Sprint 2.4 (B20): the previous implementation used
|
|
1386
|
+
``import requests`` and bypassed every transport-layer
|
|
1387
|
+
invariant — the shared ``httpx.Client`` (mTLS, connection
|
|
1388
|
+
pool), the circuit breaker, the HMAC body signature, and
|
|
1389
|
+
the retry policy. It also pulled in ``requests`` as a new
|
|
1390
|
+
dependency that is not in ``pyproject.toml`` (a runtime
|
|
1391
|
+
ImportError waiting to happen on any environment where
|
|
1392
|
+
``requests`` is not installed transitively).
|
|
1393
|
+
|
|
1394
|
+
Post-fix: route through ``self._client`` so the same TLS
|
|
1395
|
+
configuration, connection pool, and HMAC signing path
|
|
1396
|
+
apply. Body is serialised via ``_signed_request_body`` so
|
|
1397
|
+
the wire bytes match the signed bytes.
|
|
1398
|
+
"""
|
|
1399
|
+
try:
|
|
1400
|
+
payload = {"api_key": self.api_key}
|
|
1401
|
+
body = _signed_request_body(payload)
|
|
1402
|
+
headers: dict[str, str] = {
|
|
1403
|
+
"Content-Type": "application/json",
|
|
1404
|
+
"X-API-Key": self.api_key or "",
|
|
1405
|
+
}
|
|
1406
|
+
# Re-use the same HMAC headers as /gate and /track so
|
|
1407
|
+
# the server's auth-verify path is consistent.
|
|
1408
|
+
self._add_hmac_headers(headers, body.decode("utf-8"))
|
|
1409
|
+
|
|
1410
|
+
response = self._client.post(
|
|
1411
|
+
f"{self.api_url}/auth/verify",
|
|
1412
|
+
content=body,
|
|
1413
|
+
headers=headers,
|
|
1414
|
+
timeout=10.0,
|
|
1415
|
+
)
|
|
1416
|
+
if response.status_code == 200:
|
|
1417
|
+
data = response.json()
|
|
1418
|
+
new_secret = data.get("secret_key")
|
|
1419
|
+
if new_secret:
|
|
1420
|
+
logger.info("Successfully fetched new secret_key from /auth/verify")
|
|
1421
|
+
self.secret_key = new_secret
|
|
1422
|
+
else:
|
|
1423
|
+
logger.warning("/auth/verify did not return secret_key in response")
|
|
1424
|
+
else:
|
|
1425
|
+
logger.warning(f"Failed to refetch credentials: {response.status_code}")
|
|
1426
|
+
except Exception as e:
|
|
1427
|
+
logger.error(f"Error refetching credentials: {e}")
|
|
1428
|
+
|
|
1429
|
+
|
|
1430
|
+
def _parse_error_envelope(
|
|
1431
|
+
response: httpx.Response,
|
|
1432
|
+
endpoint: str,
|
|
1433
|
+
) -> Exception:
|
|
1434
|
+
"""Translate a non-2xx ``httpx.Response`` into the right exception
|
|
1435
|
+
subclass per the canonical ``contracts/errors.ts`` envelope.
|
|
1436
|
+
|
|
1437
|
+
4xx/5xx/429 are mapped to distinct ``RateLimitError`` /
|
|
1438
|
+
``NullRunAuthenticationError`` / ``NullRunTransportError(GATEWAY_ERROR)``
|
|
1439
|
+
so callers branch on type instead of string-matching ``str(exc)``.
|
|
1440
|
+
|
|
1441
|
+
Module-level helper (not a Transport method) so it can be called
|
|
1442
|
+
from background threads that do not carry a Transport instance.
|
|
1443
|
+
"""
|
|
1444
|
+
status = response.status_code
|
|
1445
|
+
try:
|
|
1446
|
+
body = response.json()
|
|
1447
|
+
except Exception:
|
|
1448
|
+
body = None
|
|
1449
|
+
if not isinstance(body, dict):
|
|
1450
|
+
body = {}
|
|
1451
|
+
error_slug: str = body.get("error", "") or ""
|
|
1452
|
+
message: str = (
|
|
1453
|
+
body.get("message")
|
|
1454
|
+
or response.text
|
|
1455
|
+
or f"HTTP {status}"
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
if status in (401, 403):
|
|
1459
|
+
return NullRunAuthenticationError(
|
|
1460
|
+
f"Auth failed on {endpoint} (status {status}, "
|
|
1461
|
+
f"error={error_slug!r}): {message}"
|
|
1462
|
+
)
|
|
1463
|
+
|
|
1464
|
+
if status == 429:
|
|
1465
|
+
retry_after: float | None = None
|
|
1466
|
+
ra_header = response.headers.get("Retry-After")
|
|
1467
|
+
if ra_header:
|
|
1468
|
+
try:
|
|
1469
|
+
retry_after = float(ra_header)
|
|
1470
|
+
except ValueError:
|
|
1471
|
+
try:
|
|
1472
|
+
from datetime import datetime, timezone
|
|
1473
|
+
from email.utils import parsedate_to_datetime
|
|
1474
|
+
dt = parsedate_to_datetime(ra_header)
|
|
1475
|
+
retry_after = (
|
|
1476
|
+
dt - datetime.now(timezone.utc)
|
|
1477
|
+
).total_seconds()
|
|
1478
|
+
except Exception:
|
|
1479
|
+
retry_after = None
|
|
1480
|
+
upgrade_url = body.get("upgrade_url") if isinstance(body, dict) else None
|
|
1481
|
+
return RateLimitError(
|
|
1482
|
+
f"Rate limited on {endpoint} (status 429, error={error_slug!r}): "
|
|
1483
|
+
f"{message}",
|
|
1484
|
+
source=TransportErrorSource.GATEWAY_ERROR,
|
|
1485
|
+
endpoint=endpoint,
|
|
1486
|
+
retry_after=retry_after,
|
|
1487
|
+
upgrade_url=upgrade_url,
|
|
1488
|
+
body=body,
|
|
1489
|
+
)
|
|
1490
|
+
|
|
1491
|
+
if 500 <= status < 600:
|
|
1492
|
+
return NullRunTransportError(
|
|
1493
|
+
f"Gateway error on {endpoint} (status {status}, "
|
|
1494
|
+
f"error={error_slug!r}): {message}",
|
|
1495
|
+
source=TransportErrorSource.GATEWAY_ERROR,
|
|
1496
|
+
endpoint=endpoint,
|
|
1497
|
+
status_code=status,
|
|
1498
|
+
error_slug=error_slug,
|
|
1499
|
+
)
|
|
1500
|
+
|
|
1501
|
+
return NullRunTransportError(
|
|
1502
|
+
f"Client error on {endpoint} (status {status}, "
|
|
1503
|
+
f"error={error_slug!r}): {message}",
|
|
1504
|
+
source=TransportErrorSource.GATEWAY_ERROR,
|
|
1505
|
+
endpoint=endpoint,
|
|
1506
|
+
status_code=status,
|
|
1507
|
+
error_slug=error_slug,
|
|
1508
|
+
)
|
|
1509
|
+
|