threadify-sdk 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,505 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from collections.abc import Callable
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ if TYPE_CHECKING:
10
+ from threadify.data_retriever import ArchivedThread, DataRetriever
11
+ from threadify.notification import Notification
12
+ from threadify.thread import ThreadInstance
13
+
14
+ from threadify.models import (
15
+ ACTION_ACK_NOTIFICATION,
16
+ ACTION_CLOSE_CONNECTION,
17
+ ACTION_JOIN_THREAD,
18
+ ACTION_NOTIFICATION,
19
+ ACTION_NOTIFICATION_BATCH,
20
+ ACTION_START_THREAD,
21
+ ACTION_SUBSCRIBE,
22
+ ACTION_UNSUBSCRIBE,
23
+ FIELD_ACK_TOKEN,
24
+ FIELD_ACTION,
25
+ FIELD_CONTRACT_NAME,
26
+ FIELD_EVENT_TYPES,
27
+ FIELD_MESSAGE,
28
+ FIELD_NOTIFICATION,
29
+ FIELD_NOTIFICATION_ID,
30
+ FIELD_NOTIFICATION_ID_ACK,
31
+ FIELD_NOTIFICATIONS,
32
+ FIELD_PROCESSED,
33
+ FIELD_REFS,
34
+ FIELD_ROLE,
35
+ FIELD_SERVICE_NAME,
36
+ FIELD_STATUS,
37
+ FIELD_STEP_NAME,
38
+ FIELD_THREAD_ID,
39
+ FIELD_THREAD_ID_ACK,
40
+ FIELD_THREAD_TOKEN,
41
+ STATUS_SUCCESS,
42
+ DEFAULT_PROCESSED_MAX_SIZE,
43
+ RefQuery,
44
+ first_non_empty,
45
+ require_non_empty,
46
+ )
47
+
48
+ logger = logging.getLogger("threadify")
49
+
50
+ NotificationHandler = Callable[["Notification"], Any]
51
+
52
+
53
+ class Connection:
54
+ def __init__(
55
+ self,
56
+ ws: Any,
57
+ api_key: str,
58
+ service_name: str,
59
+ graphql_url: str,
60
+ debug: bool = False,
61
+ max_in_flight: int = 10,
62
+ logger: logging.Logger | None = None,
63
+ ):
64
+ self._ws = ws
65
+ self._api_key = api_key
66
+ self._service_name = service_name
67
+ self._graphql_url = graphql_url
68
+ self._debug = debug
69
+ self._max_in_flight = max_in_flight
70
+ self._logger = logger or logging.getLogger("threadify")
71
+ self._connected = True
72
+
73
+ self._threads: dict[str, Any] = {}
74
+
75
+ self._notification_handlers: dict[str, list[NotificationHandler]] = {}
76
+ self._active_subscriptions: dict[str, list[str]] = {}
77
+ self._processed_notifications: set[str] = set()
78
+ self._processed_notifications_max_size: int = DEFAULT_PROCESSED_MAX_SIZE
79
+
80
+ self._recv_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
81
+
82
+ self._data_retriever: DataRetriever | None = None
83
+
84
+ self._listener_task = asyncio.ensure_future(self._read_loop())
85
+
86
+ @property
87
+ def service_name(self) -> str:
88
+ return self._service_name
89
+
90
+ @property
91
+ def is_connected(self) -> bool:
92
+ return self._connected
93
+
94
+ async def _read_loop(self) -> None:
95
+ try:
96
+ async for raw in self._ws:
97
+ try:
98
+ msg = json.loads(raw)
99
+ except (json.JSONDecodeError, TypeError):
100
+ continue
101
+
102
+ action = msg.get(FIELD_ACTION, "")
103
+
104
+ if action == ACTION_NOTIFICATION:
105
+ self._handle_notification(
106
+ msg.get(FIELD_NOTIFICATION, {}),
107
+ msg.get(FIELD_ACK_TOKEN, ""),
108
+ )
109
+ elif action == ACTION_NOTIFICATION_BATCH:
110
+ for n in msg.get(FIELD_NOTIFICATIONS, []):
111
+ if isinstance(n, dict):
112
+ self._handle_notification(n, "")
113
+ else:
114
+ await self._recv_queue.put(msg)
115
+ except Exception as exc:
116
+ self._logger.error(f"readLoop error: {exc}")
117
+ finally:
118
+ self._connected = False
119
+
120
+ async def _wait_response(
121
+ self, match: Callable[[dict], bool], timeout: float = 10.0
122
+ ) -> dict[str, Any]:
123
+ """Wait for a response matching the predicate."""
124
+ deadline = asyncio.get_event_loop().time() + timeout
125
+ while True:
126
+ remaining = deadline - asyncio.get_event_loop().time()
127
+ if remaining <= 0:
128
+ raise asyncio.TimeoutError("response timeout")
129
+
130
+ msg = await asyncio.wait_for(self._recv_queue.get(), timeout=remaining)
131
+ if match(msg):
132
+ return msg
133
+ await self._recv_queue.put(msg)
134
+
135
+ async def _send(self, msg: dict[str, Any]) -> None:
136
+ """Send a JSON message over the WebSocket."""
137
+ if not self._connected:
138
+ raise ConnectionError("WebSocket is not connected")
139
+ await self._ws.send(json.dumps(msg))
140
+
141
+ async def start(
142
+ self,
143
+ label: str = "",
144
+ contract_name: str = "",
145
+ service_name: str = "",
146
+ refs: dict[str, Any] | None = None,
147
+ ) -> ThreadInstance:
148
+ from threadify.thread import ThreadInstance
149
+
150
+ if not self._connected:
151
+ raise ConnectionError("Not connected. Call Threadify.connect() first.")
152
+
153
+ effective_service = first_non_empty(service_name, self._service_name)
154
+
155
+ # Prepare refs
156
+ message_refs = (refs or {}).copy()
157
+ message_refs[FIELD_SERVICE_NAME] = effective_service
158
+ if label:
159
+ message_refs["label"] = label
160
+
161
+ msg: dict[str, Any] = {
162
+ FIELD_ACTION: ACTION_START_THREAD,
163
+ FIELD_REFS: message_refs,
164
+ }
165
+
166
+ if contract_name:
167
+ msg[FIELD_CONTRACT_NAME] = contract_name
168
+ role = ""
169
+ # Check if role is in refs or if we should derive it
170
+ if effective_service:
171
+ role = effective_service.removesuffix("-service")
172
+ else:
173
+ role = "participant"
174
+
175
+ msg[FIELD_ROLE] = role
176
+
177
+ await self._send(msg)
178
+
179
+ resp = await self._wait_response(lambda m: m.get(FIELD_ACTION) == ACTION_START_THREAD)
180
+
181
+ if resp.get(FIELD_STATUS) != STATUS_SUCCESS:
182
+ raise RuntimeError(resp.get(FIELD_MESSAGE, "failed to start thread"))
183
+
184
+ thread_id = resp[FIELD_THREAD_ID]
185
+ thread = ThreadInstance(self, thread_id, contract_name, "", None)
186
+ self._threads[thread_id] = thread
187
+ self._logger.debug(f"Thread started: {thread_id}")
188
+ return thread
189
+
190
+ async def join(
191
+ self,
192
+ token_or_thread_id: str | None = None,
193
+ role: str = "",
194
+ *,
195
+ token: str | None = None,
196
+ thread_id: str | None = None,
197
+ ) -> ThreadInstance:
198
+ from threadify.thread import ThreadInstance
199
+
200
+ if not self._connected:
201
+ raise ConnectionError("Not connected. Call Threadify.connect() first.")
202
+
203
+ msg: dict[str, Any] = {FIELD_ACTION: ACTION_JOIN_THREAD}
204
+
205
+ if token is not None:
206
+ if token_or_thread_id is not None or thread_id is not None:
207
+ raise ValueError("token cannot be combined with token_or_thread_id/thread_id")
208
+ require_non_empty("token", token)
209
+ msg[FIELD_THREAD_TOKEN] = token
210
+ elif thread_id is not None:
211
+ require_non_empty("thread_id", thread_id)
212
+ require_non_empty("role", role)
213
+ msg[FIELD_THREAD_ID] = thread_id
214
+ msg[FIELD_ROLE] = role
215
+ elif token_or_thread_id is not None:
216
+ require_non_empty("token_or_thread_id", token_or_thread_id)
217
+ if role:
218
+ msg[FIELD_THREAD_ID] = token_or_thread_id
219
+ msg[FIELD_ROLE] = role
220
+ else:
221
+ msg[FIELD_THREAD_TOKEN] = token_or_thread_id
222
+ else:
223
+ raise ValueError("provide token, thread_id+role, or token_or_thread_id")
224
+
225
+ await self._send(msg)
226
+
227
+ resp = await self._wait_response(lambda m: m.get(FIELD_ACTION) == ACTION_JOIN_THREAD)
228
+
229
+ if resp.get(FIELD_STATUS) != STATUS_SUCCESS:
230
+ raise RuntimeError(resp.get(FIELD_MESSAGE, "failed to join thread"))
231
+
232
+ thread_id = resp[FIELD_THREAD_ID]
233
+ thread_role = resp.get(FIELD_ROLE, "")
234
+ thread = ThreadInstance(self, thread_id, resp.get("contractId", ""), thread_role, None)
235
+ self._threads[thread_id] = thread
236
+ self._logger.debug(f"Joined thread: {thread_id}, Role: {thread_role}")
237
+ return thread
238
+
239
+ async def close(self) -> None:
240
+ if not self._connected:
241
+ await self._ws.close()
242
+ return
243
+
244
+ msg = {FIELD_ACTION: ACTION_CLOSE_CONNECTION}
245
+ try:
246
+ await self._send(msg)
247
+ await self._wait_response(
248
+ lambda m: m.get(FIELD_ACTION) == ACTION_CLOSE_CONNECTION,
249
+ timeout=5.0,
250
+ )
251
+ except Exception:
252
+ pass
253
+ finally:
254
+ self._connected = False
255
+ self._listener_task.cancel()
256
+ try:
257
+ await self._listener_task
258
+ except asyncio.CancelledError:
259
+ pass
260
+ await self._ws.close()
261
+
262
+ def subscribe(
263
+ self,
264
+ event: str,
265
+ step_name_or_handler: str | NotificationHandler | None = None,
266
+ handler: NotificationHandler | None = None,
267
+ ) -> "Connection":
268
+ """Subscribe to notifications for a step or thread-level event.
269
+
270
+ Supports two signatures:
271
+
272
+ - ``subscribe(event, handler)`` — thread-level subscription.
273
+ - ``subscribe(event, step_name, handler)`` — step-level subscription.
274
+ """
275
+ # Determine signature: 2-param (thread-level) vs 3-param (step-level)
276
+ step_name: str
277
+ actual_handler: NotificationHandler
278
+
279
+ if callable(step_name_or_handler):
280
+ # 2-param: subscribe(event, handler)
281
+ step_name = "global"
282
+ actual_handler = step_name_or_handler
283
+ else:
284
+ # 3-param: subscribe(event, step_name, handler)
285
+ step_name = step_name_or_handler or "global"
286
+ actual_handler = handler # type: ignore[assignment]
287
+
288
+ if actual_handler is None:
289
+ raise ValueError("handler cannot be None")
290
+
291
+ source, event_type = _parse_event(event)
292
+ event_types = _build_event_types(source, event_type)
293
+ asyncio.ensure_future(self._send_subscription(step_name, event_types))
294
+
295
+ key = f"{event}:{step_name}"
296
+ self._notification_handlers.setdefault(key, []).append(actual_handler)
297
+ return self
298
+
299
+ def unsubscribe(self, event: str, step_name: str = "") -> "Connection":
300
+ """Unsubscribe from notifications.
301
+
302
+ Args:
303
+ event: Event pattern to unsubscribe.
304
+ step_name: Step name (default: global / thread-level).
305
+ """
306
+ target_step = step_name or "global"
307
+ key = f"{event}:{target_step}"
308
+ self._notification_handlers.pop(key, None)
309
+
310
+ has_handlers = any(k.endswith(f":{target_step}") for k in self._notification_handlers)
311
+ if not has_handlers:
312
+ asyncio.ensure_future(self._send_unsubscription(target_step))
313
+ return self
314
+
315
+ async def _send_subscription(self, step_name: str, event_types: list[str]) -> None:
316
+ if not self._connected:
317
+ return
318
+
319
+ existing = self._active_subscriptions.get(step_name, [])
320
+ merged = _merge_unique(existing, event_types)
321
+ if set(existing) == set(merged):
322
+ return
323
+
324
+ try:
325
+ await self._send(
326
+ {
327
+ FIELD_ACTION: ACTION_SUBSCRIBE,
328
+ FIELD_STEP_NAME: step_name,
329
+ FIELD_EVENT_TYPES: merged,
330
+ }
331
+ )
332
+ except Exception:
333
+ pass
334
+
335
+ self._active_subscriptions[step_name] = merged
336
+
337
+ async def _send_unsubscription(self, step_name: str) -> None:
338
+ if not self._connected:
339
+ return
340
+ try:
341
+ await self._send(
342
+ {
343
+ FIELD_ACTION: ACTION_UNSUBSCRIBE,
344
+ FIELD_STEP_NAME: step_name,
345
+ }
346
+ )
347
+ except Exception:
348
+ pass
349
+ self._active_subscriptions.pop(step_name, None)
350
+
351
+ def _handle_notification(self, data: dict, ack_token: str) -> None:
352
+ from threadify.notification import Notification
353
+
354
+ if not data:
355
+ return
356
+
357
+ notif_id = data.get(FIELD_NOTIFICATION_ID, "")
358
+
359
+ if notif_id in self._processed_notifications:
360
+ self._logger.debug(f"Duplicate notification ignored: {notif_id}")
361
+ self._send_ack(notif_id, data.get(FIELD_THREAD_ID, ""), ack_token)
362
+ return
363
+
364
+ self._processed_notifications.add(notif_id)
365
+
366
+ # Prevent memory leak — remove oldest if too large
367
+ if len(self._processed_notifications) > self._processed_notifications_max_size:
368
+ # Sets are unordered, but we can pop an arbitrary item
369
+ self._processed_notifications.pop()
370
+
371
+ notif = Notification(data, self, ack_token)
372
+
373
+ event_pattern = self._get_event_pattern(notif)
374
+ self._trigger_handlers(event_pattern, notif)
375
+
376
+ thread = self._threads.get(notif.thread_id)
377
+ if thread:
378
+ thread._handle_notification(notif)
379
+
380
+ def _get_event_pattern(self, notif: Notification) -> str:
381
+ source = notif.source or "execution"
382
+ event_type = STATUS_SUCCESS
383
+ if notif.notification_type:
384
+ parts = notif.notification_type.split(".", 1)
385
+ if len(parts) == 2:
386
+ event_type = parts[1]
387
+
388
+ source_map = {
389
+ "execution": "step",
390
+ "validation": "rule",
391
+ "thread": "thread",
392
+ }
393
+ sdk_source = source_map.get(source, source)
394
+ return f"{sdk_source}.{event_type}"
395
+
396
+ def _trigger_handlers(self, event_pattern: str, notif: Notification) -> None:
397
+ step_name = notif.step_name
398
+ contract_name = notif.contract_name
399
+ source = event_pattern.split(".", 1)[0]
400
+
401
+ keys_to_check = []
402
+ if contract_name:
403
+ keys_to_check.append(f"{event_pattern}:{contract_name}@{step_name}")
404
+ keys_to_check.append(f"{event_pattern}:{step_name}")
405
+ keys_to_check.append(f"{source}.*:{step_name}")
406
+ keys_to_check.append(f"*:{step_name}")
407
+
408
+ for key in keys_to_check:
409
+ handlers = self._notification_handlers.get(key, [])
410
+ for handler in handlers:
411
+ try:
412
+ handler(notif)
413
+ except Exception as exc:
414
+ self._logger.error(f"Notification handler error: {exc}")
415
+
416
+ def _send_ack(self, notification_id: str, thread_id: str, ack_token: str) -> None:
417
+ if not ack_token:
418
+ return
419
+ try:
420
+ asyncio.ensure_future(
421
+ self._send(
422
+ {
423
+ FIELD_ACTION: ACTION_ACK_NOTIFICATION,
424
+ FIELD_NOTIFICATION_ID_ACK: notification_id,
425
+ FIELD_THREAD_ID_ACK: thread_id,
426
+ FIELD_ACK_TOKEN: ack_token,
427
+ FIELD_PROCESSED: True,
428
+ }
429
+ )
430
+ )
431
+ except Exception:
432
+ pass
433
+
434
+ def _get_data_retriever(self) -> DataRetriever:
435
+ from threadify.data_retriever import DataRetriever
436
+
437
+ if self._data_retriever is None:
438
+ if not self._graphql_url:
439
+ raise RuntimeError("GraphQL URL not configured")
440
+ self._data_retriever = DataRetriever(self._graphql_url, self._api_key)
441
+ return self._data_retriever
442
+
443
+ async def get_thread(self, thread_id: str) -> ArchivedThread:
444
+ return await self._get_data_retriever().get_thread(thread_id)
445
+
446
+ async def get_thread_by_ref(self, ref_key: str, ref_value: str) -> ArchivedThread | None:
447
+ threads = await self._get_data_retriever().get_threads_by_ref(
448
+ RefQuery(ref_key=ref_key, ref_value=ref_value, limit=1)
449
+ )
450
+ return threads[0] if threads else None
451
+
452
+ async def get_threads_by_ref(self, query: RefQuery) -> list[ArchivedThread]:
453
+ return await self._get_data_retriever().get_threads_by_ref(query)
454
+
455
+ async def get_validation_results(self, thread_id: str, step_name: str = "") -> list[dict[str, Any]]:
456
+ return await self._get_data_retriever().get_validation_results(thread_id, step_name)
457
+
458
+ async def get_thread_chain(self, root_id: str, max_depth: int = 3) -> list[ArchivedThread]:
459
+ return await self._get_data_retriever().get_thread_chain(root_id, max_depth)
460
+
461
+ def create_span_exporter(self, options: dict[str, Any] | None = None) -> Any:
462
+ """Create an OpenTelemetry SpanExporter wired to this connection.
463
+
464
+ Args:
465
+ options: Optional configuration dict. Supported keys:
466
+ - ``refs``: list of attribute keys to map to Threadify refs.
467
+
468
+ Returns:
469
+ A :class:`~threadify.otel_exporter.ThreadifySpanExporter` instance.
470
+ """
471
+ from threadify.otel_exporter import ThreadifySpanExporter
472
+
473
+ return ThreadifySpanExporter(self, options or {})
474
+
475
+ async def reconnect(self) -> None:
476
+ """Resubscribe to all active subscriptions after a reconnection."""
477
+ if not self._connected:
478
+ raise RuntimeError("Not connected")
479
+ for step_name, event_types in self._active_subscriptions.items():
480
+ await self._send_subscription(step_name, event_types)
481
+
482
+ def _remove_thread(self, thread_id: str) -> None:
483
+ self._threads.pop(thread_id, None)
484
+
485
+
486
+ def _parse_event(event: str) -> tuple[str, str]:
487
+ normalized = event.replace("step", "execution", 1).replace("rule", "validation", 1)
488
+ parts = normalized.split(".", 1)
489
+ source = parts[0] if parts[0] else "*"
490
+ event_type = parts[1] if len(parts) > 1 and parts[1] else "*"
491
+ return source, event_type
492
+
493
+
494
+ def _build_event_types(source: str, event_type: str) -> list[str]:
495
+ if source == "*" and event_type == "*":
496
+ return ["execution.success", "execution.failed", "validation.passed", "validation.violated"]
497
+ if source == "execution" and event_type == "*":
498
+ return ["execution.success", "execution.failed"]
499
+ if source == "validation" and event_type == "*":
500
+ return ["validation.passed", "validation.violated"]
501
+ return [f"{source}.{event_type}"]
502
+
503
+
504
+ def _merge_unique(a: list[str], b: list[str]) -> list[str]:
505
+ return list(set(a) | set(b))