nullrun 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nullrun/__init__.py +282 -0
- nullrun/__version__.py +4 -0
- nullrun/actions.py +455 -0
- nullrun/breaker/__init__.py +27 -0
- nullrun/breaker/circuit_breaker.py +402 -0
- nullrun/breaker/exceptions.py +319 -0
- nullrun/context.py +208 -0
- nullrun/decorators.py +649 -0
- nullrun/instrumentation/__init__.py +23 -0
- nullrun/instrumentation/_safe_patch.py +99 -0
- nullrun/instrumentation/auto.py +1095 -0
- nullrun/instrumentation/auto_requests.py +257 -0
- nullrun/instrumentation/autogen.py +163 -0
- nullrun/instrumentation/crewai.py +140 -0
- nullrun/instrumentation/langgraph.py +412 -0
- nullrun/instrumentation/llama_index.py +110 -0
- nullrun/observability.py +160 -0
- nullrun/py.typed +0 -0
- nullrun/runtime.py +1806 -0
- nullrun/toolbox/__init__.py +20 -0
- nullrun/toolbox/langgraph.py +94 -0
- nullrun/tracing.py +155 -0
- nullrun/transport.py +1509 -0
- nullrun/transport_websocket.py +627 -0
- nullrun-0.4.0.dist-info/METADATA +194 -0
- nullrun-0.4.0.dist-info/RECORD +28 -0
- nullrun-0.4.0.dist-info/WHEEL +4 -0
- nullrun-0.4.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,627 @@
|
|
|
1
|
+
"""
|
|
2
|
+
WebSocket transport for NullRun SDK.
|
|
3
|
+
|
|
4
|
+
Provides real-time workflow state updates via WebSocket connection.
|
|
5
|
+
Replaces polling pattern: SDK connects to WS, receives push updates
|
|
6
|
+
when workflow state changes (KILL/PAUSE).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import hashlib
|
|
11
|
+
import hmac
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
import time
|
|
15
|
+
from collections.abc import Callable
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import websockets
|
|
20
|
+
WEBSOCKETS_AVAILABLE = True
|
|
21
|
+
except ImportError:
|
|
22
|
+
WEBSOCKETS_AVAILABLE = False
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compute_hmac_signature(api_key: str, secret_key: str, timestamp: int, payload: bytes) -> str:
|
|
28
|
+
"""
|
|
29
|
+
Compute HMAC-SHA256 signature for WebSocket message verification.
|
|
30
|
+
|
|
31
|
+
Signature = HMAC-SHA256(secret_key, timestamp:api_key:payload_hash)
|
|
32
|
+
where payload_hash = SHA256(message_json)
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
api_key: Client's API key (identifier)
|
|
36
|
+
secret_key: Client's secret key (used for HMAC)
|
|
37
|
+
timestamp: Unix timestamp in seconds
|
|
38
|
+
payload: Raw message payload bytes
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Hex-encoded HMAC-SHA256 signature
|
|
42
|
+
"""
|
|
43
|
+
# Compute payload hash: SHA256(payload)
|
|
44
|
+
payload_hash = hashlib.sha256(payload).hexdigest()
|
|
45
|
+
|
|
46
|
+
# Construct message: timestamp:api_key:payload_hash
|
|
47
|
+
message = f"{timestamp}:{api_key}:{payload_hash}"
|
|
48
|
+
|
|
49
|
+
# Compute HMAC-SHA256
|
|
50
|
+
signature = hmac.new(
|
|
51
|
+
secret_key.encode('utf-8'),
|
|
52
|
+
message.encode('utf-8'),
|
|
53
|
+
hashlib.sha256
|
|
54
|
+
).hexdigest()
|
|
55
|
+
|
|
56
|
+
return signature
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def verify_hmac_signature(
|
|
60
|
+
api_key: str,
|
|
61
|
+
secret_key: str,
|
|
62
|
+
timestamp: int,
|
|
63
|
+
payload: bytes,
|
|
64
|
+
signature: str,
|
|
65
|
+
max_age_seconds: int = 300,
|
|
66
|
+
) -> bool:
|
|
67
|
+
"""
|
|
68
|
+
Verify HMAC signature for a WebSocket message.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
api_key: Client's API key (identifier)
|
|
72
|
+
secret_key: Client's secret key (used for HMAC)
|
|
73
|
+
timestamp: Unix timestamp from message (seconds)
|
|
74
|
+
payload: Raw message payload bytes
|
|
75
|
+
signature: HMAC signature to verify (hex-encoded)
|
|
76
|
+
max_age_seconds: Maximum allowed age of message (default 5 min)
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
True if signature is valid and timestamp is fresh, False otherwise
|
|
80
|
+
"""
|
|
81
|
+
# Check timestamp freshness
|
|
82
|
+
current_time = int(time.time())
|
|
83
|
+
age = abs(current_time - timestamp)
|
|
84
|
+
|
|
85
|
+
if age > max_age_seconds:
|
|
86
|
+
logger.warning(f"WS signature timestamp expired: age={age}s, max={max_age_seconds}s")
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
# Compute expected signature
|
|
90
|
+
expected = compute_hmac_signature(api_key, secret_key, timestamp, payload)
|
|
91
|
+
|
|
92
|
+
# Constant-time comparison to prevent timing attacks
|
|
93
|
+
return hmac.compare_digest(expected, signature)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class WebSocketConnection:
|
|
97
|
+
"""
|
|
98
|
+
WebSocket connection for real-time control plane updates.
|
|
99
|
+
|
|
100
|
+
Usage:
|
|
101
|
+
conn = await transport.connect_websocket(
|
|
102
|
+
organization_id="org-123",
|
|
103
|
+
api_key="nr_live_xxx",
|
|
104
|
+
secret_key="secret_xxx",
|
|
105
|
+
on_state_change=lambda state: print(f"State changed: {state}")
|
|
106
|
+
)
|
|
107
|
+
# Connection stays open, receiving state updates
|
|
108
|
+
await conn.close()
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# States that require acknowledgment (KILL/PAUSE).
|
|
112
|
+
# The server's WsWorkflowState enum (NULLRUN/backend/src/proxy/http/
|
|
113
|
+
# ws_control.rs) emits PascalCase ("Killed", "Paused"); the SDK
|
|
114
|
+
# must compare against the same casing, otherwise the ACK
|
|
115
|
+
# path stays dead and the server's pending-ack queue grows
|
|
116
|
+
# without ever being drained.
|
|
117
|
+
ACKNOWLEDGED_STATES = {"Killed", "Paused"}
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
url: str,
|
|
122
|
+
headers: dict[str, str] | None = None,
|
|
123
|
+
api_key: str | None = None,
|
|
124
|
+
secret_key: str | None = None,
|
|
125
|
+
on_state_change: Callable[[dict[str, Any]], None] | None = None,
|
|
126
|
+
on_policy_invalidated: Callable[[str, str, int], None] | None = None,
|
|
127
|
+
on_key_rotated: Callable[[str, str, int], None] | None = None,
|
|
128
|
+
):
|
|
129
|
+
"""
|
|
130
|
+
Initialize WebSocket connection.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
url: WebSocket URL (e.g., "wss://api.nullrun.io/ws/control/org-123")
|
|
134
|
+
headers: HTTP headers for authentication
|
|
135
|
+
api_key: API key for HMAC verification (optional but recommended)
|
|
136
|
+
secret_key: Secret key for HMAC verification (optional but recommended)
|
|
137
|
+
on_state_change: Callback when workflow state changes
|
|
138
|
+
on_policy_invalidated: Callback when policy cache should be cleared
|
|
139
|
+
Args: (organization_id, policy_id, new_version)
|
|
140
|
+
on_key_rotated: Callback when secret key should be re-fetched
|
|
141
|
+
Args: (organization_id, key_id, new_version)
|
|
142
|
+
"""
|
|
143
|
+
self.url = url
|
|
144
|
+
self.headers = headers or {}
|
|
145
|
+
self.api_key = api_key
|
|
146
|
+
self.secret_key = secret_key
|
|
147
|
+
self.on_state_change = on_state_change
|
|
148
|
+
self.on_policy_invalidated = on_policy_invalidated
|
|
149
|
+
self.on_key_rotated = on_key_rotated
|
|
150
|
+
self._conn = None
|
|
151
|
+
self._running = False
|
|
152
|
+
self._receive_task: asyncio.Task | None = None
|
|
153
|
+
self._reconnect_task: asyncio.Task | None = None
|
|
154
|
+
self._closed = False
|
|
155
|
+
# Per-workflow monotonic version dedup (ADR-007).
|
|
156
|
+
# Drop incoming state changes with ``version <= last`` to
|
|
157
|
+
# survive the at-least-once delivery semantics of the WS
|
|
158
|
+
# channel.
|
|
159
|
+
#
|
|
160
|
+
# Sprint 1.4 (B2): the previous sentinel of 0 dropped incoming
|
|
161
|
+
# ``version == 0`` on first receive because ``0 <= 0`` is
|
|
162
|
+
# True. The server uses ``version: 0`` for the very first
|
|
163
|
+
# ``initial_state`` frame after a (re)connect, so the SDK was
|
|
164
|
+
# silently discarding the server's initial view — meaning a
|
|
165
|
+
# ``Killed``/``Paused`` state delivered in that first frame
|
|
166
|
+
# was lost. Sentinel is now -1 so any non-negative version
|
|
167
|
+
# passes the guard on the first message; subsequent stale
|
|
168
|
+
# ``version == 0`` re-deliveries are still dropped because
|
|
169
|
+
# ``last_seen`` will be ``>= 1`` for that workflow.
|
|
170
|
+
self._last_version: dict[str, int] = {}
|
|
171
|
+
|
|
172
|
+
async def _reconnect_loop(self) -> None:
|
|
173
|
+
"""
|
|
174
|
+
Background reconnect loop with exponential backoff.
|
|
175
|
+
|
|
176
|
+
The receive loop sets ``self._running = False`` in its
|
|
177
|
+
``finally`` block when the connection drops. This loop waits
|
|
178
|
+
while the receive loop is healthy and reconnects on demand.
|
|
179
|
+
|
|
180
|
+
Without the ``continue`` branch, the pre-fix code exited after
|
|
181
|
+
the very first successful ``_connect()`` because the
|
|
182
|
+
``if not self._running`` guard became False the moment
|
|
183
|
+
``_connect()`` set ``_running = True``. That broke the control
|
|
184
|
+
plane: after any network blip, kill/pause commands from the
|
|
185
|
+
dashboard would never reach the client until the process was
|
|
186
|
+
restarted. For a product whose core promise is a centralised
|
|
187
|
+
kill-switch, this was a safety gap — see plan item B1.
|
|
188
|
+
"""
|
|
189
|
+
delay = 1.0
|
|
190
|
+
max_delay = 60.0
|
|
191
|
+
|
|
192
|
+
while not self._closed:
|
|
193
|
+
if self._running:
|
|
194
|
+
# Receive loop is healthy. Sleep briefly and re-check;
|
|
195
|
+
# if the connection drops the receive loop's
|
|
196
|
+
# ``finally`` block will set ``_running = False`` and
|
|
197
|
+
# we will reconnect on the next iteration.
|
|
198
|
+
await asyncio.sleep(0.5)
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
# Connection is down. Try to reconnect with backoff.
|
|
202
|
+
try:
|
|
203
|
+
await self._connect()
|
|
204
|
+
delay = 1.0 # reset on success
|
|
205
|
+
logger.info(f"WebSocket reconnected successfully: {self.url}")
|
|
206
|
+
# A fresh server connection may re-deliver events the
|
|
207
|
+
# client has already seen (or has never seen) — clear
|
|
208
|
+
# the version-dedup cache so the server's current view
|
|
209
|
+
# is accepted, not deduplicated against the
|
|
210
|
+
# pre-disconnect state. Same semantic as
|
|
211
|
+
# ``resync_required``.
|
|
212
|
+
self.clear_local_state()
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.warning(f"WebSocket reconnect failed, retrying in {delay}s: {e}")
|
|
215
|
+
await asyncio.sleep(delay)
|
|
216
|
+
delay = min(delay * 2, max_delay)
|
|
217
|
+
|
|
218
|
+
async def _connect(self) -> None:
|
|
219
|
+
"""
|
|
220
|
+
Establish WebSocket connection.
|
|
221
|
+
|
|
222
|
+
Internal method used by connect() and reconnect loop.
|
|
223
|
+
"""
|
|
224
|
+
self._conn = await websockets.connect(
|
|
225
|
+
self.url, additional_headers=self.headers
|
|
226
|
+
)
|
|
227
|
+
self._running = True
|
|
228
|
+
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
229
|
+
|
|
230
|
+
async def connect(self) -> None:
|
|
231
|
+
"""
|
|
232
|
+
Establish WebSocket connection with automatic reconnect.
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
ConnectionError: If connection fails
|
|
236
|
+
ImportError: If websockets library not available
|
|
237
|
+
"""
|
|
238
|
+
if not WEBSOCKETS_AVAILABLE:
|
|
239
|
+
raise ImportError(
|
|
240
|
+
"websockets library not available. "
|
|
241
|
+
"Install with: pip install nullrun[websocket]"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
self._closed = False
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
await self._connect()
|
|
248
|
+
self._reconnect_task = asyncio.create_task(self._reconnect_loop())
|
|
249
|
+
logger.info(f"WebSocket connected: {self.url}")
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.error(f"WebSocket connection failed: {e}")
|
|
252
|
+
raise ConnectionError(f"Failed to connect to {self.url}: {e}") from e
|
|
253
|
+
|
|
254
|
+
async def _receive_loop(self) -> None:
|
|
255
|
+
"""
|
|
256
|
+
Receive messages from WebSocket and dispatch to handler.
|
|
257
|
+
"""
|
|
258
|
+
try:
|
|
259
|
+
async for message in self._conn:
|
|
260
|
+
await self._handle_message(message)
|
|
261
|
+
except websockets.exceptions.ConnectionClosed:
|
|
262
|
+
logger.info("WebSocket connection closed")
|
|
263
|
+
except Exception as e:
|
|
264
|
+
logger.warning(f"WebSocket receive error: {e}")
|
|
265
|
+
finally:
|
|
266
|
+
self._running = False
|
|
267
|
+
|
|
268
|
+
async def _handle_message(self, message: str) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Handle incoming WebSocket message.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
message: Raw message string (JSON)
|
|
274
|
+
"""
|
|
275
|
+
try:
|
|
276
|
+
data = json.loads(message)
|
|
277
|
+
msg_type = data.get("type", "")
|
|
278
|
+
|
|
279
|
+
# Check for HMAC signature and verify if present
|
|
280
|
+
signature = data.get("signature")
|
|
281
|
+
timestamp = data.get("timestamp")
|
|
282
|
+
if signature and timestamp and self.api_key and self.secret_key:
|
|
283
|
+
# This is a signed message - verify the signature
|
|
284
|
+
msg_timestamp = int(timestamp) if isinstance(timestamp, (int, str)) else 0
|
|
285
|
+
|
|
286
|
+
# FIX-C (counterpart of backend fix(ws-control) in
|
|
287
|
+
# NULLRUN): the server embeds the exact bytes that were
|
|
288
|
+
# HMAC-signed in `signed_payload` (hex-encoded). The
|
|
289
|
+
# receiver MUST verify against those exact bytes —
|
|
290
|
+
# never against the full wire JSON (which includes
|
|
291
|
+
# signature/timestamp/api_key_id themselves and would
|
|
292
|
+
# never match). The pre-FIX-C server builds kept the
|
|
293
|
+
# signing scheme but did not publish the canonical
|
|
294
|
+
# payload, so we fall back to the legacy behaviour
|
|
295
|
+
# (verify against the full wire bytes) only when
|
|
296
|
+
# `signed_payload` is absent.
|
|
297
|
+
#
|
|
298
|
+
# See memory/ws-signed-message-byte-mismatch for the
|
|
299
|
+
# original failure this design rule encodes.
|
|
300
|
+
signed_payload_hex = data.get("signed_payload")
|
|
301
|
+
if isinstance(signed_payload_hex, str) and signed_payload_hex:
|
|
302
|
+
try:
|
|
303
|
+
verify_payload = bytes.fromhex(signed_payload_hex)
|
|
304
|
+
except ValueError:
|
|
305
|
+
# Malformed hex from a non-conforming server.
|
|
306
|
+
# Fall through to the legacy wire-bytes path
|
|
307
|
+
# so we still have a chance to accept it; the
|
|
308
|
+
# signature check will fail in either case
|
|
309
|
+
# and we'll reject with the standard error.
|
|
310
|
+
verify_payload = message.encode('utf-8')
|
|
311
|
+
else:
|
|
312
|
+
# Pre-FIX-C server: verify against full wire
|
|
313
|
+
# bytes. Will pass only on round-trip tests where
|
|
314
|
+
# the server happens to hash the same bytes we
|
|
315
|
+
# do; in real life this is the byte-mismatch path
|
|
316
|
+
# and the message should be rejected. Kept as
|
|
317
|
+
# best-effort backwards compatibility.
|
|
318
|
+
verify_payload = message.encode('utf-8')
|
|
319
|
+
|
|
320
|
+
if not verify_hmac_signature(
|
|
321
|
+
self.api_key,
|
|
322
|
+
self.secret_key,
|
|
323
|
+
msg_timestamp,
|
|
324
|
+
verify_payload,
|
|
325
|
+
signature,
|
|
326
|
+
max_age_seconds=300,
|
|
327
|
+
):
|
|
328
|
+
# Sprint 1.5 (B13): pre-fix this logged at
|
|
329
|
+
# WARNING and dropped the message silently. For a
|
|
330
|
+
# safety layer whose core contract is "the
|
|
331
|
+
# server can always KILL a workflow", a failed
|
|
332
|
+
# signature verification on a control plane
|
|
333
|
+
# message is a first-class incident — promote to
|
|
334
|
+
# ERROR and bump the counter so an SRE can
|
|
335
|
+
# alert on ``hmac_verify_failures_total > 0``.
|
|
336
|
+
# A signed-but-invalid message means either
|
|
337
|
+
# (a) the secret_key is out of sync (server
|
|
338
|
+
# rotated, client missed the rotation event), or
|
|
339
|
+
# (b) something is forging traffic. Both are
|
|
340
|
+
# actionable and the operator needs to know.
|
|
341
|
+
logger.error(
|
|
342
|
+
f"Invalid HMAC signature for {msg_type} message - "
|
|
343
|
+
"rejecting. This usually means the secret_key is out "
|
|
344
|
+
"of sync with the server (check for a key_rotated "
|
|
345
|
+
"event you may have missed) or the control plane is "
|
|
346
|
+
"being tampered with."
|
|
347
|
+
)
|
|
348
|
+
# Local import to avoid a module-level cycle:
|
|
349
|
+
# observability imports nothing from us, so this
|
|
350
|
+
# is safe and lazy.
|
|
351
|
+
from nullrun.observability import metrics
|
|
352
|
+
metrics.inc_transport("hmac_verify_failures_total")
|
|
353
|
+
return
|
|
354
|
+
|
|
355
|
+
# FIX-C (counterpart of backend fix(ws-control) in
|
|
356
|
+
# NULLRUN): when the message is signed and carries a
|
|
357
|
+
# `signed_payload` field, dispatching from the outer
|
|
358
|
+
# body fields would let an attacker splice forged values
|
|
359
|
+
# into the outer body while reusing a captured
|
|
360
|
+
# (signed_payload, signature) pair. The signature is
|
|
361
|
+
# computed over the bytes inside signed_payload, not the
|
|
362
|
+
# outer body, so the *only* trusted source is signed_payload
|
|
363
|
+
# itself. We parse it once and use the parsed dict for all
|
|
364
|
+
# state-dispatch decisions.
|
|
365
|
+
#
|
|
366
|
+
# For non-signed messages (legacy servers, or policy
|
|
367
|
+
# events that don't need per-payload signing) we fall back
|
|
368
|
+
# to the outer body — there is no signing, no attacker
|
|
369
|
+
# model.
|
|
370
|
+
trusted: dict[str, Any] | None = None
|
|
371
|
+
if signature and timestamp and self.api_key and self.secret_key:
|
|
372
|
+
if isinstance(signed_payload_hex, str) and signed_payload_hex:
|
|
373
|
+
try:
|
|
374
|
+
trusted = json.loads(
|
|
375
|
+
bytes.fromhex(signed_payload_hex).decode("utf-8")
|
|
376
|
+
)
|
|
377
|
+
except (ValueError, json.JSONDecodeError):
|
|
378
|
+
# Malformed signed_payload — the signature
|
|
379
|
+
# check above will already have rejected this
|
|
380
|
+
# message, so this branch should be unreachable
|
|
381
|
+
# in practice. We keep the fall-through to
|
|
382
|
+
# outer body to avoid a hard crash if the
|
|
383
|
+
# two checks ever drift.
|
|
384
|
+
trusted = None
|
|
385
|
+
|
|
386
|
+
if msg_type == "initial_state":
|
|
387
|
+
# Initial state with all workflow states
|
|
388
|
+
workflows = data.get("workflows", [])
|
|
389
|
+
logger.debug(f"Received initial state: {len(workflows)} workflows")
|
|
390
|
+
for wf in workflows:
|
|
391
|
+
# Trust the inner workflows[] entries the same
|
|
392
|
+
# way we trust state_change: when the parent
|
|
393
|
+
# envelope is signed, parse each entry from its
|
|
394
|
+
# embedded signed_payload if present, else fall
|
|
395
|
+
# back to the outer dict.
|
|
396
|
+
if isinstance(wf, dict) and wf.get("signed_payload") and self.api_key and self.secret_key:
|
|
397
|
+
try:
|
|
398
|
+
inner = json.loads(
|
|
399
|
+
bytes.fromhex(wf["signed_payload"]).decode("utf-8")
|
|
400
|
+
)
|
|
401
|
+
self._dispatch_state(inner)
|
|
402
|
+
continue
|
|
403
|
+
except (ValueError, json.JSONDecodeError, KeyError):
|
|
404
|
+
pass
|
|
405
|
+
self._dispatch_state(wf)
|
|
406
|
+
|
|
407
|
+
elif msg_type == "state_change":
|
|
408
|
+
# Workflow state change notification
|
|
409
|
+
# Check if this message requires acknowledgment
|
|
410
|
+
await self._handle_state_change_with_ack(data, trusted)
|
|
411
|
+
|
|
412
|
+
elif msg_type == "policy_invalidated":
|
|
413
|
+
# Policy was updated via dashboard - SDK should clear its cache
|
|
414
|
+
organization_id = data.get("organization_id", "")
|
|
415
|
+
policy_id = data.get("policy_id", "")
|
|
416
|
+
new_version = data.get("new_version", 0)
|
|
417
|
+
logger.info(f"Policy invalidated: {policy_id} v{new_version}, org: {organization_id}")
|
|
418
|
+
if self.on_policy_invalidated:
|
|
419
|
+
try:
|
|
420
|
+
self.on_policy_invalidated(organization_id, policy_id, new_version)
|
|
421
|
+
except Exception as e:
|
|
422
|
+
logger.warning(f"Policy invalidation callback error: {e}")
|
|
423
|
+
|
|
424
|
+
elif msg_type == "key_rotated":
|
|
425
|
+
# HMAC secret key was rotated - SDK should re-fetch from /auth/verify
|
|
426
|
+
organization_id = data.get("organization_id", "")
|
|
427
|
+
key_id = data.get("key_id", "")
|
|
428
|
+
new_version = data.get("new_version", 0)
|
|
429
|
+
logger.info(f"Key rotated: {key_id} v{new_version}, org: {organization_id}")
|
|
430
|
+
if self.on_key_rotated:
|
|
431
|
+
try:
|
|
432
|
+
self.on_key_rotated(organization_id, key_id, new_version)
|
|
433
|
+
except Exception as e:
|
|
434
|
+
logger.warning(f"Key rotation callback error: {e}")
|
|
435
|
+
|
|
436
|
+
elif msg_type == "resync_required":
|
|
437
|
+
# Server overflowed its broadcast channel. Per
|
|
438
|
+
# ADR-007 the SDK MUST close, reconnect, and
|
|
439
|
+
# replace its local state from the new
|
|
440
|
+
# ``initial_state`` — there is no "catch up"
|
|
441
|
+
# semantics. We clear the version-dedup cache and
|
|
442
|
+
# let ``_reconnect_loop`` reopen the connection.
|
|
443
|
+
reason = data.get("reason", "overflow")
|
|
444
|
+
logger.warning(
|
|
445
|
+
f"Server requested resync (reason={reason}); "
|
|
446
|
+
"clearing local state and reconnecting"
|
|
447
|
+
)
|
|
448
|
+
self.clear_local_state()
|
|
449
|
+
self._running = False
|
|
450
|
+
self._closed = True
|
|
451
|
+
if self._conn is not None:
|
|
452
|
+
try:
|
|
453
|
+
await self._conn.close()
|
|
454
|
+
except Exception: # noqa: BLE001
|
|
455
|
+
pass
|
|
456
|
+
self._conn = None
|
|
457
|
+
|
|
458
|
+
elif msg_type == "pong":
|
|
459
|
+
# Pong response to ping - connection is alive
|
|
460
|
+
pass
|
|
461
|
+
|
|
462
|
+
elif msg_type == "subscribed":
|
|
463
|
+
# Subscription confirmation
|
|
464
|
+
organization_id = data.get("organization_id")
|
|
465
|
+
logger.debug(f"Subscribed to organization: {organization_id}")
|
|
466
|
+
|
|
467
|
+
elif msg_type == "error":
|
|
468
|
+
# Error message from server
|
|
469
|
+
code = data.get("code", "unknown")
|
|
470
|
+
message = data.get("message", "Unknown error")
|
|
471
|
+
logger.warning(f"WebSocket error: {code} - {message}")
|
|
472
|
+
|
|
473
|
+
except json.JSONDecodeError:
|
|
474
|
+
logger.warning(f"Invalid JSON message: {message[:100]}")
|
|
475
|
+
|
|
476
|
+
async def _handle_state_change_with_ack(
|
|
477
|
+
self,
|
|
478
|
+
data: dict[str, Any],
|
|
479
|
+
trusted: dict[str, Any] | None = None,
|
|
480
|
+
) -> None:
|
|
481
|
+
"""
|
|
482
|
+
Handle state change message that may require acknowledgment.
|
|
483
|
+
|
|
484
|
+
For killed/paused states, sends ACK immediately before dispatching.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
data: The outer (envelope) message data — used for
|
|
488
|
+
routing metadata only.
|
|
489
|
+
trusted: The parsed bytes of `signed_payload` (when the
|
|
490
|
+
message was signed). When present, dispatch reads
|
|
491
|
+
state / workflow_id / version / message_id from this
|
|
492
|
+
dict, NOT from `data`. The signature is computed over
|
|
493
|
+
the bytes inside signed_payload, so any divergence
|
|
494
|
+
between `data` and `trusted` is a forgery attempt and
|
|
495
|
+
must not be honoured.
|
|
496
|
+
"""
|
|
497
|
+
# FIX-C: when the message is signed, the signature covers the
|
|
498
|
+
# bytes inside `signed_payload`, not the outer body. We must
|
|
499
|
+
# use `trusted` (the parsed signed_payload) for any
|
|
500
|
+
# security-sensitive decision. The outer `data` is only used
|
|
501
|
+
# for routing.
|
|
502
|
+
source = trusted if trusted is not None else data
|
|
503
|
+
state = source.get("state", "")
|
|
504
|
+
workflow_id = source.get("workflow_id", "")
|
|
505
|
+
message_id = source.get("message_id")
|
|
506
|
+
|
|
507
|
+
# Check if this state requires acknowledgment
|
|
508
|
+
if state in self.ACKNOWLEDGED_STATES and message_id:
|
|
509
|
+
# Send ACK immediately
|
|
510
|
+
await self._send_ack(message_id)
|
|
511
|
+
logger.debug(f"Sent ACK for message {message_id} ({state} for workflow {workflow_id})")
|
|
512
|
+
|
|
513
|
+
# Dispatch state to callback. Use the trusted source so
|
|
514
|
+
# callbacks (and the per-workflow version dedup in
|
|
515
|
+
# _dispatch_state) see the same values that were ACK'd.
|
|
516
|
+
self._dispatch_state(source)
|
|
517
|
+
|
|
518
|
+
async def _send_ack(self, message_id: str) -> None:
|
|
519
|
+
"""
|
|
520
|
+
Send acknowledgment message to server.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
message_id: The message ID to acknowledge
|
|
524
|
+
"""
|
|
525
|
+
if not self._conn or not self._running:
|
|
526
|
+
logger.warning("Cannot send ACK - WebSocket not connected")
|
|
527
|
+
return
|
|
528
|
+
|
|
529
|
+
try:
|
|
530
|
+
ack = {
|
|
531
|
+
"type": "ack",
|
|
532
|
+
"message_id": message_id,
|
|
533
|
+
"received_at": int(time.time() * 1000), # milliseconds
|
|
534
|
+
}
|
|
535
|
+
await self._conn.send(json.dumps(ack))
|
|
536
|
+
logger.debug(f"ACK sent for message {message_id}")
|
|
537
|
+
except Exception as e:
|
|
538
|
+
logger.warning(f"Failed to send ACK: {e}")
|
|
539
|
+
|
|
540
|
+
def _dispatch_state(self, state: dict[str, Any]) -> None:
|
|
541
|
+
"""
|
|
542
|
+
Dispatch state to callback after per-workflow version dedup
|
|
543
|
+
(ADR-007: at-least-once delivery, drop stale events).
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
state: State dict with workflow_id, state, version, etc.
|
|
547
|
+
"""
|
|
548
|
+
workflow_id = state.get("workflow_id", "")
|
|
549
|
+
incoming_version = state.get("version", 0)
|
|
550
|
+
if workflow_id:
|
|
551
|
+
# Sprint 1.4 (B2): default -1 (not 0) so version=0 is
|
|
552
|
+
# accepted on first receive. See __init__ for rationale.
|
|
553
|
+
last = self._last_version.get(workflow_id, -1)
|
|
554
|
+
if incoming_version <= last:
|
|
555
|
+
logger.debug(
|
|
556
|
+
f"Dropping stale state event for {workflow_id}: "
|
|
557
|
+
f"incoming version={incoming_version} <= last={last}"
|
|
558
|
+
)
|
|
559
|
+
return
|
|
560
|
+
self._last_version[workflow_id] = incoming_version
|
|
561
|
+
if self.on_state_change:
|
|
562
|
+
try:
|
|
563
|
+
self.on_state_change(state)
|
|
564
|
+
except Exception as e:
|
|
565
|
+
logger.warning(f"State change callback error: {e}")
|
|
566
|
+
|
|
567
|
+
def clear_local_state(self) -> None:
|
|
568
|
+
"""
|
|
569
|
+
Clear the in-memory per-workflow version cache.
|
|
570
|
+
|
|
571
|
+
Called after a ``ResyncRequired`` event so the next
|
|
572
|
+
``initial_state`` from the server is accepted (the dedup
|
|
573
|
+
cache may otherwise drop the server's freshest state if
|
|
574
|
+
the version is unchanged from the pre-overflow value).
|
|
575
|
+
Per ADR-007 there is no "merge" — local state is fully
|
|
576
|
+
replaced by the next ``initial_state``.
|
|
577
|
+
"""
|
|
578
|
+
self._last_version.clear()
|
|
579
|
+
|
|
580
|
+
async def send(self, message: dict[str, Any]) -> None:
|
|
581
|
+
"""
|
|
582
|
+
Send message to WebSocket server.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
message: Message dict (will be JSON serialized)
|
|
586
|
+
"""
|
|
587
|
+
if not self._conn or not self._running:
|
|
588
|
+
raise ConnectionError("WebSocket not connected")
|
|
589
|
+
|
|
590
|
+
try:
|
|
591
|
+
await self._conn.send(json.dumps(message))
|
|
592
|
+
except Exception as e:
|
|
593
|
+
logger.warning(f"WebSocket send error: {e}")
|
|
594
|
+
raise
|
|
595
|
+
|
|
596
|
+
async def close(self) -> None:
|
|
597
|
+
"""
|
|
598
|
+
Close WebSocket connection.
|
|
599
|
+
"""
|
|
600
|
+
self._closed = True
|
|
601
|
+
self._running = False
|
|
602
|
+
|
|
603
|
+
if self._reconnect_task:
|
|
604
|
+
self._reconnect_task.cancel()
|
|
605
|
+
try:
|
|
606
|
+
await self._reconnect_task
|
|
607
|
+
except asyncio.CancelledError:
|
|
608
|
+
pass
|
|
609
|
+
|
|
610
|
+
if self._receive_task:
|
|
611
|
+
self._receive_task.cancel()
|
|
612
|
+
try:
|
|
613
|
+
await self._receive_task
|
|
614
|
+
except asyncio.CancelledError:
|
|
615
|
+
pass
|
|
616
|
+
|
|
617
|
+
if self._conn:
|
|
618
|
+
await self._conn.close()
|
|
619
|
+
self._conn = None
|
|
620
|
+
|
|
621
|
+
logger.info("WebSocket connection closed")
|
|
622
|
+
|
|
623
|
+
@property
|
|
624
|
+
def is_connected(self) -> bool:
|
|
625
|
+
"""Check if connection is active."""
|
|
626
|
+
return self._running and self._conn is not None and not self._closed
|
|
627
|
+
|