bylaw-python 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bylaw_python/client.py ADDED
@@ -0,0 +1,1595 @@
1
+ # Ledgix ALCV — Client
2
+ # Sync + async HTTP client for Vault communication and A-JWT verification
3
+
4
+ from __future__ import annotations
5
+
6
+ import base64
7
+ import hashlib
8
+ import json
9
+ import math
10
+ import random
11
+ import struct
12
+ import threading
13
+ import time
14
+ import uuid
15
+ from collections.abc import Awaitable, Callable
16
+ from typing import Any
17
+ from urllib.parse import urlencode
18
+
19
+ import httpx
20
+ import jwt
21
+
22
+ _RETRYABLE_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
23
+
24
+ # Vault's proactive backpressure (Scale & Reliability §2.1) emits 429 +
25
+ # Retry-After when its clearance queue is past the configured watermark. We
26
+ # honor the header verbatim (capped to MAX_RETRY_AFTER_SECONDS so a misbehaving
27
+ # server can't pin the SDK for minutes), and we do NOT count these waves
28
+ # against max_retries — they're cooperative backoff, not transport failures.
29
+ # A separate ceiling MAX_CONSECUTIVE_429 prevents an infinite loop if the
30
+ # Vault is genuinely melting.
31
+ MAX_RETRY_AFTER_SECONDS: float = 60.0
32
+ MAX_CONSECUTIVE_429: int = 10
33
+
34
+ from .config import VaultConfig
35
+ from .exceptions import (
36
+ ClearanceDeniedError,
37
+ ManualReviewTimeoutError,
38
+ PolicyRegistrationError,
39
+ QueueSaturatedError,
40
+ ReplayDetectedError,
41
+ ReviewPendingError,
42
+ TokenVerificationError,
43
+ VaultConnectionError,
44
+ )
45
+
46
+
47
+ def _parse_retry_after(value: str | None) -> float | None:
48
+ """Parse a Retry-After header value. Vault emits seconds; the HTTP spec
49
+ also allows HTTP-date but we don't need that today. Returns None on parse
50
+ failure so callers fall back to jittered backoff."""
51
+ if not value:
52
+ return None
53
+ try:
54
+ secs = float(value.strip())
55
+ except (TypeError, ValueError):
56
+ return None
57
+ if secs < 0:
58
+ return None
59
+ return min(secs, MAX_RETRY_AFTER_SECONDS)
60
+ from .pending import PendingApproval
61
+ from .models import (
62
+ ClearanceRequest,
63
+ ClearanceResponse,
64
+ ConsistencyProof,
65
+ InclusionProof,
66
+ LedgerEntry,
67
+ LedgerCheckpoint,
68
+ LedgerKeyVersion,
69
+ LedgerProofBundle,
70
+ LedgerManifest,
71
+ LedgerVerificationResult,
72
+ PolicyRegistration,
73
+ PolicyRegistrationResponse,
74
+ _MISSING,
75
+ )
76
+
77
+
78
+ class LedgixClient:
79
+ """Sync + async client for the ALCV Vault.
80
+
81
+ Usage (sync)::
82
+
83
+ client = LedgixClient()
84
+ resp = client.request_clearance(ClearanceRequest(tool_name="stripe_refund", tool_args={"amount": 45}))
85
+
86
+ Usage (async)::
87
+
88
+ client = LedgixClient()
89
+ resp = await client.arequest_clearance(ClearanceRequest(tool_name="stripe_refund", tool_args={"amount": 45}))
90
+ """
91
+
92
+ def __init__(self, config: VaultConfig | None = None, *, _parent_jti: str | None = None) -> None:
93
+ self.config = config or VaultConfig()
94
+ self._parent_jti = _parent_jti
95
+ self._sync_client: httpx.Client | None = None
96
+ self._async_client: httpx.AsyncClient | None = None
97
+ # JWKS cache: maps kid -> JWK dict. _jwks_fetched_at tracks when it was last populated.
98
+ self._jwks_cache: dict[str, Any] | None = None # raw JWKS response
99
+ self._jwks_keys_by_kid: dict[str, Any] = {} # kid -> JWK entry, rebuilt on each fetch
100
+ self._jwks_fetched_at: float = 0.0
101
+ self._decision_cache: Any = None # cachetools.TTLCache or None
102
+ self._decision_cache_lock = threading.Lock()
103
+ from cachetools import TTLCache # already a hard dep via decision_cache path
104
+ self._seen_jtis: TTLCache = TTLCache(
105
+ maxsize=self.config.replay_cache_size,
106
+ ttl=self.config.max_token_lifetime_seconds,
107
+ )
108
+ self._seen_jtis_lock = threading.Lock()
109
+ # Async lock is created lazily because constructing asyncio.Lock outside
110
+ # a running event loop is fine in 3.10+ but we still defer to be safe
111
+ # across event-loop swaps in the same process.
112
+ self._jwks_async_lock: Any = None
113
+ self._jwks_sync_lock = threading.Lock()
114
+ if self.config.decision_cache_enabled:
115
+ self._decision_cache = TTLCache(
116
+ maxsize=self.config.decision_cache_max_entries,
117
+ ttl=self.config.decision_cache_ttl_seconds,
118
+ )
119
+
120
+ # ------------------------------------------------------------------
121
+ # Internal HTTP helpers
122
+ # ------------------------------------------------------------------
123
+
124
+ def _headers(self) -> dict[str, str]:
125
+ headers: dict[str, str] = {"Content-Type": "application/json"}
126
+ if self.config.vault_api_key:
127
+ headers["X-Vault-API-Key"] = self.config.vault_api_key
128
+ return headers
129
+
130
+ def _get_sync_client(self) -> httpx.Client:
131
+ if self._sync_client is None or self._sync_client.is_closed:
132
+ self._sync_client = httpx.Client(
133
+ base_url=self.config.vault_url,
134
+ headers=self._headers(),
135
+ timeout=self.config.vault_timeout,
136
+ )
137
+ return self._sync_client
138
+
139
+ def _get_async_client(self) -> httpx.AsyncClient:
140
+ if self._async_client is None or self._async_client.is_closed:
141
+ self._async_client = httpx.AsyncClient(
142
+ base_url=self.config.vault_url,
143
+ headers=self._headers(),
144
+ timeout=self.config.vault_timeout,
145
+ )
146
+ return self._async_client
147
+
148
+ # ------------------------------------------------------------------
149
+ # Retry helpers
150
+ # ------------------------------------------------------------------
151
+
152
+ def _backoff_delay(self, attempt: int) -> float:
153
+ """Exponential backoff with full jitter, capped at 30 seconds."""
154
+ delay = min(30.0, self.config.retry_base_delay * (2 ** attempt))
155
+ return random.uniform(0.0, delay)
156
+
157
+ def _sync_retry(self, fn: Callable[[], httpx.Response]) -> httpx.Response:
158
+ """Execute an HTTP callable with retry and exponential backoff.
159
+
160
+ Retries on ``httpx.TransportError`` (network errors, timeouts) and on
161
+ retryable HTTP status codes (5xx). 429 responses honor the
162
+ ``Retry-After`` header (Vault backpressure §2.1) and do NOT consume
163
+ the ``max_retries`` budget — they're cooperative backoff, not
164
+ transport failures. After ``MAX_CONSECUTIVE_429`` waves with no
165
+ success the SDK gives up with ``QueueSaturatedError``.
166
+
167
+ Raises ``VaultConnectionError`` after all transport attempts are
168
+ exhausted.
169
+ """
170
+ attempt = 0
171
+ consecutive_429 = 0
172
+ last_retry_after: float | None = None
173
+ while True:
174
+ try:
175
+ response = fn()
176
+ except httpx.TransportError as exc:
177
+ if attempt < self.config.max_retries:
178
+ time.sleep(self._backoff_delay(attempt))
179
+ attempt += 1
180
+ continue
181
+ raise VaultConnectionError(str(exc)) from exc
182
+
183
+ if response.status_code == 429:
184
+ # Treat 429 as cooperative backoff: don't consume the retry
185
+ # budget, sleep for the server-requested duration (or fall
186
+ # back to jitter if no header). Bound the loop separately.
187
+ consecutive_429 += 1
188
+ if consecutive_429 > MAX_CONSECUTIVE_429:
189
+ raise QueueSaturatedError(consecutive_429 - 1, last_retry_after)
190
+ retry_after = _parse_retry_after(response.headers.get("Retry-After"))
191
+ if retry_after is not None:
192
+ last_retry_after = retry_after
193
+ time.sleep(retry_after)
194
+ else:
195
+ time.sleep(self._backoff_delay(attempt))
196
+ continue
197
+
198
+ # Reset 429 streak on any non-429 response — a single success in
199
+ # between resets the SDK's "is the queue dying?" signal.
200
+ consecutive_429 = 0
201
+
202
+ if response.status_code in _RETRYABLE_STATUS_CODES and attempt < self.config.max_retries:
203
+ time.sleep(self._backoff_delay(attempt))
204
+ attempt += 1
205
+ continue
206
+ return response
207
+
208
+ async def _async_retry(self, fn: Callable[[], Awaitable[httpx.Response]]) -> httpx.Response:
209
+ """Async variant of ``_sync_retry``. Same semantics."""
210
+ import asyncio
211
+
212
+ attempt = 0
213
+ consecutive_429 = 0
214
+ last_retry_after: float | None = None
215
+ while True:
216
+ try:
217
+ response = await fn()
218
+ except httpx.TransportError as exc:
219
+ if attempt < self.config.max_retries:
220
+ await asyncio.sleep(self._backoff_delay(attempt))
221
+ attempt += 1
222
+ continue
223
+ raise VaultConnectionError(str(exc)) from exc
224
+
225
+ if response.status_code == 429:
226
+ consecutive_429 += 1
227
+ if consecutive_429 > MAX_CONSECUTIVE_429:
228
+ raise QueueSaturatedError(consecutive_429 - 1, last_retry_after)
229
+ retry_after = _parse_retry_after(response.headers.get("Retry-After"))
230
+ if retry_after is not None:
231
+ last_retry_after = retry_after
232
+ await asyncio.sleep(retry_after)
233
+ else:
234
+ await asyncio.sleep(self._backoff_delay(attempt))
235
+ continue
236
+
237
+ consecutive_429 = 0
238
+
239
+ if response.status_code in _RETRYABLE_STATUS_CODES and attempt < self.config.max_retries:
240
+ await asyncio.sleep(self._backoff_delay(attempt))
241
+ attempt += 1
242
+ continue
243
+ return response
244
+
245
+ # ------------------------------------------------------------------
246
+ # Decision cache helpers
247
+ # ------------------------------------------------------------------
248
+
249
+ def _enrich_request(self, request: ClearanceRequest) -> ClearanceRequest:
250
+ """Return request with human_principal, parent_jti, and counterparty
251
+ destination_* fields filled in from config/instance defaults / hints.
252
+ Caller-supplied destination_* always wins over the inferred values."""
253
+ updates: dict[str, Any] = {}
254
+ if request.human_principal is None and self.config.principal_id:
255
+ updates["human_principal"] = self.config.principal_id
256
+ if request.parent_jti is None and self._parent_jti:
257
+ updates["parent_jti"] = self._parent_jti
258
+ if (
259
+ request.destination_uri is None
260
+ or request.destination_provider is None
261
+ or request.destination_account_ref is None
262
+ ):
263
+ from .counterparty import extract as _extract_counterparty
264
+
265
+ inferred = _extract_counterparty(request.tool_name, request.tool_args)
266
+ if request.destination_uri is None and "destination_uri" in inferred:
267
+ updates["destination_uri"] = inferred["destination_uri"]
268
+ if request.destination_provider is None and "destination_provider" in inferred:
269
+ updates["destination_provider"] = inferred["destination_provider"]
270
+ if request.destination_account_ref is None and "destination_account_ref" in inferred:
271
+ updates["destination_account_ref"] = inferred["destination_account_ref"]
272
+ return request.model_copy(update=updates) if updates else request
273
+
274
+ def create_delegated_client(self, parent_jti: str) -> "LedgixClient":
275
+ """Return a new client that auto-injects *parent_jti* on every clearance request.
276
+
277
+ The returned client shares the same ``VaultConfig`` but does not share
278
+ HTTP connections or the decision cache, so it is safe to use concurrently.
279
+ """
280
+ return LedgixClient(config=self.config, _parent_jti=parent_jti)
281
+
282
+ def _build_cache_key(self, request: ClearanceRequest) -> str:
283
+ """Return a stable hex cache key for a clearance request, or '' if not cacheable."""
284
+ try:
285
+ canonical_args = json.dumps(
286
+ request.tool_args or {},
287
+ sort_keys=True,
288
+ separators=(",", ":"),
289
+ default=str,
290
+ )
291
+ except Exception:
292
+ return ""
293
+ if len(canonical_args) > 65_536:
294
+ return ""
295
+ agent_id = request.agent_id or self.config.agent_id or ""
296
+ policy_id = (request.context or {}).get("policy_id") or ""
297
+ material = f"{agent_id}\x00{request.tool_name}\x00{canonical_args}\x00{policy_id}"
298
+ return hashlib.sha256(material.encode()).hexdigest()
299
+
300
+ def _cache_get(self, key: str) -> dict[str, Any] | None:
301
+ if self._decision_cache is None or not key:
302
+ return None
303
+ with self._decision_cache_lock:
304
+ return self._decision_cache.get(key)
305
+
306
+ def _cache_put(self, key: str, envelope: dict[str, Any]) -> None:
307
+ if self._decision_cache is None or not key:
308
+ return
309
+ with self._decision_cache_lock:
310
+ self._decision_cache[key] = envelope
311
+
312
+ def clear_cache(self) -> None:
313
+ """Flush all cached decision envelopes."""
314
+ if self._decision_cache is None:
315
+ return
316
+ with self._decision_cache_lock:
317
+ self._decision_cache.clear()
318
+
319
+ @staticmethod
320
+ def _is_cacheable(clearance: ClearanceResponse) -> bool:
321
+ return (
322
+ clearance.decision_status == "approved"
323
+ and clearance.status == "approved"
324
+ and bool(clearance.policy_version_id)
325
+ and clearance.token is not None
326
+ )
327
+
328
+ def _make_envelope(self, clearance: ClearanceResponse) -> dict[str, Any]:
329
+ return {
330
+ "decision_status": clearance.decision_status,
331
+ "reason": clearance.reason,
332
+ "policy_version_id": clearance.policy_version_id or "",
333
+ "policy_content_hash": clearance.policy_content_hash or "",
334
+ "confidence_bucket": clearance.confidence_bucket,
335
+ "minimum_confidence_bucket": clearance.minimum_confidence_bucket,
336
+ "original_request_id": clearance.request_id,
337
+ }
338
+
339
+ def _mint_token(self, request: ClearanceRequest, envelope: dict[str, Any]) -> ClearanceResponse:
340
+ """Call /mint-token to get a fresh A-JWT from a cached decision envelope (sync)."""
341
+ mint_body = {
342
+ "tool_name": request.tool_name,
343
+ "tool_args": request.tool_args or {},
344
+ "agent_id": request.agent_id or self.config.agent_id or "",
345
+ "session_id": request.session_id or self.config.session_id or "",
346
+ "policy_id": (request.context or {}).get("policy_id") or "",
347
+ "policy_version_id": envelope["policy_version_id"],
348
+ "policy_content_hash": envelope["policy_content_hash"],
349
+ "original_request_id": envelope["original_request_id"],
350
+ "confidence_bucket": envelope["confidence_bucket"],
351
+ "reason": envelope["reason"],
352
+ "human_principal": request.human_principal or self.config.principal_id,
353
+ "destination_uri": request.destination_uri or "",
354
+ "destination_provider": request.destination_provider or "",
355
+ "destination_account_ref": request.destination_account_ref or "",
356
+ "data_categories": request.data_categories or [],
357
+ "purpose": request.purpose or "",
358
+ "processing_register_ref": request.processing_register_ref or "",
359
+ "dataset_ref": request.dataset_ref or "",
360
+ }
361
+ idem_headers = {"Idempotency-Key": str(uuid.uuid4())}
362
+ try:
363
+ response = self._sync_retry(
364
+ lambda: self._get_sync_client().post("/mint-token", json=mint_body, headers=idem_headers)
365
+ )
366
+ response.raise_for_status()
367
+ except httpx.HTTPStatusError as exc:
368
+ raise VaultConnectionError(
369
+ f"Vault /mint-token returned HTTP {exc.response.status_code}: {exc.response.text}"
370
+ ) from exc
371
+ data = response.json()
372
+ return ClearanceResponse(
373
+ status="approved",
374
+ decision_status="approved",
375
+ requires_manual_review=False,
376
+ token=data.get("token"),
377
+ reason=data.get("reason", envelope["reason"]),
378
+ request_id=data.get("request_id", ""),
379
+ confidence_bucket=envelope["confidence_bucket"],
380
+ minimum_confidence_bucket=envelope.get("minimum_confidence_bucket", "high"),
381
+ policy_version_id=envelope["policy_version_id"],
382
+ policy_content_hash=envelope["policy_content_hash"],
383
+ )
384
+
385
+ async def _amint_token(self, request: ClearanceRequest, envelope: dict[str, Any]) -> ClearanceResponse:
386
+ """Call /mint-token to get a fresh A-JWT from a cached decision envelope (async)."""
387
+ mint_body = {
388
+ "tool_name": request.tool_name,
389
+ "tool_args": request.tool_args or {},
390
+ "agent_id": request.agent_id or self.config.agent_id or "",
391
+ "session_id": request.session_id or self.config.session_id or "",
392
+ "policy_id": (request.context or {}).get("policy_id") or "",
393
+ "policy_version_id": envelope["policy_version_id"],
394
+ "policy_content_hash": envelope["policy_content_hash"],
395
+ "original_request_id": envelope["original_request_id"],
396
+ "confidence_bucket": envelope["confidence_bucket"],
397
+ "reason": envelope["reason"],
398
+ "human_principal": request.human_principal or self.config.principal_id,
399
+ "destination_uri": request.destination_uri or "",
400
+ "destination_provider": request.destination_provider or "",
401
+ "destination_account_ref": request.destination_account_ref or "",
402
+ "data_categories": request.data_categories or [],
403
+ "purpose": request.purpose or "",
404
+ "processing_register_ref": request.processing_register_ref or "",
405
+ "dataset_ref": request.dataset_ref or "",
406
+ }
407
+ idem_headers = {"Idempotency-Key": str(uuid.uuid4())}
408
+ try:
409
+ response = await self._async_retry(
410
+ lambda: self._get_async_client().post("/mint-token", json=mint_body, headers=idem_headers)
411
+ )
412
+ response.raise_for_status()
413
+ except httpx.HTTPStatusError as exc:
414
+ raise VaultConnectionError(
415
+ f"Vault /mint-token returned HTTP {exc.response.status_code}: {exc.response.text}"
416
+ ) from exc
417
+ data = response.json()
418
+ return ClearanceResponse(
419
+ status="approved",
420
+ decision_status="approved",
421
+ requires_manual_review=False,
422
+ token=data.get("token"),
423
+ reason=data.get("reason", envelope["reason"]),
424
+ request_id=data.get("request_id", ""),
425
+ confidence_bucket=envelope["confidence_bucket"],
426
+ minimum_confidence_bucket=envelope.get("minimum_confidence_bucket", "high"),
427
+ policy_version_id=envelope["policy_version_id"],
428
+ policy_content_hash=envelope["policy_content_hash"],
429
+ )
430
+
431
+ # ------------------------------------------------------------------
432
+ # Clearance — sync
433
+ # ------------------------------------------------------------------
434
+
435
+ def request_clearance(self, request: ClearanceRequest) -> ClearanceResponse:
436
+ """Send a clearance request to the Vault (sync).
437
+
438
+ When the decision cache is enabled (``decision_cache_enabled=True`` in
439
+ ``VaultConfig``), an approved response is memoized. Subsequent identical
440
+ calls skip the LLM judge and call ``/mint-token`` for a fresh A-JWT.
441
+
442
+ Raises:
443
+ ClearanceDeniedError: If the Vault denies the request.
444
+ VaultConnectionError: If the Vault is unreachable.
445
+ """
446
+ request = self._enrich_request(request)
447
+ cache_key = self._build_cache_key(request)
448
+ envelope = self._cache_get(cache_key)
449
+ if envelope is not None:
450
+ clearance = self._mint_token(request, envelope)
451
+ if self.config.verify_jwt and clearance.token:
452
+ self.verify_token(clearance.token)
453
+ return clearance
454
+
455
+ idem_headers = {"Idempotency-Key": str(uuid.uuid4())}
456
+ try:
457
+ response = self._sync_retry(
458
+ lambda: self._get_sync_client().post(
459
+ "/request-clearance",
460
+ content=request.model_dump_json(),
461
+ headers=idem_headers,
462
+ )
463
+ )
464
+ response.raise_for_status()
465
+ except httpx.HTTPStatusError as exc:
466
+ raise VaultConnectionError(
467
+ f"Vault returned HTTP {exc.response.status_code}: {exc.response.text}"
468
+ ) from exc
469
+
470
+ clearance = ClearanceResponse.model_validate(response.json())
471
+ result = self._resolve_pending_clearance(clearance)
472
+ if isinstance(result, PendingApproval):
473
+ raise ReviewPendingError(result)
474
+ clearance = result
475
+
476
+ if not clearance.is_approved:
477
+ raise ClearanceDeniedError(
478
+ reason=clearance.reason,
479
+ request_id=clearance.request_id,
480
+ )
481
+
482
+ if self.config.verify_jwt and clearance.token:
483
+ self.verify_token(clearance.token)
484
+
485
+ if self._is_cacheable(clearance):
486
+ self._cache_put(cache_key, self._make_envelope(clearance))
487
+
488
+ return clearance
489
+
490
+ # ------------------------------------------------------------------
491
+ # Clearance — async
492
+ # ------------------------------------------------------------------
493
+
494
+ async def arequest_clearance(self, request: ClearanceRequest) -> ClearanceResponse:
495
+ """Send a clearance request to the Vault (async).
496
+
497
+ When the decision cache is enabled (``decision_cache_enabled=True`` in
498
+ ``VaultConfig``), an approved response is memoized. Subsequent identical
499
+ calls skip the LLM judge and call ``/mint-token`` for a fresh A-JWT.
500
+
501
+ Raises:
502
+ ClearanceDeniedError: If the Vault denies the request.
503
+ VaultConnectionError: If the Vault is unreachable.
504
+ """
505
+ request = self._enrich_request(request)
506
+ cache_key = self._build_cache_key(request)
507
+ envelope = self._cache_get(cache_key)
508
+ if envelope is not None:
509
+ clearance = await self._amint_token(request, envelope)
510
+ if self.config.verify_jwt and clearance.token:
511
+ await self.averify_token(clearance.token)
512
+ return clearance
513
+
514
+ idem_headers = {"Idempotency-Key": str(uuid.uuid4())}
515
+ try:
516
+ response = await self._async_retry(
517
+ lambda: self._get_async_client().post(
518
+ "/request-clearance",
519
+ content=request.model_dump_json(),
520
+ headers=idem_headers,
521
+ )
522
+ )
523
+ response.raise_for_status()
524
+ except httpx.HTTPStatusError as exc:
525
+ raise VaultConnectionError(
526
+ f"Vault returned HTTP {exc.response.status_code}: {exc.response.text}"
527
+ ) from exc
528
+
529
+ clearance = ClearanceResponse.model_validate(response.json())
530
+ result = await self._aresolve_pending_clearance(clearance)
531
+ if isinstance(result, PendingApproval):
532
+ raise ReviewPendingError(result)
533
+ clearance = result
534
+
535
+ if not clearance.is_approved:
536
+ raise ClearanceDeniedError(
537
+ reason=clearance.reason,
538
+ request_id=clearance.request_id,
539
+ )
540
+
541
+ if self.config.verify_jwt and clearance.token:
542
+ await self.averify_token(clearance.token)
543
+
544
+ if self._is_cacheable(clearance):
545
+ self._cache_put(cache_key, self._make_envelope(clearance))
546
+
547
+ return clearance
548
+
549
+ # ------------------------------------------------------------------
550
+ # Policy registration
551
+ # ------------------------------------------------------------------
552
+
553
+ def register_policy(self, policy: PolicyRegistration) -> PolicyRegistrationResponse:
554
+ """Register a policy with the Vault (sync)."""
555
+ idem_headers = {"Idempotency-Key": str(uuid.uuid4())}
556
+ try:
557
+ response = self._sync_retry(
558
+ lambda: self._get_sync_client().post(
559
+ "/register-policy",
560
+ content=policy.model_dump_json(),
561
+ headers=idem_headers,
562
+ )
563
+ )
564
+ response.raise_for_status()
565
+ except httpx.HTTPStatusError as exc:
566
+ raise PolicyRegistrationError(
567
+ f"Vault returned HTTP {exc.response.status_code}: {exc.response.text}"
568
+ ) from exc
569
+
570
+ return PolicyRegistrationResponse.model_validate(response.json())
571
+
572
+ async def aregister_policy(self, policy: PolicyRegistration) -> PolicyRegistrationResponse:
573
+ """Register a policy with the Vault (async)."""
574
+ idem_headers = {"Idempotency-Key": str(uuid.uuid4())}
575
+ try:
576
+ response = await self._async_retry(
577
+ lambda: self._get_async_client().post(
578
+ "/register-policy",
579
+ content=policy.model_dump_json(),
580
+ headers=idem_headers,
581
+ )
582
+ )
583
+ response.raise_for_status()
584
+ except httpx.HTTPStatusError as exc:
585
+ raise PolicyRegistrationError(
586
+ f"Vault returned HTTP {exc.response.status_code}: {exc.response.text}"
587
+ ) from exc
588
+
589
+ return PolicyRegistrationResponse.model_validate(response.json())
590
+
591
+ # ------------------------------------------------------------------
592
+ # JWKS + A-JWT verification
593
+ # ------------------------------------------------------------------
594
+
595
+ def _resolve_pending_clearance(
596
+ self, clearance: ClearanceResponse
597
+ ) -> ClearanceResponse | PendingApproval:
598
+ if clearance.status not in {"processing", "pending_review"}:
599
+ return clearance
600
+
601
+ if self.config.review_mode == "detach":
602
+ return PendingApproval(clearance.request_id, self, clearance)
603
+
604
+ deadline = time.monotonic() + self.config.review_timeout
605
+ while time.monotonic() < deadline:
606
+ time.sleep(self.config.review_poll_interval)
607
+ response = self._get_sync_client().get(f"/clearance-status/{clearance.request_id}")
608
+ response.raise_for_status()
609
+ clearance = ClearanceResponse.model_validate(response.json())
610
+ if clearance.status not in {"processing", "pending_review"}:
611
+ return clearance
612
+ raise ManualReviewTimeoutError(clearance.request_id)
613
+
614
+ async def _aresolve_pending_clearance(
615
+ self, clearance: ClearanceResponse
616
+ ) -> ClearanceResponse | PendingApproval:
617
+ if clearance.status not in {"processing", "pending_review"}:
618
+ return clearance
619
+
620
+ if self.config.review_mode == "detach":
621
+ return PendingApproval(clearance.request_id, self, clearance)
622
+
623
+ import asyncio
624
+
625
+ deadline = time.monotonic() + self.config.review_timeout
626
+ while time.monotonic() < deadline:
627
+ await asyncio.sleep(self.config.review_poll_interval)
628
+ response = await self._get_async_client().get(f"/clearance-status/{clearance.request_id}")
629
+ response.raise_for_status()
630
+ clearance = ClearanceResponse.model_validate(response.json())
631
+ if clearance.status not in {"processing", "pending_review"}:
632
+ return clearance
633
+ raise ManualReviewTimeoutError(clearance.request_id)
634
+
635
+ def fetch_jwks(self) -> dict[str, Any]:
636
+ """Fetch the Vault's JWKS (JSON Web Key Set) for token verification (sync).
637
+
638
+ Serialized with a threading.Lock + double-check so concurrent threads
639
+ verifying tokens only trigger one network round-trip.
640
+ """
641
+ start_fetched_at = self._jwks_fetched_at
642
+ with self._jwks_sync_lock:
643
+ if self._jwks_fetched_at > start_fetched_at and self._jwks_cache is not None:
644
+ return self._jwks_cache
645
+ try:
646
+ response = self._sync_retry(
647
+ lambda: self._get_sync_client().get("/.well-known/jwks.json")
648
+ )
649
+ response.raise_for_status()
650
+ except httpx.HTTPStatusError as exc:
651
+ raise VaultConnectionError(
652
+ f"Failed to fetch JWKS: HTTP {exc.response.status_code}"
653
+ ) from exc
654
+
655
+ self._jwks_cache = response.json()
656
+ self._index_jwks_by_kid(self._jwks_cache)
657
+ return self._jwks_cache
658
+
659
+ async def afetch_jwks(self) -> dict[str, Any]:
660
+ """Fetch the Vault's JWKS for token verification (async).
661
+
662
+ Protected by an asyncio.Lock + double-check on _jwks_fetched_at so a
663
+ thundering herd of concurrent verify_token() callers only triggers one
664
+ network round-trip.
665
+ """
666
+ import asyncio
667
+
668
+ if self._jwks_async_lock is None:
669
+ self._jwks_async_lock = asyncio.Lock()
670
+ start_fetched_at = self._jwks_fetched_at
671
+ async with self._jwks_async_lock:
672
+ # Double-check: another coroutine may have refetched while we waited.
673
+ if self._jwks_fetched_at > start_fetched_at and self._jwks_cache is not None:
674
+ return self._jwks_cache
675
+ try:
676
+ response = await self._async_retry(
677
+ lambda: self._get_async_client().get("/.well-known/jwks.json")
678
+ )
679
+ response.raise_for_status()
680
+ except httpx.HTTPStatusError as exc:
681
+ raise VaultConnectionError(
682
+ f"Failed to fetch JWKS: HTTP {exc.response.status_code}"
683
+ ) from exc
684
+
685
+ self._jwks_cache = response.json()
686
+ self._index_jwks_by_kid(self._jwks_cache)
687
+ return self._jwks_cache
688
+
689
+ def _index_jwks_by_kid(self, jwks: dict[str, Any]) -> None:
690
+ """Rebuild the kid→JWK index from the raw JWKS response."""
691
+ keys_by_kid: dict[str, Any] = {}
692
+ for key in jwks.get("keys", []):
693
+ kid = key.get("kid")
694
+ if kid:
695
+ keys_by_kid[kid] = key
696
+ # Also expose all keys under a sentinel so kid-less tokens can still
697
+ # fall back to the first key (legacy; Vault always sets kid today).
698
+ if not keys_by_kid and jwks.get("keys"):
699
+ keys_by_kid["__default__"] = jwks["keys"][0]
700
+ self._jwks_keys_by_kid = keys_by_kid
701
+ self._jwks_fetched_at = time.monotonic()
702
+
703
+ def fetch_ledger(self, limit: int = 100) -> list[LedgerEntry]:
704
+ """Fetch recent ledger entries for the authenticated tenant (sync)."""
705
+ query = urlencode({"limit": max(1, min(limit, 500))})
706
+ try:
707
+ response = self._sync_retry(lambda: self._get_sync_client().get(f"/ledger?{query}"))
708
+ response.raise_for_status()
709
+ except httpx.HTTPStatusError as exc:
710
+ raise VaultConnectionError(
711
+ f"Failed to fetch ledger: HTTP {exc.response.status_code}"
712
+ ) from exc
713
+
714
+ payload = response.json()
715
+ return [LedgerEntry.model_validate(item) for item in payload.get("entries", [])]
716
+
717
+ async def afetch_ledger(self, limit: int = 100) -> list[LedgerEntry]:
718
+ """Fetch recent ledger entries for the authenticated tenant (async)."""
719
+ query = urlencode({"limit": max(1, min(limit, 500))})
720
+ try:
721
+ response = await self._async_retry(lambda: self._get_async_client().get(f"/ledger?{query}"))
722
+ response.raise_for_status()
723
+ except httpx.HTTPStatusError as exc:
724
+ raise VaultConnectionError(
725
+ f"Failed to fetch ledger: HTTP {exc.response.status_code}"
726
+ ) from exc
727
+
728
+ payload = response.json()
729
+ return [LedgerEntry.model_validate(item) for item in payload.get("entries", [])]
730
+
731
+ def fetch_ledger_checkpoints(self, limit: int = 24) -> list[LedgerCheckpoint]:
732
+ """Fetch recent signed ledger checkpoints for the authenticated tenant (sync)."""
733
+ query = urlencode({"limit": max(1, min(limit, 500))})
734
+ try:
735
+ response = self._sync_retry(
736
+ lambda: self._get_sync_client().get(f"/ledger/checkpoints?{query}")
737
+ )
738
+ response.raise_for_status()
739
+ except httpx.HTTPStatusError as exc:
740
+ raise VaultConnectionError(
741
+ f"Failed to fetch ledger checkpoints: HTTP {exc.response.status_code}"
742
+ ) from exc
743
+
744
+ payload = response.json()
745
+ return [LedgerCheckpoint.model_validate(item) for item in payload.get("checkpoints", [])]
746
+
747
+ async def afetch_ledger_checkpoints(self, limit: int = 24) -> list[LedgerCheckpoint]:
748
+ """Fetch recent signed ledger checkpoints for the authenticated tenant (async)."""
749
+ query = urlencode({"limit": max(1, min(limit, 500))})
750
+ try:
751
+ response = await self._async_retry(
752
+ lambda: self._get_async_client().get(f"/ledger/checkpoints?{query}")
753
+ )
754
+ response.raise_for_status()
755
+ except httpx.HTTPStatusError as exc:
756
+ raise VaultConnectionError(
757
+ f"Failed to fetch ledger checkpoints: HTTP {exc.response.status_code}"
758
+ ) from exc
759
+
760
+ payload = response.json()
761
+ return [LedgerCheckpoint.model_validate(item) for item in payload.get("checkpoints", [])]
762
+
763
+ def fetch_ledger_manifests(self, limit: int = 24) -> list[LedgerManifest]:
764
+ return self.fetch_ledger_checkpoints(limit)
765
+
766
+ async def afetch_ledger_manifests(self, limit: int = 24) -> list[LedgerManifest]:
767
+ return await self.afetch_ledger_checkpoints(limit)
768
+
769
+ def fetch_ledger_inclusion_proof(self, request_id: str) -> InclusionProof:
770
+ response = self._sync_retry(
771
+ lambda: self._get_sync_client().get(f"/ledger/proof/inclusion?request_id={request_id}")
772
+ )
773
+ response.raise_for_status()
774
+ return InclusionProof.model_validate(response.json())
775
+
776
+ async def afetch_ledger_inclusion_proof(self, request_id: str) -> InclusionProof:
777
+ response = await self._async_retry(
778
+ lambda: self._get_async_client().get(f"/ledger/proof/inclusion?request_id={request_id}")
779
+ )
780
+ response.raise_for_status()
781
+ return InclusionProof.model_validate(response.json())
782
+
783
+ def fetch_ledger_consistency_proof(self, from_checkpoint_id: int, to_checkpoint_id: int) -> ConsistencyProof:
784
+ response = self._sync_retry(
785
+ lambda: self._get_sync_client().get(
786
+ f"/ledger/proof/consistency?from={from_checkpoint_id}&to={to_checkpoint_id}"
787
+ )
788
+ )
789
+ response.raise_for_status()
790
+ return ConsistencyProof.model_validate(response.json())
791
+
792
+ async def afetch_ledger_consistency_proof(self, from_checkpoint_id: int, to_checkpoint_id: int) -> ConsistencyProof:
793
+ response = await self._async_retry(
794
+ lambda: self._get_async_client().get(
795
+ f"/ledger/proof/consistency?from={from_checkpoint_id}&to={to_checkpoint_id}"
796
+ )
797
+ )
798
+ response.raise_for_status()
799
+ return ConsistencyProof.model_validate(response.json())
800
+
801
+ def fetch_ledger_proof_bundle(self, request_id: str) -> LedgerProofBundle:
802
+ response = self._sync_retry(
803
+ lambda: self._get_sync_client().get(f"/ledger/proof/bundle?request_id={request_id}")
804
+ )
805
+ response.raise_for_status()
806
+ return LedgerProofBundle.model_validate(response.json())
807
+
808
+ async def afetch_ledger_proof_bundle(self, request_id: str) -> LedgerProofBundle:
809
+ response = await self._async_retry(
810
+ lambda: self._get_async_client().get(f"/ledger/proof/bundle?request_id={request_id}")
811
+ )
812
+ response.raise_for_status()
813
+ return LedgerProofBundle.model_validate(response.json())
814
+
815
+ def verify_ledger_proof(
816
+ self,
817
+ entries: list[LedgerEntry | dict[str, Any]] | None = None,
818
+ manifests: list[LedgerManifest | dict[str, Any]] | None = None,
819
+ ) -> LedgerVerificationResult:
820
+ """Verify ledger event receipts and checkpoint signatures offline using the Vault JWKS."""
821
+ entries = (
822
+ [item if isinstance(item, LedgerEntry) else LedgerEntry.model_validate(item) for item in entries]
823
+ if entries is not None
824
+ else self.fetch_ledger()
825
+ )
826
+ checkpoints = (
827
+ [item if isinstance(item, LedgerCheckpoint) else LedgerCheckpoint.model_validate(item) for item in manifests]
828
+ if manifests is not None
829
+ else self.fetch_ledger_checkpoints()
830
+ )
831
+ if self._jwks_cache is None:
832
+ self.fetch_jwks()
833
+ return self._verify_ledger_proof(entries, checkpoints)
834
+
835
+ async def averify_ledger_proof(
836
+ self,
837
+ entries: list[LedgerEntry | dict[str, Any]] | None = None,
838
+ manifests: list[LedgerManifest | dict[str, Any]] | None = None,
839
+ ) -> LedgerVerificationResult:
840
+ """Async variant of ``verify_ledger_proof``."""
841
+ entries = (
842
+ [item if isinstance(item, LedgerEntry) else LedgerEntry.model_validate(item) for item in entries]
843
+ if entries is not None
844
+ else await self.afetch_ledger()
845
+ )
846
+ checkpoints = (
847
+ [item if isinstance(item, LedgerCheckpoint) else LedgerCheckpoint.model_validate(item) for item in manifests]
848
+ if manifests is not None
849
+ else await self.afetch_ledger_checkpoints()
850
+ )
851
+ if self._jwks_cache is None:
852
+ await self.afetch_jwks()
853
+ return self._verify_ledger_proof(entries, checkpoints)
854
+
855
+ def verify_ledger_proof_bundle(
856
+ self,
857
+ bundle: LedgerProofBundle | dict[str, Any],
858
+ ) -> LedgerVerificationResult:
859
+ proof_bundle = (
860
+ bundle
861
+ if isinstance(bundle, LedgerProofBundle)
862
+ else LedgerProofBundle.model_validate(bundle)
863
+ )
864
+ if not proof_bundle.keys and self._jwks_cache is None:
865
+ self.fetch_jwks()
866
+ return self._verify_ledger_proof_bundle(proof_bundle)
867
+
868
+ async def averify_ledger_proof_bundle(
869
+ self,
870
+ bundle: LedgerProofBundle | dict[str, Any],
871
+ ) -> LedgerVerificationResult:
872
+ proof_bundle = (
873
+ bundle
874
+ if isinstance(bundle, LedgerProofBundle)
875
+ else LedgerProofBundle.model_validate(bundle)
876
+ )
877
+ if not proof_bundle.keys and self._jwks_cache is None:
878
+ await self.afetch_jwks()
879
+ return self._verify_ledger_proof_bundle(proof_bundle)
880
+
881
+ def verify_token(self, token: str) -> dict[str, Any]:
882
+ """Verify an A-JWT using the Vault's public key (sync).
883
+
884
+ Returns the decoded token payload on success.
885
+
886
+ Raises:
887
+ TokenVerificationError: If the token is invalid, expired, or
888
+ the JWKS cannot be fetched.
889
+ ReplayDetectedError: If this jti has already been consumed.
890
+ """
891
+ kid = self._peek_token_kid(token)
892
+ if self._jwks_cache is None or not self._has_key(kid):
893
+ self.fetch_jwks()
894
+ return self._decode_token(token)
895
+
896
+ async def averify_token(self, token: str) -> dict[str, Any]:
897
+ """Verify an A-JWT using the Vault's public key (async).
898
+
899
+ Raises:
900
+ TokenVerificationError: If the token is invalid, expired, or
901
+ the JWKS cannot be fetched.
902
+ ReplayDetectedError: If this jti has already been consumed.
903
+ """
904
+ kid = self._peek_token_kid(token)
905
+ if self._jwks_cache is None or not self._has_key(kid):
906
+ await self.afetch_jwks()
907
+ return self._decode_token(token)
908
+
909
+ def _peek_token_kid(self, token: str) -> str | None:
910
+ """Return the kid header of a token without verifying its signature."""
911
+ try:
912
+ header = jwt.get_unverified_header(token)
913
+ return header.get("kid")
914
+ except jwt.exceptions.DecodeError:
915
+ return None
916
+
917
+ def _has_key(self, kid: str | None) -> bool:
918
+ """Return True if the given kid is indexed in the current JWKS cache."""
919
+ if not self._jwks_keys_by_kid:
920
+ return False
921
+ if kid:
922
+ return kid in self._jwks_keys_by_kid
923
+ return "__default__" in self._jwks_keys_by_kid
924
+
925
+ def _decode_token(self, token: str) -> dict[str, Any]:
926
+ """Verify an A-JWT against the cached JWKS and check jti replay.
927
+
928
+ Security invariants enforced here:
929
+ - Kid matching: the token's `kid` header selects an explicit JWK from the
930
+ JWKS; unknown kids are rejected fail-closed (no wildcard fallback).
931
+ - Algorithm pinned to EdDSA — RS256/HS256 confusion attacks are impossible.
932
+ - jti replay: every jti is tracked in a TTL cache for max_token_lifetime_seconds.
933
+ A missing or re-presented jti raises ReplayDetectedError immediately.
934
+
935
+ JWKS must already be populated before calling this method.
936
+ Raises TokenVerificationError on signature / claim failures.
937
+ Raises ReplayDetectedError if the jti has already been consumed.
938
+ """
939
+ if not self._jwks_cache:
940
+ raise TokenVerificationError("No JWKS available from Vault")
941
+ if not self._jwks_keys_by_kid:
942
+ raise TokenVerificationError("JWKS contains no keys")
943
+
944
+ try:
945
+ kid = self._peek_token_kid(token)
946
+ # Select the key: prefer exact kid match, fall back to __default__ for
947
+ # kid-less tokens (legacy), fail-closed otherwise.
948
+ key_data = self._jwks_keys_by_kid.get(kid) if kid else None
949
+ if key_data is None:
950
+ key_data = self._jwks_keys_by_kid.get("__default__")
951
+ if key_data is None:
952
+ raise TokenVerificationError(
953
+ f"A-JWT kid={kid!r} not found in JWKS — key may have rotated; "
954
+ "refetch JWKS or upgrade Vault"
955
+ )
956
+
957
+ public_key = jwt.algorithms.OKPAlgorithm.from_jwk(json.dumps(key_data))
958
+
959
+ decoded = jwt.decode(
960
+ token,
961
+ public_key,
962
+ algorithms=["EdDSA"],
963
+ audience=self.config.jwt_audience,
964
+ issuer=self.config.jwt_issuer,
965
+ options={"verify_exp": True, "require": ["exp", "iss", "aud", "sub"]},
966
+ )
967
+ if decoded.get("sub") != "clearance":
968
+ raise TokenVerificationError("Invalid A-JWT: unexpected subject")
969
+
970
+ except TokenVerificationError:
971
+ raise
972
+ except jwt.ExpiredSignatureError as exc:
973
+ raise TokenVerificationError("A-JWT has expired") from exc
974
+ except jwt.InvalidTokenError as exc:
975
+ raise TokenVerificationError(f"Invalid A-JWT: {exc}") from exc
976
+
977
+ # jti replay detection — fail-closed: a missing jti is rejected.
978
+ jti = decoded.get("jti")
979
+ if not jti:
980
+ raise TokenVerificationError("A-JWT missing jti claim")
981
+ with self._seen_jtis_lock:
982
+ if jti in self._seen_jtis:
983
+ raise ReplayDetectedError(jti)
984
+ # cachetools.TTLCache evicts on access; inserting with a dummy value
985
+ # records the jti for max_token_lifetime_seconds.
986
+ self._seen_jtis[jti] = True
987
+
988
+ return decoded
989
+
990
+ def _verify_ledger_proof(
991
+ self,
992
+ entries: list[LedgerEntry],
993
+ checkpoints: list[LedgerCheckpoint],
994
+ key_records: list[dict[str, Any]] | None = None,
995
+ ) -> LedgerVerificationResult:
996
+ verification_keys = key_records or self._resolve_verification_keys()
997
+ if not verification_keys:
998
+ return LedgerVerificationResult(
999
+ intact=False,
1000
+ verified_entries=0,
1001
+ verified_checkpoints=0,
1002
+ verified_manifests=0,
1003
+ latest_leaf_hash=None,
1004
+ latest_checkpoint_hash=None,
1005
+ latest_manifest_hash=None,
1006
+ coverage_note=None,
1007
+ error="No JWKS available from Vault",
1008
+ )
1009
+
1010
+ try:
1011
+ key_cache: dict[str, Any] = {}
1012
+
1013
+ def key_for_kid(kid: str) -> Any:
1014
+ if kid in key_cache:
1015
+ return key_cache[kid]
1016
+ match = next(
1017
+ (
1018
+ item
1019
+ for item in verification_keys
1020
+ if isinstance(item, dict) and item.get("kid") == kid
1021
+ ),
1022
+ None,
1023
+ )
1024
+ if match is None:
1025
+ raise TokenVerificationError(f"No public key found for kid {kid}")
1026
+ public_key = jwt.algorithms.OKPAlgorithm.from_jwk(json.dumps(match))
1027
+ key_cache[kid] = public_key
1028
+ return public_key
1029
+
1030
+ sorted_entries = sorted(entries, key=lambda item: item.seq)
1031
+ sequenced_entries = sorted(
1032
+ (entry for entry in sorted_entries if entry.leaf_index is not None),
1033
+ key=lambda item: item.leaf_index or 0,
1034
+ )
1035
+
1036
+ latest_leaf_hash: str | None = None
1037
+ coverage_notes: list[str] = []
1038
+ redacted_entry_count = 0
1039
+ for entry in sorted_entries:
1040
+ if self._has_protected_event_fields(entry):
1041
+ expected_event_hash = self._build_event_hash(entry)
1042
+ if expected_event_hash != entry.event_hash:
1043
+ raise TokenVerificationError(f"Ledger event hash mismatch at seq {entry.seq}")
1044
+ else:
1045
+ redacted_entry_count += 1
1046
+ expected_leaf_hash = self._hash_leaf(entry.event_hash)
1047
+ if expected_leaf_hash != entry.leaf_hash:
1048
+ raise TokenVerificationError(f"Ledger leaf hash mismatch at seq {entry.seq}")
1049
+ if entry.receipt_algorithm != "Ed25519":
1050
+ raise TokenVerificationError(
1051
+ f"Unsupported ledger receipt algorithm {entry.receipt_algorithm}"
1052
+ )
1053
+ if not entry.receipt_payload or not entry.receipt_signature or not entry.receipt_key_id:
1054
+ raise TokenVerificationError(f"Missing receipt proof data at seq {entry.seq}")
1055
+ payload_bytes = self._decode_base64url(entry.receipt_payload)
1056
+ rebuilt_payload = self._build_receipt_payload(entry)
1057
+ if payload_bytes != rebuilt_payload:
1058
+ raise TokenVerificationError(f"Ledger receipt payload mismatch at seq {entry.seq}")
1059
+ key_for_kid(entry.receipt_key_id).verify(
1060
+ self._decode_base64url(entry.receipt_signature),
1061
+ payload_bytes,
1062
+ )
1063
+ latest_leaf_hash = entry.leaf_hash
1064
+ if redacted_entry_count > 0:
1065
+ coverage_notes.append(
1066
+ "Event-body hash recomputation was skipped for "
1067
+ f"{redacted_entry_count} redacted public ledger entr"
1068
+ f"{'y' if redacted_entry_count == 1 else 'ies'}; receipt and checkpoint proofs still verified."
1069
+ )
1070
+
1071
+ sorted_checkpoints = sorted(checkpoints, key=lambda item: item.checkpoint_id)
1072
+ previous_checkpoint_hash = ""
1073
+ latest_checkpoint_hash: str | None = None
1074
+ for checkpoint in sorted_checkpoints:
1075
+ if checkpoint.prev_checkpoint_hash != previous_checkpoint_hash:
1076
+ raise TokenVerificationError(
1077
+ f"Ledger checkpoint chain broken at checkpoint {checkpoint.checkpoint_id}"
1078
+ )
1079
+ if checkpoint.signature_algorithm != "Ed25519":
1080
+ raise TokenVerificationError(
1081
+ f"Unsupported checkpoint signature algorithm {checkpoint.signature_algorithm}"
1082
+ )
1083
+ if (
1084
+ not checkpoint.checkpoint_payload
1085
+ or not checkpoint.checkpoint_signature
1086
+ or not checkpoint.signer_key_id
1087
+ ):
1088
+ raise TokenVerificationError(
1089
+ f"Missing checkpoint proof data at checkpoint {checkpoint.checkpoint_id}"
1090
+ )
1091
+ payload_bytes = self._decode_base64url(checkpoint.checkpoint_payload)
1092
+ rebuilt_payload = self._build_checkpoint_payload(checkpoint)
1093
+ if payload_bytes != rebuilt_payload:
1094
+ raise TokenVerificationError(
1095
+ f"Ledger checkpoint payload mismatch at checkpoint {checkpoint.checkpoint_id}"
1096
+ )
1097
+ if self._hash_checkpoint_payload(payload_bytes) != checkpoint.checkpoint_hash:
1098
+ raise TokenVerificationError(
1099
+ f"Ledger checkpoint hash mismatch at checkpoint {checkpoint.checkpoint_id}"
1100
+ )
1101
+ key_for_kid(checkpoint.signer_key_id).verify(
1102
+ self._decode_base64url(checkpoint.checkpoint_signature),
1103
+ payload_bytes,
1104
+ )
1105
+ previous_checkpoint_hash = checkpoint.checkpoint_hash
1106
+ latest_checkpoint_hash = checkpoint.checkpoint_hash
1107
+
1108
+ coverage_note: str | None = None
1109
+ if sorted_checkpoints:
1110
+ latest_checkpoint = sorted_checkpoints[-1]
1111
+ if len(sequenced_entries) == latest_checkpoint.tree_size:
1112
+ root_hash = self._merkle_root([entry.leaf_hash for entry in sequenced_entries])
1113
+ if root_hash != latest_checkpoint.root_hash:
1114
+ raise TokenVerificationError(
1115
+ "Latest checkpoint root does not match sequenced leaf hashes"
1116
+ )
1117
+ else:
1118
+ coverage_notes.append(
1119
+ f"Provided {len(sequenced_entries)} sequenced entries for tree size "
1120
+ f"{latest_checkpoint.tree_size}; full root verification requires the complete covered set."
1121
+ )
1122
+ if coverage_notes:
1123
+ coverage_note = " ".join(coverage_notes)
1124
+ return LedgerVerificationResult(
1125
+ intact=True,
1126
+ verified_entries=len(sorted_entries),
1127
+ verified_checkpoints=len(sorted_checkpoints),
1128
+ verified_manifests=len(sorted_checkpoints),
1129
+ latest_leaf_hash=latest_leaf_hash,
1130
+ latest_checkpoint_hash=latest_checkpoint_hash,
1131
+ latest_manifest_hash=latest_checkpoint_hash,
1132
+ coverage_note=coverage_note,
1133
+ )
1134
+ except Exception as exc:
1135
+ return LedgerVerificationResult(
1136
+ intact=False,
1137
+ verified_entries=0,
1138
+ verified_checkpoints=0,
1139
+ verified_manifests=0,
1140
+ latest_leaf_hash=None,
1141
+ latest_checkpoint_hash=None,
1142
+ latest_manifest_hash=None,
1143
+ coverage_note=None,
1144
+ error=str(exc),
1145
+ )
1146
+
1147
+ def _verify_ledger_proof_bundle(self, bundle: LedgerProofBundle) -> LedgerVerificationResult:
1148
+ verification_keys = self._resolve_verification_keys(bundle.keys)
1149
+ if not verification_keys:
1150
+ return LedgerVerificationResult(
1151
+ intact=False,
1152
+ verified_entries=0,
1153
+ verified_checkpoints=0,
1154
+ verified_manifests=0,
1155
+ latest_leaf_hash=None,
1156
+ latest_checkpoint_hash=None,
1157
+ latest_manifest_hash=None,
1158
+ coverage_note=None,
1159
+ error="No JWKS available from Vault",
1160
+ )
1161
+
1162
+ try:
1163
+ key_cache: dict[str, Any] = {}
1164
+
1165
+ def key_for_kid(kid: str) -> Any:
1166
+ if kid in key_cache:
1167
+ return key_cache[kid]
1168
+ match = next(
1169
+ (
1170
+ item
1171
+ for item in verification_keys
1172
+ if isinstance(item, dict) and item.get("kid") == kid
1173
+ ),
1174
+ None,
1175
+ )
1176
+ if match is None:
1177
+ raise TokenVerificationError(f"No public key found for kid {kid}")
1178
+ public_key = jwt.algorithms.OKPAlgorithm.from_jwk(json.dumps(match))
1179
+ key_cache[kid] = public_key
1180
+ return public_key
1181
+
1182
+ if self._has_protected_event_fields(bundle.event):
1183
+ expected_event_hash = self._build_event_hash(bundle.event)
1184
+ if expected_event_hash != bundle.event.event_hash:
1185
+ raise TokenVerificationError("Ledger event hash mismatch in proof bundle")
1186
+ expected_leaf_hash = self._hash_leaf(bundle.event.event_hash)
1187
+ if expected_leaf_hash != bundle.event.leaf_hash:
1188
+ raise TokenVerificationError("Ledger leaf hash mismatch in proof bundle")
1189
+ if bundle.event.receipt_algorithm != "Ed25519":
1190
+ raise TokenVerificationError(
1191
+ f"Unsupported ledger receipt algorithm {bundle.event.receipt_algorithm}"
1192
+ )
1193
+ if (
1194
+ not bundle.event.receipt_payload
1195
+ or not bundle.event.receipt_signature
1196
+ or not bundle.event.receipt_key_id
1197
+ ):
1198
+ raise TokenVerificationError("Missing receipt proof data in proof bundle")
1199
+ payload_bytes = self._decode_base64url(bundle.event.receipt_payload)
1200
+ rebuilt_payload = self._build_receipt_payload(bundle.event)
1201
+ if payload_bytes != rebuilt_payload:
1202
+ raise TokenVerificationError("Ledger receipt payload mismatch in proof bundle")
1203
+ key_for_kid(bundle.event.receipt_key_id).verify(
1204
+ self._decode_base64url(bundle.event.receipt_signature),
1205
+ payload_bytes,
1206
+ )
1207
+
1208
+ checkpoints = [bundle.inclusion.checkpoint]
1209
+ if bundle.consistency:
1210
+ if (
1211
+ bundle.consistency.from_checkpoint.checkpoint_hash
1212
+ != bundle.inclusion.checkpoint.checkpoint_hash
1213
+ ):
1214
+ raise TokenVerificationError(
1215
+ "Ledger consistency proof does not match the inclusion checkpoint"
1216
+ )
1217
+ checkpoints.append(bundle.consistency.to_checkpoint)
1218
+
1219
+ for checkpoint in checkpoints:
1220
+ if checkpoint.signature_algorithm != "Ed25519":
1221
+ raise TokenVerificationError(
1222
+ f"Unsupported checkpoint signature algorithm {checkpoint.signature_algorithm}"
1223
+ )
1224
+ if (
1225
+ not checkpoint.checkpoint_payload
1226
+ or not checkpoint.checkpoint_signature
1227
+ or not checkpoint.signer_key_id
1228
+ ):
1229
+ raise TokenVerificationError(
1230
+ f"Missing checkpoint proof data at checkpoint {checkpoint.checkpoint_id}"
1231
+ )
1232
+ checkpoint_payload = self._decode_base64url(checkpoint.checkpoint_payload)
1233
+ rebuilt_checkpoint_payload = self._build_checkpoint_payload(checkpoint)
1234
+ if checkpoint_payload != rebuilt_checkpoint_payload:
1235
+ raise TokenVerificationError("Ledger checkpoint payload mismatch in proof bundle")
1236
+ if self._hash_checkpoint_payload(checkpoint_payload) != checkpoint.checkpoint_hash:
1237
+ raise TokenVerificationError("Ledger checkpoint hash mismatch in proof bundle")
1238
+ key_for_kid(checkpoint.signer_key_id).verify(
1239
+ self._decode_base64url(checkpoint.checkpoint_signature),
1240
+ checkpoint_payload,
1241
+ )
1242
+
1243
+ if not self._verify_inclusion_proof(
1244
+ bundle.event.leaf_hash,
1245
+ bundle.inclusion.leaf_index,
1246
+ bundle.inclusion.tree_size,
1247
+ bundle.inclusion.path,
1248
+ bundle.inclusion.checkpoint.root_hash,
1249
+ ):
1250
+ raise TokenVerificationError("Ledger inclusion proof is invalid")
1251
+ if bundle.consistency and not self._verify_consistency_proof(
1252
+ bundle.consistency.from_checkpoint.tree_size,
1253
+ bundle.consistency.to_checkpoint.tree_size,
1254
+ bundle.consistency.from_checkpoint.root_hash,
1255
+ bundle.consistency.to_checkpoint.root_hash,
1256
+ bundle.consistency.path,
1257
+ ):
1258
+ raise TokenVerificationError("Ledger consistency proof is invalid")
1259
+
1260
+ latest_checkpoint_hash = (
1261
+ bundle.consistency.to_checkpoint.checkpoint_hash
1262
+ if bundle.consistency
1263
+ else bundle.inclusion.checkpoint.checkpoint_hash
1264
+ )
1265
+ return LedgerVerificationResult(
1266
+ intact=True,
1267
+ verified_entries=1,
1268
+ verified_checkpoints=2 if bundle.consistency else 1,
1269
+ verified_manifests=2 if bundle.consistency else 1,
1270
+ latest_leaf_hash=bundle.event.leaf_hash,
1271
+ latest_checkpoint_hash=latest_checkpoint_hash,
1272
+ latest_manifest_hash=latest_checkpoint_hash,
1273
+ coverage_note=None,
1274
+ )
1275
+ except Exception as exc:
1276
+ return LedgerVerificationResult(
1277
+ intact=False,
1278
+ verified_entries=0,
1279
+ verified_checkpoints=0,
1280
+ verified_manifests=0,
1281
+ latest_leaf_hash=None,
1282
+ latest_checkpoint_hash=None,
1283
+ latest_manifest_hash=None,
1284
+ coverage_note=None,
1285
+ error=str(exc),
1286
+ )
1287
+
1288
+ def _resolve_verification_keys(
1289
+ self,
1290
+ embedded_keys: list[LedgerKeyVersion] | None = None,
1291
+ ) -> list[dict[str, Any]]:
1292
+ if embedded_keys:
1293
+ resolved: list[dict[str, Any]] = []
1294
+ for key_version in embedded_keys:
1295
+ if not key_version.public_jwk:
1296
+ continue
1297
+ jwk = json.loads(self._decode_base64url(key_version.public_jwk).decode("utf-8"))
1298
+ if "kid" not in jwk or not jwk["kid"]:
1299
+ jwk["kid"] = key_version.key_id
1300
+ resolved.append(jwk)
1301
+ if resolved:
1302
+ return resolved
1303
+
1304
+ jwks = self._jwks_cache
1305
+ keys = jwks.get("keys") if isinstance(jwks, dict) else None
1306
+ if isinstance(keys, list):
1307
+ return [item for item in keys if isinstance(item, dict)]
1308
+ return []
1309
+
1310
+ def _build_event_hash(self, entry: LedgerEntry) -> str:
1311
+ raw_tool_args = entry.raw_tool_args if entry.raw_tool_args is not _MISSING else entry.tool_args
1312
+ raw_action_metadata = (
1313
+ entry.raw_action_metadata if entry.raw_action_metadata is not _MISSING else entry.action_metadata
1314
+ )
1315
+ if raw_action_metadata is _MISSING:
1316
+ raw_action_metadata = {}
1317
+ raw_citations = entry.raw_citations if entry.raw_citations is not _MISSING else entry.citations
1318
+ raw_evidence_chunks = (
1319
+ entry.raw_evidence_chunks if entry.raw_evidence_chunks is not _MISSING else entry.evidence_chunks
1320
+ )
1321
+
1322
+ payload = self._encode_deterministic_cbor(
1323
+ {
1324
+ "accepted_at": entry.accepted_at,
1325
+ "action_category": entry.action_category,
1326
+ "action_metadata": self._normalize_json_numbers_for_cbor(raw_action_metadata),
1327
+ "agent_id": entry.agent_id,
1328
+ "approved": entry.approved,
1329
+ "canonical_version": entry.canonical_version,
1330
+ "citations": self._normalize_json_numbers_for_cbor(raw_citations),
1331
+ "confidence": entry.confidence,
1332
+ "event_uuid": entry.event_uuid,
1333
+ "evidence_chunks": self._normalize_json_numbers_for_cbor(raw_evidence_chunks),
1334
+ "intent_hash": entry.intent_hash,
1335
+ "policy_id": entry.policy_id,
1336
+ "policy_version_id": entry.policy_version_id,
1337
+ "policy_content_hash": entry.policy_content_hash,
1338
+ "reason": entry.reason,
1339
+ "request_id": entry.request_id,
1340
+ "tool_args": self._normalize_json_numbers_for_cbor(raw_tool_args),
1341
+ "tool_name": entry.tool_name,
1342
+ }
1343
+ )
1344
+ current_hash = self._hash_event_payload(payload)
1345
+ if current_hash == entry.event_hash:
1346
+ return current_hash
1347
+
1348
+ legacy_payload = self._encode_deterministic_cbor(
1349
+ {
1350
+ "accepted_at": entry.accepted_at,
1351
+ "agent_id": entry.agent_id,
1352
+ "approved": entry.approved,
1353
+ "canonical_version": entry.canonical_version,
1354
+ "citations": self._normalize_json_numbers_for_cbor(raw_citations),
1355
+ "confidence": entry.confidence,
1356
+ "event_uuid": entry.event_uuid,
1357
+ "evidence_chunks": self._normalize_json_numbers_for_cbor(raw_evidence_chunks),
1358
+ "intent_hash": entry.intent_hash,
1359
+ "policy_id": entry.policy_id,
1360
+ "reason": entry.reason,
1361
+ "request_id": entry.request_id,
1362
+ "tool_args": self._normalize_json_numbers_for_cbor(raw_tool_args),
1363
+ "tool_name": entry.tool_name,
1364
+ }
1365
+ )
1366
+ return self._hash_event_payload(legacy_payload)
1367
+
1368
+ def _has_protected_event_fields(self, entry: LedgerEntry) -> bool:
1369
+ return isinstance(entry.intent_hash, str) and len(entry.intent_hash) > 0
1370
+
1371
+ def _build_receipt_payload(self, entry: LedgerEntry) -> bytes:
1372
+ return self._encode_deterministic_cbor(
1373
+ {
1374
+ "accepted_at": entry.accepted_at,
1375
+ "event_hash": entry.event_hash,
1376
+ "event_uuid": entry.event_uuid,
1377
+ "leaf_hash": entry.leaf_hash,
1378
+ "receipt_key_id": entry.receipt_key_id,
1379
+ "request_id": entry.request_id,
1380
+ "type": "event_receipt",
1381
+ "version": 1,
1382
+ }
1383
+ )
1384
+
1385
+ def _build_checkpoint_payload(self, checkpoint: LedgerCheckpoint) -> bytes:
1386
+ export_targets = [checkpoint.export_target] if checkpoint.export_target else []
1387
+ return self._encode_deterministic_cbor(
1388
+ {
1389
+ "export_targets": export_targets,
1390
+ "key_id": checkpoint.signer_key_id,
1391
+ "mmd_seconds": checkpoint.mmd_seconds,
1392
+ "prev_checkpoint_hash": checkpoint.prev_checkpoint_hash,
1393
+ "root_hash": checkpoint.root_hash,
1394
+ "signed_at": checkpoint.signed_at,
1395
+ "tree_size": checkpoint.tree_size,
1396
+ "type": "checkpoint",
1397
+ "version": 1,
1398
+ }
1399
+ )
1400
+
1401
+ def _hash_event_payload(self, payload: bytes) -> str:
1402
+ return hashlib.sha256(b"ledgix.audit.event.v1\x00" + payload).hexdigest()
1403
+
1404
+ def _hash_checkpoint_payload(self, payload: bytes) -> str:
1405
+ return hashlib.sha256(b"ledgix.audit.checkpoint.v1\x00" + payload).hexdigest()
1406
+
1407
+ def _hash_leaf(self, event_hash: str) -> str:
1408
+ return hashlib.sha256(b"\x00" + bytes.fromhex(event_hash)).hexdigest()
1409
+
1410
+ def _hash_node(self, left_hash: str, right_hash: str) -> str:
1411
+ return hashlib.sha256(
1412
+ b"\x01" + bytes.fromhex(left_hash) + bytes.fromhex(right_hash)
1413
+ ).hexdigest()
1414
+
1415
+ def _merkle_root(self, leaf_hashes: list[str]) -> str:
1416
+ if not leaf_hashes:
1417
+ return ""
1418
+ return self._merkle_range_hash(leaf_hashes, 0, len(leaf_hashes))
1419
+
1420
+ def _merkle_range_hash(self, leaf_hashes: list[str], start: int, size: int) -> str:
1421
+ if size == 1:
1422
+ return leaf_hashes[start]
1423
+ split = self._largest_power_of_two_less_than(size)
1424
+ left_hash = self._merkle_range_hash(leaf_hashes, start, split)
1425
+ right_hash = self._merkle_range_hash(leaf_hashes, start + split, size - split)
1426
+ return self._hash_node(left_hash, right_hash)
1427
+
1428
+ @staticmethod
1429
+ def _largest_power_of_two_less_than(value: int) -> int:
1430
+ power = 1
1431
+ while power << 1 < value:
1432
+ power <<= 1
1433
+ return power
1434
+
1435
+ def _verify_inclusion_proof(
1436
+ self,
1437
+ leaf_hash: str,
1438
+ leaf_index: int,
1439
+ tree_size: int,
1440
+ path: list[str],
1441
+ root_hash: str,
1442
+ ) -> bool:
1443
+ fn = leaf_index
1444
+ sn = tree_size - 1
1445
+ current_hash = leaf_hash
1446
+ for sibling in path:
1447
+ if sn == 0:
1448
+ return False
1449
+ if fn % 2 == 1 or fn == sn:
1450
+ current_hash = self._hash_node(sibling, current_hash)
1451
+ while fn > 0 and fn % 2 == 0:
1452
+ fn >>= 1
1453
+ sn >>= 1
1454
+ else:
1455
+ current_hash = self._hash_node(current_hash, sibling)
1456
+ fn >>= 1
1457
+ sn >>= 1
1458
+ return current_hash == root_hash and sn == 0
1459
+
1460
+ def _verify_consistency_proof(
1461
+ self,
1462
+ first_size: int,
1463
+ second_size: int,
1464
+ first_hash: str,
1465
+ second_hash: str,
1466
+ path: list[str],
1467
+ ) -> bool:
1468
+ if first_size == second_size:
1469
+ return first_hash == second_hash
1470
+ if not path:
1471
+ return False
1472
+ working = [first_hash, *path] if self._is_power_of_two(first_size) else list(path)
1473
+ fn = first_size - 1
1474
+ sn = second_size - 1
1475
+ while fn & 1 == 1:
1476
+ fn >>= 1
1477
+ sn >>= 1
1478
+ first_root = working[0]
1479
+ second_root = working[0]
1480
+ for candidate in working[1:]:
1481
+ if sn == 0:
1482
+ return False
1483
+ if fn & 1 == 1 or fn == sn:
1484
+ first_root = self._hash_node(candidate, first_root)
1485
+ second_root = self._hash_node(candidate, second_root)
1486
+ while fn > 0 and fn & 1 == 0:
1487
+ fn >>= 1
1488
+ sn >>= 1
1489
+ else:
1490
+ second_root = self._hash_node(second_root, candidate)
1491
+ fn >>= 1
1492
+ sn >>= 1
1493
+ return first_root == first_hash and second_root == second_hash and sn == 0
1494
+
1495
+ @staticmethod
1496
+ def _is_power_of_two(value: int) -> bool:
1497
+ return value > 0 and (value & (value - 1)) == 0
1498
+
1499
+ def _normalize_json_numbers_for_cbor(self, value: Any) -> Any:
1500
+ if value is None or isinstance(value, (str, bool, float)):
1501
+ return value
1502
+ if isinstance(value, int):
1503
+ return float(value)
1504
+ if isinstance(value, (list, tuple)):
1505
+ return [self._normalize_json_numbers_for_cbor(item) for item in value]
1506
+ if isinstance(value, dict):
1507
+ return {key: self._normalize_json_numbers_for_cbor(item) for key, item in value.items()}
1508
+ return value
1509
+
1510
+ def _encode_deterministic_cbor(self, value: Any) -> bytes:
1511
+ if value is None:
1512
+ return b"\xf6"
1513
+ if isinstance(value, bool):
1514
+ return b"\xf5" if value else b"\xf4"
1515
+ if isinstance(value, str):
1516
+ encoded = value.encode("utf-8")
1517
+ return self._cbor_header(3, len(encoded)) + encoded
1518
+ if isinstance(value, bytes):
1519
+ return self._cbor_header(2, len(value)) + value
1520
+ if isinstance(value, int):
1521
+ return self._cbor_int(value)
1522
+ if isinstance(value, float):
1523
+ if math.isnan(value) or math.isinf(value):
1524
+ raise ValueError(f"Unsupported floating-point value {value}")
1525
+ return b"\xfb" + struct.pack(">d", value)
1526
+ if isinstance(value, (list, tuple)):
1527
+ items = b"".join(self._encode_deterministic_cbor(item) for item in value)
1528
+ return self._cbor_header(4, len(value)) + items
1529
+ if isinstance(value, dict):
1530
+ keys = sorted(value.keys(), key=lambda item: (len(item), item))
1531
+ encoded_items = bytearray()
1532
+ for key in keys:
1533
+ encoded_items.extend(self._encode_deterministic_cbor(str(key)))
1534
+ encoded_items.extend(self._encode_deterministic_cbor(value[key]))
1535
+ return self._cbor_header(5, len(keys)) + bytes(encoded_items)
1536
+ normalized = json.loads(json.dumps(value))
1537
+ return self._encode_deterministic_cbor(normalized)
1538
+
1539
+ def _cbor_int(self, value: int) -> bytes:
1540
+ if value >= 0:
1541
+ return self._cbor_header(0, value)
1542
+ return self._cbor_header(1, -(value + 1))
1543
+
1544
+ def _cbor_header(self, major: int, value: int) -> bytes:
1545
+ if value <= 23:
1546
+ return bytes([(major << 5) | value])
1547
+ if value <= 0xFF:
1548
+ return bytes([(major << 5) | 24, value])
1549
+ if value <= 0xFFFF:
1550
+ return bytes([(major << 5) | 25]) + value.to_bytes(2, "big")
1551
+ if value <= 0xFFFFFFFF:
1552
+ return bytes([(major << 5) | 26]) + value.to_bytes(4, "big")
1553
+ return bytes([(major << 5) | 27]) + value.to_bytes(8, "big")
1554
+
1555
+ @staticmethod
1556
+ def _decode_base64url(value: str) -> bytes:
1557
+ padded = value + "=" * ((4 - len(value) % 4) % 4)
1558
+ return base64.urlsafe_b64decode(padded.encode("ascii"))
1559
+
1560
+ # ------------------------------------------------------------------
1561
+ # Lifecycle
1562
+ # ------------------------------------------------------------------
1563
+
1564
+ def close(self) -> None:
1565
+ """Close the underlying HTTP clients."""
1566
+ if self._sync_client and not self._sync_client.is_closed:
1567
+ self._sync_client.close()
1568
+ if self._async_client and not self._async_client.is_closed:
1569
+ # Can't await in sync context; schedule close if event loop exists
1570
+ import asyncio
1571
+
1572
+ try:
1573
+ loop = asyncio.get_running_loop()
1574
+ loop.create_task(self._async_client.aclose())
1575
+ except RuntimeError:
1576
+ pass # No running loop; client will be GC'd
1577
+
1578
+ async def aclose(self) -> None:
1579
+ """Close the underlying HTTP clients (async)."""
1580
+ if self._sync_client and not self._sync_client.is_closed:
1581
+ self._sync_client.close()
1582
+ if self._async_client and not self._async_client.is_closed:
1583
+ await self._async_client.aclose()
1584
+
1585
+ def __enter__(self) -> LedgixClient:
1586
+ return self
1587
+
1588
+ def __exit__(self, *args: Any) -> None:
1589
+ self.close()
1590
+
1591
+ async def __aenter__(self) -> LedgixClient:
1592
+ return self
1593
+
1594
+ async def __aexit__(self, *args: Any) -> None:
1595
+ await self.aclose()