nullrun 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nullrun/transport.py ADDED
@@ -0,0 +1,1509 @@
1
+ """
2
+ Transport layer for NullRun SDK.
3
+
4
+ Handles HTTP communication with batching and background flush.
5
+ Includes fallback modes for Gateway unavailability.
6
+ """
7
+
8
+ import hashlib
9
+ import hmac
10
+ import json
11
+ import logging
12
+ import os
13
+ import random
14
+ import threading
15
+ import time
16
+ import uuid
17
+ import weakref
18
+ from collections import OrderedDict
19
+ from collections.abc import Callable
20
+ from dataclasses import dataclass
21
+ from typing import Any
22
+
23
+ import httpx
24
+
25
+ from nullrun.actions import handle_action
26
+ from nullrun.breaker.circuit_breaker import CircuitBreaker
27
+ from nullrun.breaker.exceptions import (
28
+ BreakerTransportError,
29
+ InsecureTransportError,
30
+ NullRunAuthenticationError,
31
+ NullRunTransportError,
32
+ RateLimitError,
33
+ TransportErrorSource,
34
+ )
35
+ from nullrun.observability import metrics
36
+
37
+ # OpenTelemetry imports (lazy-loaded to support optional dependency)
38
+ try:
39
+ from opentelemetry import trace
40
+ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
41
+ _OTEL_AVAILABLE = True
42
+ except ImportError:
43
+ _OTEL_AVAILABLE = False
44
+ trace = None # type: ignore[assignment]
45
+ TraceContextTextMapPropagator = None # type: ignore[assignment]
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+
51
+
52
+
53
+ __api_version__ = "1.0"
54
+
55
+
56
+ # =============================================================================
57
+ # HMAC Request Signing (Task 11)
58
+ # =============================================================================
59
+
60
+ def generate_hmac_signature(
61
+ api_key: str,
62
+ secret_key: str,
63
+ timestamp: int,
64
+ body: str,
65
+ ) -> str:
66
+ """
67
+ Generate HMAC-SHA256 signature for request authentication.
68
+
69
+ Signature = HMAC-SHA256(secret_key, timestamp + ":" + api_key + ":" + body_hash)
70
+ Body hash = SHA256(request_body)
71
+
72
+ This provides:
73
+ - Authentication: API key identifies the client
74
+ - Integrity: Body hash ensures request hasn't been tampered with
75
+ - Freshness: Timestamp prevents replay attacks
76
+
77
+ Args:
78
+ api_key: Client's API key (identifier)
79
+ secret_key: Client's secret key (used for HMAC)
80
+ timestamp: Unix timestamp in seconds
81
+ body: Request body as JSON string
82
+
83
+ Returns:
84
+ Hex-encoded HMAC-SHA256 signature
85
+ """
86
+ body_hash = hashlib.sha256(body.encode('utf-8')).hexdigest()
87
+ message = f"{timestamp}:{api_key}:{body_hash}"
88
+
89
+ signature = hmac.new(
90
+ secret_key.encode('utf-8'),
91
+ message.encode('utf-8'),
92
+ hashlib.sha256
93
+ ).hexdigest()
94
+
95
+ return signature
96
+
97
+
98
+ def verify_hmac_signature(
99
+ api_key: str,
100
+ secret_key: str,
101
+ timestamp: int,
102
+ body: str,
103
+ signature: str,
104
+ max_age_seconds: int = 300,
105
+ ) -> bool:
106
+ """
107
+ Verify HMAC signature from request.
108
+
109
+ Args:
110
+ api_key: Client's API key
111
+ secret_key: Client's secret key
112
+ timestamp: Unix timestamp from request
113
+ body: Request body as JSON string
114
+ signature: HMAC signature to verify
115
+ max_age_seconds: Maximum allowed age of request (default 5 min)
116
+
117
+ Returns:
118
+ True if signature is valid and request is fresh
119
+ """
120
+ # Check timestamp freshness
121
+ current_time = int(time.time())
122
+ if abs(current_time - timestamp) > max_age_seconds:
123
+ logger.warning(f"Request timestamp too old: {timestamp} vs current {current_time}")
124
+ return False
125
+
126
+ # Recompute expected signature
127
+ expected = generate_hmac_signature(api_key, secret_key, timestamp, body)
128
+
129
+ # Constant-time comparison to prevent timing attacks
130
+ return hmac.compare_digest(expected, signature)
131
+
132
+
133
+ # =============================================================================
134
+ # Policy Cache for CACHED fallback mode
135
+ # =============================================================================
136
+
137
+ class CachedDecision:
138
+ """Represents a cached execute decision."""
139
+
140
+ def __init__(
141
+ self,
142
+ decision: str,
143
+ policy_id: str | None = None,
144
+ ttl_seconds: float = 300.0,
145
+ policy_version: int | None = None,
146
+ ):
147
+ self.decision = decision
148
+ self.policy_id = policy_id
149
+ self.cached_at = time.monotonic()
150
+ self.ttl_seconds = ttl_seconds
151
+ # Phase 5 #5.2: dedicated field, not a `ttl_seconds` repurpose.
152
+ self.policy_version = policy_version
153
+
154
+ def is_expired(self) -> bool:
155
+ return time.monotonic() - self.cached_at > self.ttl_seconds
156
+
157
+
158
+ class PolicyCache:
159
+ """
160
+ LRU cache for execute decisions. Used in CACHED fallback mode.
161
+
162
+ Cache key is (organization_id, policy_version) to prevent cache thrashing.
163
+ At 1000+ users with unique workflow_ids, keying by tool caused constant eviction.
164
+ Now we key by organization + policy version, so all tools in an organization share
165
+ the same policy cached entry until the policy version changes.
166
+ """
167
+
168
+ def __init__(self, maxsize: int = 1000, ttl_seconds: float = 300.0):
169
+ self._cache: OrderedDict[str, CachedDecision] = OrderedDict()
170
+ self._maxsize = maxsize
171
+ self._ttl = ttl_seconds
172
+ self._hits = 0
173
+ self._misses = 0
174
+
175
+ def get(self, key: str) -> CachedDecision | None:
176
+ decision = self._cache.get(key)
177
+ if decision is None:
178
+ self._misses += 1
179
+ return None
180
+ if decision.is_expired():
181
+ del self._cache[key]
182
+ self._misses += 1
183
+ return None
184
+ self._cache.move_to_end(key)
185
+ self._hits += 1
186
+ return decision
187
+
188
+ def set(self, key: str, decision: str, policy_id: str = None, policy_version: int = None) -> None:
189
+ if key in self._cache:
190
+ self._cache.move_to_end(key)
191
+ elif len(self._cache) >= self._maxsize:
192
+ self._cache.popitem(last=False)
193
+ # Phase 5 #5.2: pass policy_version as a dedicated field.
194
+ # The previous implementation wrote it into ttl_seconds, which
195
+ # corrupted the cache-lifetime check (see plan #5.2).
196
+ self._cache[key] = CachedDecision(
197
+ decision=decision,
198
+ policy_id=policy_id,
199
+ ttl_seconds=self._ttl,
200
+ policy_version=policy_version,
201
+ )
202
+
203
+ def make_key(self, organization_id: str, policy_version: int = None) -> str:
204
+ """Generate cache key from organization_id and policy_version."""
205
+ if policy_version is not None:
206
+ return f"{organization_id}:{policy_version}"
207
+ return f"{organization_id}:0" # Default to version 0 if not provided
208
+
209
+ def get_stats(self) -> dict:
210
+ """Get cache statistics for observability."""
211
+ total = self._hits + self._misses
212
+ hit_rate = self._hits / total if total > 0 else 0.0
213
+ return {
214
+ "size": len(self._cache),
215
+ "hits": self._hits,
216
+ "misses": self._misses,
217
+ "hit_rate": hit_rate,
218
+ }
219
+
220
+ def __len__(self) -> int:
221
+ return len(self._cache)
222
+
223
+
224
+
225
+
226
+ def _signed_request_body(payload: dict[str, Any]) -> bytes:
227
+ """Serialise a JSON payload to the canonical bytes the HMAC
228
+ signature is computed over.
229
+
230
+ All three signed POST call sites (``_send_batch_with_retry_info``,
231
+ ``Transport.execute``, ``Transport.check``) MUST serialise via this
232
+ helper and pass the result with ``content=body`` to
233
+ ``httpx.Client.post``. Sending via ``json=...`` lets httpx
234
+ re-serialise with its default compact separators, which produces
235
+ a body that does NOT match the body the HMAC signature was
236
+ computed over. The Rust server at
237
+ ``backend/src/auth/hmac.rs:466-518`` is strict -- it recomputes
238
+ ``sha256(body)`` from the raw wire bytes and rejects with 401
239
+ on mismatch.
240
+ """
241
+ return json.dumps(payload, separators=(",", ":")).encode("utf-8")
242
+
243
+ # =============================================================================
244
+ # Retry with exponential backoff + jitter
245
+ # =============================================================================
246
+
247
+ """
248
+ Retry with exponential backoff + jitter + Retry-After header support
249
+ """
250
+
251
+ def _retry_with_backoff(
252
+ func: Callable[[], Any],
253
+ max_retries: int = 3,
254
+ base_delay: float = 0.5,
255
+ max_delay: float = 30.0,
256
+ backoff_factor: float = 2.0,
257
+ jitter: float = 0.1,
258
+ last_retry_after_seconds: float = 0.0,
259
+ on_transport_error: str | Callable[[Exception], dict[str, Any]] | None = None,
260
+ ) -> Any:
261
+ """
262
+ Retry with exponential backoff and jitter, honoring Retry-After header.
263
+
264
+ When Retry-After is provided (from backend 429 response), use it directly
265
+ instead of exponential backoff to prevent retry storms.
266
+
267
+ Formula (without Retry-After): delay = min(base_delay * backoff_factor^attempt, max_delay)
268
+ delay += random.uniform(-jitter * delay, jitter * delay)
269
+ Formula (with Retry-After): actual_delay = min(last_retry_after_seconds, max_delay)
270
+ """
271
+ last_exc: Exception | None = None
272
+
273
+ for attempt in range(max_retries + 1):
274
+ try:
275
+ result = func()
276
+
277
+ if hasattr(result, "status_code"):
278
+ if result.status_code == 401:
279
+ raise NullRunAuthenticationError("Invalid API key")
280
+ if result.status_code >= 500 and on_transport_error == "raise":
281
+ # Round 3 (Phase 0.4.0): 5xx is a classified
282
+ # GATEWAY_ERROR. Don't retry -- this is a server
283
+ # bug, not a network blip. Only raise when the
284
+ # caller has opted into the typed-error contract
285
+ # via on_transport_error="raise".
286
+ raise NullRunTransportError(
287
+ f"Gateway returned {result.status_code}",
288
+ source=TransportErrorSource.GATEWAY_ERROR,
289
+ endpoint="execute",
290
+ status_code=result.status_code,
291
+ )
292
+ if result.status_code >= 400:
293
+ result.raise_for_status()
294
+
295
+ return result
296
+
297
+ except (BreakerTransportError, NullRunAuthenticationError, NullRunTransportError):
298
+ raise
299
+
300
+ except Exception as exc:
301
+ last_exc = exc
302
+ # Sprint 3 follow-up (B24): bump ``last_error`` so the
303
+ # operator can read the most recent failure type without
304
+ # grepping logs. The string is the exception class
305
+ # name plus the message — short, searchable, and
306
+ # doesn't leak request bodies.
307
+ metrics.set_transport("last_error", f"{type(exc).__name__}: {exc}")
308
+ # ``timeouts`` is a specific subcategory of retry
309
+ # trigger — distinguished so an SRE can alert on
310
+ # ``timeouts > N per minute`` separately from
311
+ # generic 5xx retries.
312
+ if isinstance(exc, (httpx.TimeoutException, httpx.ConnectTimeout, httpx.ReadTimeout)):
313
+ metrics.inc_transport("timeouts")
314
+
315
+ if attempt >= max_retries:
316
+ break
317
+
318
+ # Bump ``retries_total`` for every retry attempt
319
+ # (not for the final failure). The counter is
320
+ # distinct from the final BreakerTransportError —
321
+ # it measures how often the SDK had to retry
322
+ # because the backend was flaky.
323
+ metrics.inc_transport("retries_total")
324
+
325
+ # Honor Retry-After from backend if present (from 429 response)
326
+ if last_retry_after_seconds > 0:
327
+ actual_delay = min(last_retry_after_seconds, max_delay)
328
+ # Reset after use so next retry uses exponential backoff
329
+ last_retry_after_seconds = 0.0
330
+ logger.warning(
331
+ "Request failed (attempt %d/%d), honoring Retry-After %.2fs: %s",
332
+ attempt + 1,
333
+ max_retries + 1,
334
+ actual_delay,
335
+ type(exc).__name__,
336
+ )
337
+ else:
338
+ delay = min(base_delay * (backoff_factor ** attempt), max_delay)
339
+ jitter_amount = delay * jitter
340
+ # Standard jitter for retry delay -- not crypto-sensitive
341
+ actual_delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
342
+ actual_delay = max(0.0, actual_delay)
343
+ logger.warning(
344
+ "Request failed (attempt %d/%d), retrying in %.2fs: %s",
345
+ attempt + 1,
346
+ max_retries + 1,
347
+ actual_delay,
348
+ type(exc).__name__,
349
+ )
350
+
351
+ time.sleep(actual_delay)
352
+
353
+ raise BreakerTransportError(
354
+ f"Request failed after {max_retries + 1} attempts"
355
+ ) from last_exc
356
+
357
+ # =============================================================================
358
+ # Fallback Modes (Phase 1 - SDK Resilience)
359
+ # =============================================================================
360
+
361
+ class FallbackMode:
362
+ """
363
+ SDK behavior when Gateway is unavailable.
364
+
365
+ This is CRITICAL for production - Gateway unavailability should NOT
366
+ block agent execution, but behavior must be defined and logged.
367
+ """
368
+ # Block if Gateway unavailable (for critical tools)
369
+ STRICT = "strict"
370
+ # Allow if Gateway unavailable, log locally (DEFAULT)
371
+ PERMISSIVE = "permissive"
372
+ # Use cached decision if Gateway unavailable
373
+ CACHED = "cached"
374
+
375
+
376
+ class DecisionSource:
377
+ """
378
+ Where the decision originated - for provenance tracking.
379
+ """
380
+ GATEWAY = "gateway"
381
+ CACHED = "cached"
382
+ FALLBACK = "fallback"
383
+ LOCAL = "local"
384
+
385
+
386
+ @dataclass
387
+ class FlushConfig:
388
+ """Configuration for transport flush behavior."""
389
+ batch_size: int = 50
390
+ flush_interval: float = 5.0 # seconds
391
+ max_retries: int = 3
392
+ retry_delay: float = 1.0 # seconds
393
+ max_buffer_size: int = 1000 # Max events before dropping oldest
394
+ max_failed_flush: int = 10 # Circuit breaker: stop trying after this many failures
395
+
396
+
397
+ @dataclass
398
+ class ExecuteConfig:
399
+ """Configuration for execute (strict mode) behavior."""
400
+ # Fallback mode when Gateway is unavailable
401
+ fallback_mode: str = FallbackMode.PERMISSIVE
402
+ # Gateway timeout in seconds
403
+ timeout: float = 5.0
404
+ # Max retries for execute calls
405
+ max_retries: int = 2
406
+ # Cache TTL for CACHED mode (seconds)
407
+ cache_ttl: float = 60.0
408
+ # Cache max size
409
+ cache_max_size: int = 10000
410
+
411
+
412
+ class Transport:
413
+ """
414
+ HTTP transport with batching support.
415
+
416
+ Features:
417
+ - Non-blocking track() calls (append to buffer)
418
+ - Background flush at intervals or when batch_size reached
419
+ - Retry logic for failed requests
420
+ - Thread-safe for sync usage
421
+ - HMAC request signing for secure authentication
422
+ - Distributed circuit breaker via Redis for multi-worker deployments
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ api_url: str,
428
+ api_key: str | None = None,
429
+ secret_key: str | None = None,
430
+ config: FlushConfig | None = None,
431
+ redis_client: Any = None,
432
+ ):
433
+ self.api_url = api_url.rstrip("/")
434
+
435
+ # TLS enforcement: reject non-localhost HTTP URLs. The check
436
+ # must NOT be a startswith chain — that allowed homograph
437
+ # attacks (http://127.0.0.1.attacker.com, http://localhost.evil.com)
438
+ # and rejected legitimate inputs (http://[::1]:8080, http://LOCALHOST).
439
+ # We use urllib.parse.urlparse to extract the canonical hostname,
440
+ # then check the host against a small allow-list that includes the
441
+ # full IPv4 loopback range (127.0.0.0/8) and IPv6 loopback (::1).
442
+ # For IPv4 we use ``ipaddress.ip_address`` so that
443
+ # ``127.0.0.1.attacker.com`` (a string that happens to start
444
+ # with "127.") is NOT mistakenly treated as a loopback IP.
445
+ from ipaddress import ip_address
446
+ from urllib.parse import urlparse
447
+
448
+ parsed = urlparse(self.api_url)
449
+ if parsed.scheme == "http":
450
+ host = (parsed.hostname or "").lower()
451
+ allowed = host == "localhost" or host == "::1"
452
+ if not allowed:
453
+ try:
454
+ addr = ip_address(host)
455
+ allowed = addr.is_loopback
456
+ except ValueError:
457
+ allowed = False
458
+ if not allowed:
459
+ raise InsecureTransportError(
460
+ f"Insecure URL detected: {self.api_url}. "
461
+ f"HTTP is only allowed for localhost / 127.0.0.0/8 / ::1. "
462
+ f"Use https:// for production."
463
+ )
464
+
465
+ self.api_key = api_key
466
+ self.secret_key = secret_key # HMAC signing key
467
+ self.config = config or FlushConfig()
468
+ # Phase 8 #8.4: allow env-var override of batch size and
469
+ # flush interval. Useful for tuning high-throughput agents
470
+ # without subclassing.
471
+ if "NULLRUN_BATCH_SIZE" in os.environ:
472
+ try:
473
+ self.config.batch_size = int(os.environ["NULLRUN_BATCH_SIZE"])
474
+ except ValueError:
475
+ logger.warning(
476
+ "NULLRUN_BATCH_SIZE=%r is not an int; ignoring",
477
+ os.environ["NULLRUN_BATCH_SIZE"],
478
+ )
479
+ if "NULLRUN_FLUSH_INTERVAL_MS" in os.environ:
480
+ try:
481
+ self.config.flush_interval = (
482
+ int(os.environ["NULLRUN_FLUSH_INTERVAL_MS"]) / 1000.0
483
+ )
484
+ except ValueError:
485
+ logger.warning(
486
+ "NULLRUN_FLUSH_INTERVAL_MS=%r is not an int; ignoring",
487
+ os.environ["NULLRUN_FLUSH_INTERVAL_MS"],
488
+ )
489
+ self._buffer: list[dict[str, Any]] = []
490
+ self._in_flight: dict[str, dict[str, Any]] = {} # event_id -> event for retry dedup
491
+ self._lock = threading.RLock() # RLock so re-entrant acquisition (e.g.
492
+ # test fixtures that hold the lock
493
+ # while calling lock-acquiring
494
+ # methods) doesn't deadlock.
495
+ self._flush_thread: threading.Thread | None = None
496
+ self._running = False
497
+
498
+ # mTLS client certificate support
499
+ # NULLRUN_TLS_CLIENT_CERT and NULLRUN_TLS_CLIENT_KEY env vars for client cert auth
500
+ client_cert_path = os.environ.get("NULLRUN_TLS_CLIENT_CERT")
501
+ client_key_path = os.environ.get("NULLRUN_TLS_CLIENT_KEY")
502
+ ca_cert_path = os.environ.get("NULLRUN_TLS_CA_CERT") # Optional custom CA
503
+
504
+ # Build SSL configuration for mTLS
505
+ # For client cert auth: verify is a CA cert, cert is tuple of (client_cert, client_key)
506
+ verify_cert: bool | str = True
507
+ client_cert: tuple[str, str] | None = None
508
+ if client_cert_path and client_key_path:
509
+ # Client certificate authentication (mTLS)
510
+ client_cert = (client_cert_path, client_key_path)
511
+ verify_cert = ca_cert_path if ca_cert_path else True
512
+ logger.debug(f"mTLS enabled: client_cert={client_cert_path}")
513
+ elif ca_cert_path:
514
+ # Custom CA certificate only (no client cert)
515
+ verify_cert = ca_cert_path
516
+ logger.debug(f"Custom CA configured: ca_cert={ca_cert_path}")
517
+
518
+ self._client = httpx.Client(
519
+ timeout=httpx.Timeout(
520
+ connect=5.0,
521
+ read=30.0,
522
+ write=10.0,
523
+ pool=5.0,
524
+ ),
525
+ verify=verify_cert,
526
+ cert=client_cert,
527
+ limits=httpx.Limits(
528
+ max_connections=10,
529
+ max_keepalive_connections=5,
530
+ keepalive_expiry=30.0,
531
+ ),
532
+ )
533
+ self._redis_client = redis_client
534
+ self._circuit_breaker = CircuitBreaker(
535
+ failure_threshold=self.config.max_failed_flush,
536
+ recovery_timeout=30.0,
537
+ redis_client=redis_client,
538
+ name="transport",
539
+ )
540
+ self._stopped = False # Track if stop() was called
541
+ self._policy_cache = PolicyCache(
542
+ maxsize=1000,
543
+ ttl_seconds=300.0,
544
+ )
545
+ _masked = api_key[:8] + "***" if api_key and len(api_key) >= 8 else "***"
546
+ logger.debug(f"Transport initialized: api_url={self.api_url}, api_key={_masked}")
547
+
548
+ # OpenTelemetry tracer initialization (lazy - only if opentelemetry is installed)
549
+ self._tracer = None
550
+ self._propagator = None
551
+ if _OTEL_AVAILABLE:
552
+ self._tracer = trace.get_tracer("nullrun.transport")
553
+ self._propagator = TraceContextTextMapPropagator()
554
+
555
+ # Register final-flush hook via weakref.finalize so the
556
+ # callback only fires if this Transport instance is still
557
+ # alive at process exit. Replaces the previous
558
+ # ``atexit.register`` (which accumulated one handler per
559
+ # Transport in long-running deployments) and the previous
560
+ # ``signal.signal`` handler (which hijacked SIGTERM/SIGINT
561
+ # process-wide and called ``sys.exit(0)`` from inside the
562
+ # signal context). The fix contract is pinned by
563
+ # tests/test_signal_safety.py.
564
+ self._finalizer = weakref.finalize(self, self._atexit_flush_safe)
565
+
566
+ @staticmethod
567
+ def _atexit_flush_safe(_self_id: int | None = None) -> None:
568
+ """Weakref finalizer entry point.
569
+
570
+ ``weakref.finalize`` calls this with no arguments (the
571
+ reference to ``self`` has been dropped by the time the
572
+ callback fires). We cannot reach into the transport from
573
+ here — the buffer, the httpx client, and the lock are all
574
+ gone. The recommended lifecycle is to call ``stop()``
575
+ explicitly (or use ``Transport`` as a context manager).
576
+ If the caller did neither, we log a one-time DEBUG line
577
+ and return.
578
+
579
+ The staticmethod signature accepts an optional positional
580
+ arg so that ``weakref.finalize`` succeeds and so that
581
+ tests can call ``_atexit_flush_safe(id(t))`` to assert
582
+ the wrapper swallows exceptions raised by a patched
583
+ ``_atexit_flush``.
584
+ """
585
+ logger.debug(
586
+ "Transport finalizer fired without explicit stop(); "
587
+ "remaining events may be lost. Use Transport as a context "
588
+ "manager or call stop() explicitly."
589
+ )
590
+
591
+ def _persist_to_wal(self) -> None:
592
+ """Persist unflushed events to WAL file for replay on restart."""
593
+ if not self._buffer:
594
+ return
595
+ event_count = len(self._buffer)
596
+ wal_path = os.path.join(os.getcwd(), ".nullrun.wal")
597
+ with open(wal_path, "a") as f:
598
+ for event in self._buffer:
599
+ f.write(json.dumps(event) + "\n")
600
+ self._buffer.clear()
601
+ logger.debug(f"Persisted {event_count} events to WAL at {wal_path}")
602
+
603
+ def _replay_from_wal(self) -> None:
604
+ """Replay events from WAL file on startup."""
605
+ wal_path = os.path.join(os.getcwd(), ".nullrun.wal")
606
+ if not os.path.exists(wal_path):
607
+ return
608
+ events = []
609
+ with open(wal_path) as f:
610
+ for line in f:
611
+ try:
612
+ events.append(json.loads(line.strip()))
613
+ except json.JSONDecodeError:
614
+ continue
615
+ if events:
616
+ self._buffer.extend(events)
617
+ self._do_flush()
618
+ os.remove(wal_path) # Clean up WAL after successful replay
619
+ logger.info(f"Replayed {len(events)} events from WAL")
620
+
621
+ def track(self, event: dict[str, Any]) -> None:
622
+ """
623
+ Add event to buffer. Non-blocking.
624
+
625
+ Events are flushed either when batch_size is reached or
626
+ flush_interval elapses.
627
+ """
628
+ with self._lock:
629
+ # Generate event_id if not provided
630
+ if "event_id" not in event or not event["event_id"]:
631
+ event["event_id"] = str(uuid.uuid4())
632
+
633
+ # Store in-flight for retry dedup
634
+ self._in_flight[event["event_id"]] = event
635
+
636
+ self._buffer.append(event)
637
+ metrics.inc_transport("events_enqueued")
638
+
639
+ if len(self._buffer) >= self.config.batch_size:
640
+ self._do_flush_locked()
641
+
642
+ def start(self) -> None:
643
+ """Start background flush thread."""
644
+ if self._running:
645
+ return
646
+ # Replay any events from WAL that were persisted due to previous crash
647
+ self._replay_from_wal()
648
+ self._running = True
649
+ self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
650
+ self._flush_thread.start()
651
+ logger.info("Transport flush thread started")
652
+
653
+ def __enter__(self) -> "Transport":
654
+ """Context-manager entry: start the flush thread and return self.
655
+
656
+ Pairs with ``__exit__`` so callers can write
657
+ ``with Transport(...) as t:`` and rely on ``stop()`` running
658
+ on the way out. Replaces the manual ``start() / stop()`` pair
659
+ that was easy to forget in long-running services.
660
+ """
661
+ self.start()
662
+ return self
663
+
664
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
665
+ """Context-manager exit: stop the flush thread and persist WAL.
666
+
667
+ Always stops, regardless of whether the body raised. The
668
+ exception (if any) is NOT swallowed — the caller still sees
669
+ it after the with-block.
670
+ """
671
+ try:
672
+ self.stop()
673
+ except Exception as e: # noqa: BLE001 — best-effort on context exit
674
+ logger.debug(f"Transport.__exit__: stop() raised: {e}")
675
+
676
+ def stop(self, timeout: float = 10.0) -> None:
677
+ """Stop background flush thread and flush remaining events."""
678
+ self._running = False
679
+ self._stopped = True # Mark as stopped to prevent double flush
680
+ if self._flush_thread:
681
+ self._flush_thread.join(timeout=timeout)
682
+ self._do_flush() # Final flush
683
+ self._persist_to_wal() # WAL any remaining events
684
+ self._client.close()
685
+ # Detach the weakref finalizer — stop() is the canonical
686
+ # "I am done" path. After this point the finalizer will
687
+ # silently no-op even if the interpreter is still alive.
688
+ if getattr(self, "_finalizer", None) is not None and self._finalizer.alive:
689
+ self._finalizer.detach()
690
+ logger.info("Transport stopped")
691
+
692
+ def _flush_loop(self) -> None:
693
+ """Background loop that periodically flushes."""
694
+ while self._running:
695
+ time.sleep(self.config.flush_interval)
696
+ if self._running:
697
+ self._do_flush()
698
+
699
+ def _do_flush(self) -> None:
700
+ """Perform the actual flush."""
701
+ with self._lock:
702
+ self._do_flush_locked()
703
+
704
+ def _do_flush_locked(self) -> None:
705
+ """Flush under lock. Must be called with _lock held."""
706
+ if not self._buffer:
707
+ logger.debug("Buffer empty, skipping flush")
708
+ return
709
+
710
+ batch = self._buffer[:]
711
+ self._buffer.clear()
712
+ logger.debug(f"Sending batch of {len(batch)} events")
713
+
714
+ # Circuit breaker wrapped send - uses proper 3-state circuit breaker
715
+ def send_batch():
716
+ result = self._send_batch_with_retry_info(batch)
717
+ # Remove accepted events from in-flight
718
+ if result.accepted_event_ids:
719
+ for event in batch:
720
+ if event.get("event_id") in result.accepted_event_ids:
721
+ self._in_flight.pop(event.get("event_id"), None)
722
+ logger.debug(f"Flushed {len(batch)} events")
723
+ # Update metrics on successful flush (thread-safe)
724
+ metrics.inc_transport("batches_sent")
725
+ metrics.inc_transport("events_sent", len(batch))
726
+ metrics.set_transport("last_flush_at", time.monotonic())
727
+ return result
728
+
729
+ try:
730
+ self._circuit_breaker.call(send_batch)
731
+ except BreakerTransportError:
732
+ # Circuit breaker is open - re-add batch to buffer for retry later
733
+ logger.warning(
734
+ f"Circuit breaker OPEN. Batch of {len(batch)} events will be re-queued."
735
+ )
736
+ # Enforce max buffer size BEFORE re-queue to prevent unbounded growth
737
+ # Drop oldest events first to make room for new batch
738
+ available_space = self.config.max_buffer_size - len(self._buffer)
739
+ if available_space < len(batch):
740
+ overflow = len(batch) - available_space
741
+ if overflow > 0:
742
+ # Drop oldest from front (batch) since it hasn't been sent yet
743
+ logger.warning(f"Buffer overflow on CB OPEN: dropping {overflow} oldest events from pending batch")
744
+ batch = batch[overflow:] # type: ignore[assignment]
745
+ metrics.inc_transport("events_dropped", overflow)
746
+ # Append to END (not front) so oldest events are retried first
747
+ self._buffer.extend(batch)
748
+ # Update metrics on failure (thread-safe)
749
+ metrics.inc_transport("batches_failed")
750
+
751
+ def _drain_batch(self) -> list[dict[str, Any]] | None:
752
+ """Round 2 (Phase 0.4.0): public, lock-acquiring snapshot of
753
+ the current buffer. Returns ``None`` when empty.
754
+
755
+ Used by ``tests/test_buffer_invariants.py``. The full flush
756
+ logic (CB, re-queue, metrics) lives in ``_do_flush_locked``;
757
+ this method is the read-only counterpart.
758
+ """
759
+ with self._lock:
760
+ if not self._buffer:
761
+ return None
762
+ batch = list(self._buffer)
763
+ del self._buffer[:]
764
+ return batch
765
+
766
+ @dataclass
767
+ class SendResult:
768
+ accepted_event_ids: list
769
+ retry_after_ms: float | None = None
770
+ is_policy_limit: bool = False
771
+
772
+ def _add_hmac_headers(self, headers: dict[str, str], body: str) -> None:
773
+ """
774
+ Add HMAC signing headers to request.
775
+
776
+ Adds:
777
+ - X-Signature-Timestamp: Unix timestamp for freshness
778
+ - X-Signature: HMAC-SHA256(api_key, secret, timestamp, body_hash)
779
+
780
+ Only adds signature if secret_key is configured.
781
+ """
782
+ if not self.secret_key or not self.api_key:
783
+ return
784
+
785
+ timestamp = int(time.time())
786
+ signature = generate_hmac_signature(
787
+ self.api_key,
788
+ self.secret_key,
789
+ timestamp,
790
+ body,
791
+ )
792
+
793
+ headers["X-Signature-Timestamp"] = str(timestamp)
794
+ headers["X-Signature"] = signature
795
+
796
+ def _build_signed_headers(
797
+ self,
798
+ body: str | bytes | None = None,
799
+ extra: dict[str, str] | None = None,
800
+ ) -> dict[str, str]:
801
+ """Build the canonical signed-headers dict for a request.
802
+
803
+ Round 2 (Phase 0.4.0): the canonical one-call helper used
804
+ by every signed POST. Mirrors the contract the test
805
+ framework in ``tests/test_hmac_signing.py`` expects.
806
+
807
+ Always includes:
808
+ - Content-Type: application/json
809
+ - X-API-Version: __api_version__
810
+ - X-API-Key: when api_key is set
811
+
812
+ Adds HMAC signature headers when secret_key is set and a
813
+ body is provided.
814
+
815
+ ``extra`` is merged ON TOP of the defaults so callers can
816
+ override Content-Type or add custom headers.
817
+ """
818
+ headers: dict[str, str] = {
819
+ "Content-Type": "application/json",
820
+ "X-API-Version": __api_version__,
821
+ }
822
+ if self.api_key:
823
+ headers["X-API-Key"] = self.api_key
824
+ if body is not None and self.secret_key and self.api_key:
825
+ body_str = body if isinstance(body, str) else body.decode("utf-8")
826
+ timestamp = int(time.time())
827
+ signature = generate_hmac_signature(
828
+ self.api_key, self.secret_key, timestamp, body_str
829
+ )
830
+ headers["X-Signature-Timestamp"] = str(timestamp)
831
+ headers["X-Signature"] = signature
832
+ if extra:
833
+ headers.update(extra)
834
+ # Inject trace context (W3C) as well — matches the
835
+ # end-to-end behaviour of every signed POST.
836
+ self._inject_trace_context(headers)
837
+ return headers
838
+
839
+ def _inject_trace_context(self, headers: dict[str, str]) -> None:
840
+ """
841
+ Inject trace context into request headers (W3C Trace Context format).
842
+
843
+ This enables distributed tracing across SDK and backend.
844
+ Uses W3C Trace Context standard for trace_id propagation.
845
+ """
846
+ if not _OTEL_AVAILABLE or not self._propagator:
847
+ return
848
+
849
+ carrier: dict[str, str] = {}
850
+ self._propagator.inject(carrier)
851
+ headers.update(carrier)
852
+
853
+ def _extract_retry_after(self, response: httpx.Response) -> float | None:
854
+ """Extract Retry-After header value as seconds.
855
+
856
+ Handles both:
857
+ - Integer seconds (e.g., "30")
858
+ - HTTP-date format (e.g., "Wed, 21 Oct 2015 07:28:00 GMT")
859
+ """
860
+ retry_after = response.headers.get("Retry-After")
861
+ if not retry_after:
862
+ return None
863
+
864
+ # Try parsing as seconds (integer or float)
865
+ try:
866
+ return float(retry_after)
867
+ except ValueError:
868
+ pass
869
+
870
+ # Try parsing as HTTP datetime (RFC 7231)
871
+ try:
872
+ from email.utils import parsedate_to_datetime
873
+ dt = parsedate_to_datetime(retry_after)
874
+ from datetime import datetime, timezone
875
+ return (dt - datetime.now(timezone.utc)).total_seconds()
876
+ except Exception:
877
+ pass
878
+
879
+ return None
880
+
881
+ def _send_batch_with_retry_info(self, batch: list[dict[str, Any]]) -> 'SendResult':
882
+ """Send batch to server using batch endpoint. Returns SendResult with retry info."""
883
+ logger.debug(f"Sending batch of {len(batch)} events to {self.api_url}/api/v1/track/batch")
884
+ headers = {"Content-Type": "application/json", "X-API-Version": __api_version__}
885
+ if self.api_key:
886
+ headers["X-API-Key"] = self.api_key
887
+
888
+ # Add HMAC signature headers
889
+ body = json.dumps({"events": batch})
890
+ self._add_hmac_headers(headers, body)
891
+
892
+ # Inject trace context for distributed tracing (W3C Trace Context)
893
+ self._inject_trace_context(headers)
894
+
895
+ # Use batch endpoint for efficiency - single request for all events.
896
+ # We send ``content=body`` (the exact bytes that were HMAC-signed
897
+ # above) rather than ``json=...`` — the latter re-serialises the
898
+ # payload with httpx defaults (compact separators) and produces
899
+ # a body that does not match the body the HMAC signature was
900
+ # computed over. See plan B6.
901
+ response = self._client.post(
902
+ f"{self.api_url}/api/v1/track/batch",
903
+ content=body,
904
+ headers=headers,
905
+ )
906
+
907
+ # P0: Extract retry_after from response headers or body
908
+ retry_after_seconds: float | None = None
909
+ retry_after_ms: float | None = None
910
+ is_policy_limit = False
911
+
912
+ # Check Retry-After header (may be seconds or HTTP-date)
913
+ retry_after_seconds = self._extract_retry_after(response)
914
+
915
+ # Check response body for retry info
916
+ try:
917
+ data = response.json()
918
+ # Check for rejection info
919
+ if 'rejected' in data and data['rejected']:
920
+ rejected_info = data['rejected']
921
+ if isinstance(rejected_info, dict):
922
+ if 'retry_after_ms' in rejected_info:
923
+ retry_after_ms = rejected_info['retry_after_ms']
924
+ if 'reason' in rejected_info and rejected_info['reason'] == 'policy_limit':
925
+ is_policy_limit = True
926
+ except Exception: # noqa: S110
927
+ pass
928
+
929
+ # Store for next retry calculation (prefer header seconds, fallback to body ms)
930
+ if retry_after_seconds is not None:
931
+ self._last_retry_after_seconds = retry_after_seconds
932
+ retry_after_ms = retry_after_seconds * 1000
933
+ elif retry_after_ms is not None:
934
+ self._last_retry_after_seconds = retry_after_ms / 1000.0
935
+ else:
936
+ self._last_retry_after_seconds = 0.0
937
+ self._last_failure_policy_limit = is_policy_limit
938
+
939
+ # Handle 429 response - extract and store Retry-After before raising
940
+ if response.status_code == 429:
941
+ retry_after = self._extract_retry_after(response)
942
+ if retry_after:
943
+ self._last_retry_after_seconds = retry_after
944
+ response.raise_for_status()
945
+ response.raise_for_status()
946
+
947
+ # Process actions_taken from server response
948
+ try:
949
+ data = response.json()
950
+ actions = data.get("actions_taken", [])
951
+ for action in actions:
952
+ action_type = action.get("type", "")
953
+ workflow_id = action.get("workflow_id", "unknown")
954
+ reason = action.get("reason", "")
955
+ if action_type:
956
+ handle_action(action_type, workflow_id, reason)
957
+ except Exception as e:
958
+ logger.warning(f"Failed to process actions_taken: {e}")
959
+
960
+ # Return accepted event_ids for retry dedup
961
+ accepted_event_ids = data.get("accepted_event_ids", []) if 'data' in locals() else []
962
+ logger.debug(f"Batch track: sent {len(batch)} events")
963
+ return self.SendResult(
964
+ accepted_event_ids=accepted_event_ids,
965
+ retry_after_ms=retry_after_ms,
966
+ is_policy_limit=is_policy_limit
967
+ )
968
+
969
+ def flush_now(self) -> None:
970
+ """Force immediate flush."""
971
+ self._do_flush()
972
+
973
+ # =============================================================================
974
+ # Execute (Strict Mode) - Phase 1
975
+ # =============================================================================
976
+
977
+ def execute(
978
+ self,
979
+ organization_id: str,
980
+ execution_id: str,
981
+ trace_id: str,
982
+ tool: str,
983
+ input_data: dict[str, Any],
984
+ mode: str = "auto",
985
+ fallback_mode: str = FallbackMode.PERMISSIVE,
986
+ operation_id: str | None = None,
987
+ on_transport_error: Callable[[Exception], dict[str, Any]] | None = None,
988
+ ) -> dict[str, Any]:
989
+ """
990
+ Pre-execution policy evaluation via unified gate endpoint.
991
+
992
+ This is the PRIMARY enforcement point - decision is made BEFORE execution.
993
+ Uses /api/v1/gate endpoint for unified execute + check functionality.
994
+
995
+ Args:
996
+ organization_id: Organization identifier
997
+ execution_id: Execution identifier
998
+ trace_id: Distributed trace ID
999
+ tool: Tool to execute
1000
+ input_data: Tool input
1001
+ mode: Execution mode ("auto", "inline", "strict")
1002
+ fallback_mode: What to do if Gateway unavailable
1003
+ operation_id: Optional idempotency key
1004
+ on_transport_error: Optional callback invoked on
1005
+ ``BreakerTransportError`` (Phase 5 #5.10). When set, the
1006
+ callback's return value is returned verbatim; otherwise
1007
+ the request falls through to the ``fallback_mode``
1008
+ default. The decorator's ``_enforce_sensitive_tool``
1009
+ sets this to a closure that converts the error into a
1010
+ ``NullRunBlockedException`` (fail-CLOSED).
1011
+
1012
+ Returns:
1013
+ Dict with:
1014
+ - decision: "allow" | "block" | "flag" | "pause" | "require_approval"
1015
+ - decision_source: "gateway" | "cached" | "fallback"
1016
+ - explanation: Human-readable explanation
1017
+ - policy_version: Policy version used
1018
+ - decision_context: Context for replay (if available)
1019
+ """
1020
+ gate_request = {
1021
+ "organization_id": organization_id,
1022
+ "execution_id": execution_id,
1023
+ "trace_id": trace_id,
1024
+ "tool": tool,
1025
+ "input": input_data,
1026
+ "mode": mode,
1027
+ "operation_id": operation_id or str(uuid.uuid4()),
1028
+ }
1029
+
1030
+ headers = {"Content-Type": "application/json"}
1031
+ if self.api_key:
1032
+ headers["X-API-Key"] = self.api_key
1033
+
1034
+ # HMAC fix: serialise via the canonical-bytes helper and send
1035
+ # via content=body so the wire bytes match the signed bytes.
1036
+ # See ``_signed_request_body`` for the rationale.
1037
+ body = _signed_request_body(gate_request)
1038
+ self._add_hmac_headers(headers, body.decode("utf-8"))
1039
+
1040
+ # Inject trace context for distributed tracing (W3C Trace Context)
1041
+ self._inject_trace_context(headers)
1042
+
1043
+ def do_gate_request() -> httpx.Response:
1044
+ return self._client.post(
1045
+ f"{self.api_url}/api/v1/gate",
1046
+ content=body,
1047
+ headers=headers,
1048
+ timeout=5.0,
1049
+ )
1050
+
1051
+ # Try Gateway with retry backoff
1052
+ try:
1053
+ response = _retry_with_backoff(
1054
+ do_gate_request,
1055
+ max_retries=2,
1056
+ base_delay=0.5,
1057
+ on_transport_error=on_transport_error,
1058
+ )
1059
+
1060
+ if response.status_code == 200:
1061
+ data = response.json()
1062
+ data["decision_source"] = DecisionSource.GATEWAY
1063
+ # Cache successful decision for CACHED mode
1064
+ cache_key = self._policy_cache.make_key(
1065
+ organization_id,
1066
+ data.get("policy_version")
1067
+ )
1068
+ self._policy_cache.set(
1069
+ cache_key,
1070
+ data.get("decision", "allow"),
1071
+ data.get("policy_id"),
1072
+ data.get("policy_version")
1073
+ )
1074
+ return data # type: ignore[no-any-return]
1075
+ elif response.status_code >= 400:
1076
+ # 4xx - don't retry, return block
1077
+ return {
1078
+ "decision": "block",
1079
+ "decision_source": DecisionSource.FALLBACK,
1080
+ "explanation": f"Gateway returned {response.status_code}",
1081
+ "policy_version": 0,
1082
+ }
1083
+
1084
+ except BreakerTransportError as exc:
1085
+ # Phase 5 #5.10: ADR-008 lets callers opt into a
1086
+ # classified-error handler. Round 3 (Phase 0.4.0):
1087
+ # on_transport_error accepts both callables AND strings:
1088
+ # "raise" -> raise NullRunTransportError (classified)
1089
+ # "open" -> return synthetic allow with FALLBACK_* source
1090
+ # "closed" -> return synthetic block with FALLBACK_* source
1091
+ # callable -> call with the breaker error, return the result
1092
+ # None -> fall through to the legacy fallback-mode default
1093
+ if on_transport_error == "raise":
1094
+ # Re-raise as a classified transport error.
1095
+ raise NullRunTransportError(
1096
+ f"Gateway unreachable on /execute: {exc}",
1097
+ source=TransportErrorSource.NETWORK_ERROR,
1098
+ endpoint="execute",
1099
+ ) from exc
1100
+ if callable(on_transport_error):
1101
+ return on_transport_error(exc)
1102
+ if on_transport_error == "open":
1103
+ return {
1104
+ "decision": "allow",
1105
+ "decision_source": TransportErrorSource.NETWORK_ERROR,
1106
+ "explanation": f"Gateway unreachable: {exc}",
1107
+ "policy_version": 0,
1108
+ }
1109
+ if on_transport_error == "closed":
1110
+ return {
1111
+ "decision": "block",
1112
+ "decision_source": TransportErrorSource.NETWORK_ERROR,
1113
+ "explanation": f"Gateway unreachable: {exc}",
1114
+ "policy_version": 0,
1115
+ }
1116
+ pass # fall through to fallback mode
1117
+ except NullRunTransportError:
1118
+ raise # Already classified -- propagate as-is
1119
+ except httpx.RequestError as exc:
1120
+ # Round 3: classify httpx network errors at the call site.
1121
+ if on_transport_error == "raise":
1122
+ raise NullRunTransportError(
1123
+ f"Network error on /execute: {exc}",
1124
+ source=TransportErrorSource.NETWORK_ERROR,
1125
+ endpoint="execute",
1126
+ ) from exc
1127
+ raise
1128
+ except NullRunAuthenticationError:
1129
+ raise # Don't fall back on auth errors
1130
+
1131
+ # All attempts failed - apply fallback mode
1132
+ # Sprint 3 follow-up (B24): bump ``fallback_mode_activations``
1133
+ # every time we reach this branch (gateway unreachable).
1134
+ # The operator alerts on a spike here as a proxy for
1135
+ # backend unavailability.
1136
+ metrics.inc_transport("fallback_mode_activations")
1137
+ if fallback_mode == FallbackMode.STRICT:
1138
+ return {
1139
+ "decision": "block",
1140
+ "decision_source": DecisionSource.FALLBACK,
1141
+ "explanation": "Gateway unavailable, fallback=STRICT",
1142
+ "policy_version": 0,
1143
+ }
1144
+ elif fallback_mode == FallbackMode.CACHED:
1145
+ # Use cached decision if available
1146
+ cache_key = self._policy_cache.make_key(organization_id)
1147
+ cached = self._policy_cache.get(cache_key)
1148
+ if cached:
1149
+ logger.warning("Gateway unreachable, using cached decision for %s", tool)
1150
+ return {
1151
+ "decision": cached.decision,
1152
+ "decision_source": DecisionSource.CACHED,
1153
+ "explanation": "Gateway unavailable, using cached decision",
1154
+ "policy_version": cached.policy_version or 0,
1155
+ }
1156
+ else:
1157
+ logger.warning(
1158
+ "Gateway unreachable, no cache for %s, "
1159
+ "falling back to PERMISSIVE",
1160
+ tool
1161
+ )
1162
+ return {
1163
+ "decision": "allow",
1164
+ "decision_source": DecisionSource.FALLBACK,
1165
+ "explanation": "Gateway unavailable, no cache available",
1166
+ "policy_version": 0,
1167
+ }
1168
+ else: # PERMISSIVE (default)
1169
+ return {
1170
+ "decision": "allow",
1171
+ "decision_source": DecisionSource.FALLBACK,
1172
+ "explanation": "Gateway unavailable, fallback=PERMISSIVE",
1173
+ "policy_version": 0,
1174
+ }
1175
+
1176
+ def check(
1177
+ self,
1178
+ check_request: dict[str, Any],
1179
+ on_transport_error: Callable[[Exception], dict[str, Any]] | str | None = None,
1180
+ ) -> dict[str, Any]:
1181
+ """
1182
+ Call /api/v1/gate endpoint for pre-execution budget checking.
1183
+
1184
+ Uses the unified gate endpoint with check_type for budget validation.
1185
+ Supports idempotency via operation_id field.
1186
+
1187
+ Args:
1188
+ check_request: Dict with:
1189
+ - organization_id: Organization identifier
1190
+ - execution_id: Execution identifier
1191
+ - operation_id: Operation identifier (for idempotency)
1192
+ - check_type: "llm" or "tool"
1193
+ - model: Model name (for LLM checks)
1194
+ - tool_name: Tool name (for tool checks)
1195
+ - estimated_tokens: Token count (for LLM checks)
1196
+ - input: Optional input data
1197
+
1198
+ Returns:
1199
+ Dict with:
1200
+ - decision: "allow" | "block" | "throttle"
1201
+ - reservation_id: Optional reservation ID
1202
+ - remaining_budget_cents: Remaining budget
1203
+ - projected_cost_cents: Projected cost for this operation
1204
+ - explanations: List of explanation strings
1205
+ - suggestions: List of suggestion strings
1206
+ """
1207
+ # Convert check_request to gate_request format
1208
+ gate_request = {
1209
+ "organization_id": check_request.get("organization_id"),
1210
+ "execution_id": check_request.get("execution_id"),
1211
+ "trace_id": check_request.get("trace_id", str(uuid.uuid4())),
1212
+ "tool": check_request.get("tool_name") or check_request.get("tool"),
1213
+ "input": check_request.get("input"),
1214
+ "mode": "auto",
1215
+ "check_type": check_request.get("check_type"),
1216
+ "model": check_request.get("model"),
1217
+ "estimated_tokens": check_request.get("estimated_tokens"),
1218
+ "operation_id": check_request.get("operation_id") or str(uuid.uuid4()),
1219
+ }
1220
+
1221
+ headers = {"Content-Type": "application/json"}
1222
+ if self.api_key:
1223
+ headers["X-API-Key"] = self.api_key
1224
+ headers["X-API-Version"] = __api_version__
1225
+
1226
+ # HMAC fix: serialise via the canonical-bytes helper and send
1227
+ # via content=body so the wire bytes match the signed bytes.
1228
+ body = _signed_request_body(gate_request)
1229
+ self._add_hmac_headers(headers, body.decode("utf-8"))
1230
+
1231
+ # Inject trace context for distributed tracing (W3C Trace Context)
1232
+ self._inject_trace_context(headers)
1233
+
1234
+ try:
1235
+ response = self._client.post(
1236
+ f"{self.api_url}/api/v1/gate",
1237
+ content=body,
1238
+ headers=headers,
1239
+ timeout=5.0,
1240
+ )
1241
+
1242
+ if response.status_code == 200:
1243
+ return response.json() # type: ignore[no-any-return]
1244
+ else:
1245
+ # 4xx always -> synthetic block. 5xx only raises when
1246
+ # the caller opted into the typed-error contract via
1247
+ # on_transport_error="raise"; otherwise it's also a
1248
+ # synthetic block (legacy behaviour).
1249
+ if response.status_code >= 500 and on_transport_error == "raise":
1250
+ raise NullRunTransportError(
1251
+ f"Gateway returned {response.status_code}",
1252
+ source=TransportErrorSource.GATEWAY_ERROR,
1253
+ endpoint="check",
1254
+ status_code=response.status_code,
1255
+ )
1256
+ return {
1257
+ "decision": "block",
1258
+ "decision_source": DecisionSource.FALLBACK,
1259
+ "reservation_id": None,
1260
+ "remaining_budget_cents": 0,
1261
+ "projected_cost_cents": 0,
1262
+ "explanations": [f"Gate endpoint returned {response.status_code}"],
1263
+ "suggestions": ["Check API availability"],
1264
+ }
1265
+ except httpx.RequestError as e:
1266
+ # Round 3: classify network errors. By default fall
1267
+ # through to synthetic block (legacy); raise only when
1268
+ # the caller opted in via on_transport_error="raise".
1269
+ if on_transport_error == "raise":
1270
+ raise NullRunTransportError(
1271
+ f"Network error on /check: {e}",
1272
+ source=TransportErrorSource.NETWORK_ERROR,
1273
+ endpoint="check",
1274
+ ) from e
1275
+ logger.warning(f"Gate request failed: {e}")
1276
+ return {
1277
+ "decision": "block",
1278
+ "decision_source": DecisionSource.FALLBACK,
1279
+ "reservation_id": None,
1280
+ "remaining_budget_cents": 0,
1281
+ "projected_cost_cents": 0,
1282
+ "explanations": [f"Gate request failed: {e}"],
1283
+ "suggestions": ["Check API availability"],
1284
+ }
1285
+
1286
+ # =============================================================================
1287
+ # WebSocket Connection (Task 6 - WebSocket Push)
1288
+ # =============================================================================
1289
+
1290
+ def clear_policy_cache(self) -> None:
1291
+ """Clear the policy cache, forcing next gate/execute to fetch fresh policy."""
1292
+ if hasattr(self, '_policy_cache'):
1293
+ self._policy_cache._cache.clear()
1294
+ logger.debug("Policy cache cleared")
1295
+
1296
+ async def connect_websocket(
1297
+ self,
1298
+ organization_id: str,
1299
+ on_state_change: Callable[[dict[str, Any]], None] | None = None,
1300
+ on_policy_invalidated: Callable[[str, str, int], None] | None = None,
1301
+ on_key_rotated: Callable[[str, str, int], None] | None = None,
1302
+ ) -> "WebSocketConnection":
1303
+ """
1304
+ Connect to WebSocket control plane for real-time workflow state updates.
1305
+
1306
+ This replaces polling GET /status/{workflow_id} with WebSocket push.
1307
+ When the workflow state changes (KILL/PAUSE), the server pushes the update.
1308
+
1309
+ Args:
1310
+ organization_id: Organization identifier
1311
+ on_state_change: Optional callback for state change notifications
1312
+ on_policy_invalidated: Optional callback for policy cache invalidation.
1313
+ When called, clears local policy cache so next
1314
+ gate/execute fetches fresh policy from backend.
1315
+ Args: (organization_id, policy_id, new_version)
1316
+ on_key_rotated: Optional callback for HMAC key rotation.
1317
+ When called, should re-fetch secret_key from /auth/verify.
1318
+ Args: (organization_id, key_id, new_version)
1319
+
1320
+ Returns:
1321
+ WebSocketConnection instance
1322
+
1323
+ Raises:
1324
+ ConnectionError: If WebSocket connection fails
1325
+ """
1326
+ # Phase 6 #6.6: build the WS URL via urllib.parse instead of
1327
+ # string replace. Reject unknown schemes with a clear error.
1328
+ from urllib.parse import urlparse, urlunparse
1329
+
1330
+ from nullrun.transport_websocket import WebSocketConnection
1331
+ parsed = urlparse(self.api_url)
1332
+ if parsed.scheme not in ("http", "https"):
1333
+ raise ValueError(
1334
+ f"Unsupported scheme for control plane: {parsed.scheme!r}"
1335
+ )
1336
+ ws_scheme = "wss" if parsed.scheme == "https" else "ws"
1337
+ ws_url = urlunparse(
1338
+ parsed._replace(
1339
+ scheme=ws_scheme,
1340
+ path=f"/ws/control/{organization_id}",
1341
+ params="",
1342
+ query="",
1343
+ fragment="",
1344
+ )
1345
+ )
1346
+
1347
+ headers = {"Content-Type": "application/json"}
1348
+ if self.api_key:
1349
+ headers["X-API-Key"] = self.api_key
1350
+
1351
+ # Wrap the policy invalidated callback to clear local cache
1352
+ async def wrapped_policy_invalidated(ws_id: str, policy_id: str, new_version: int) -> None:
1353
+ logger.info(f"Policy {policy_id} invalidated (v{new_version}), clearing policy cache")
1354
+ self.clear_policy_cache()
1355
+ if on_policy_invalidated:
1356
+ on_policy_invalidated(ws_id, policy_id, new_version)
1357
+
1358
+ # Wrap the key rotated callback to re-fetch credentials
1359
+ async def wrapped_key_rotated(ws_id: str, key_id: str, new_version: int) -> None:
1360
+ logger.info(f"Key {key_id} rotated (v{new_version}), re-fetching credentials")
1361
+ await self._refetch_credentials()
1362
+ if on_key_rotated:
1363
+ on_key_rotated(ws_id, key_id, new_version)
1364
+
1365
+ conn = WebSocketConnection(
1366
+ url=ws_url,
1367
+ headers=headers,
1368
+ api_key=self.api_key,
1369
+ secret_key=self.secret_key,
1370
+ on_state_change=on_state_change,
1371
+ on_policy_invalidated=wrapped_policy_invalidated,
1372
+ on_key_rotated=wrapped_key_rotated,
1373
+ )
1374
+ await conn.connect()
1375
+ return conn
1376
+
1377
+ async def _refetch_credentials(self) -> None:
1378
+ """
1379
+ Re-fetch credentials from /auth/verify after key rotation.
1380
+
1381
+ This is called when the server notifies us via WebSocket that
1382
+ our HMAC secret_key has been rotated. We need to get the new
1383
+ secret_key from the /auth/verify endpoint.
1384
+
1385
+ Sprint 2.4 (B20): the previous implementation used
1386
+ ``import requests`` and bypassed every transport-layer
1387
+ invariant — the shared ``httpx.Client`` (mTLS, connection
1388
+ pool), the circuit breaker, the HMAC body signature, and
1389
+ the retry policy. It also pulled in ``requests`` as a new
1390
+ dependency that is not in ``pyproject.toml`` (a runtime
1391
+ ImportError waiting to happen on any environment where
1392
+ ``requests`` is not installed transitively).
1393
+
1394
+ Post-fix: route through ``self._client`` so the same TLS
1395
+ configuration, connection pool, and HMAC signing path
1396
+ apply. Body is serialised via ``_signed_request_body`` so
1397
+ the wire bytes match the signed bytes.
1398
+ """
1399
+ try:
1400
+ payload = {"api_key": self.api_key}
1401
+ body = _signed_request_body(payload)
1402
+ headers: dict[str, str] = {
1403
+ "Content-Type": "application/json",
1404
+ "X-API-Key": self.api_key or "",
1405
+ }
1406
+ # Re-use the same HMAC headers as /gate and /track so
1407
+ # the server's auth-verify path is consistent.
1408
+ self._add_hmac_headers(headers, body.decode("utf-8"))
1409
+
1410
+ response = self._client.post(
1411
+ f"{self.api_url}/auth/verify",
1412
+ content=body,
1413
+ headers=headers,
1414
+ timeout=10.0,
1415
+ )
1416
+ if response.status_code == 200:
1417
+ data = response.json()
1418
+ new_secret = data.get("secret_key")
1419
+ if new_secret:
1420
+ logger.info("Successfully fetched new secret_key from /auth/verify")
1421
+ self.secret_key = new_secret
1422
+ else:
1423
+ logger.warning("/auth/verify did not return secret_key in response")
1424
+ else:
1425
+ logger.warning(f"Failed to refetch credentials: {response.status_code}")
1426
+ except Exception as e:
1427
+ logger.error(f"Error refetching credentials: {e}")
1428
+
1429
+
1430
+ def _parse_error_envelope(
1431
+ response: httpx.Response,
1432
+ endpoint: str,
1433
+ ) -> Exception:
1434
+ """Translate a non-2xx ``httpx.Response`` into the right exception
1435
+ subclass per the canonical ``contracts/errors.ts`` envelope.
1436
+
1437
+ 4xx/5xx/429 are mapped to distinct ``RateLimitError`` /
1438
+ ``NullRunAuthenticationError`` / ``NullRunTransportError(GATEWAY_ERROR)``
1439
+ so callers branch on type instead of string-matching ``str(exc)``.
1440
+
1441
+ Module-level helper (not a Transport method) so it can be called
1442
+ from background threads that do not carry a Transport instance.
1443
+ """
1444
+ status = response.status_code
1445
+ try:
1446
+ body = response.json()
1447
+ except Exception:
1448
+ body = None
1449
+ if not isinstance(body, dict):
1450
+ body = {}
1451
+ error_slug: str = body.get("error", "") or ""
1452
+ message: str = (
1453
+ body.get("message")
1454
+ or response.text
1455
+ or f"HTTP {status}"
1456
+ )
1457
+
1458
+ if status in (401, 403):
1459
+ return NullRunAuthenticationError(
1460
+ f"Auth failed on {endpoint} (status {status}, "
1461
+ f"error={error_slug!r}): {message}"
1462
+ )
1463
+
1464
+ if status == 429:
1465
+ retry_after: float | None = None
1466
+ ra_header = response.headers.get("Retry-After")
1467
+ if ra_header:
1468
+ try:
1469
+ retry_after = float(ra_header)
1470
+ except ValueError:
1471
+ try:
1472
+ from datetime import datetime, timezone
1473
+ from email.utils import parsedate_to_datetime
1474
+ dt = parsedate_to_datetime(ra_header)
1475
+ retry_after = (
1476
+ dt - datetime.now(timezone.utc)
1477
+ ).total_seconds()
1478
+ except Exception:
1479
+ retry_after = None
1480
+ upgrade_url = body.get("upgrade_url") if isinstance(body, dict) else None
1481
+ return RateLimitError(
1482
+ f"Rate limited on {endpoint} (status 429, error={error_slug!r}): "
1483
+ f"{message}",
1484
+ source=TransportErrorSource.GATEWAY_ERROR,
1485
+ endpoint=endpoint,
1486
+ retry_after=retry_after,
1487
+ upgrade_url=upgrade_url,
1488
+ body=body,
1489
+ )
1490
+
1491
+ if 500 <= status < 600:
1492
+ return NullRunTransportError(
1493
+ f"Gateway error on {endpoint} (status {status}, "
1494
+ f"error={error_slug!r}): {message}",
1495
+ source=TransportErrorSource.GATEWAY_ERROR,
1496
+ endpoint=endpoint,
1497
+ status_code=status,
1498
+ error_slug=error_slug,
1499
+ )
1500
+
1501
+ return NullRunTransportError(
1502
+ f"Client error on {endpoint} (status {status}, "
1503
+ f"error={error_slug!r}): {message}",
1504
+ source=TransportErrorSource.GATEWAY_ERROR,
1505
+ endpoint=endpoint,
1506
+ status_code=status,
1507
+ error_slug=error_slug,
1508
+ )
1509
+