langgraph-api 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

Files changed (86) hide show
  1. LICENSE +93 -0
  2. langgraph_api/__init__.py +0 -0
  3. langgraph_api/api/__init__.py +63 -0
  4. langgraph_api/api/assistants.py +326 -0
  5. langgraph_api/api/meta.py +71 -0
  6. langgraph_api/api/openapi.py +32 -0
  7. langgraph_api/api/runs.py +463 -0
  8. langgraph_api/api/store.py +116 -0
  9. langgraph_api/api/threads.py +263 -0
  10. langgraph_api/asyncio.py +201 -0
  11. langgraph_api/auth/__init__.py +0 -0
  12. langgraph_api/auth/langsmith/__init__.py +0 -0
  13. langgraph_api/auth/langsmith/backend.py +67 -0
  14. langgraph_api/auth/langsmith/client.py +145 -0
  15. langgraph_api/auth/middleware.py +41 -0
  16. langgraph_api/auth/noop.py +14 -0
  17. langgraph_api/cli.py +209 -0
  18. langgraph_api/config.py +70 -0
  19. langgraph_api/cron_scheduler.py +60 -0
  20. langgraph_api/errors.py +52 -0
  21. langgraph_api/graph.py +314 -0
  22. langgraph_api/http.py +168 -0
  23. langgraph_api/http_logger.py +89 -0
  24. langgraph_api/js/.gitignore +2 -0
  25. langgraph_api/js/build.mts +49 -0
  26. langgraph_api/js/client.mts +849 -0
  27. langgraph_api/js/global.d.ts +6 -0
  28. langgraph_api/js/package.json +33 -0
  29. langgraph_api/js/remote.py +673 -0
  30. langgraph_api/js/server_sent_events.py +126 -0
  31. langgraph_api/js/src/graph.mts +88 -0
  32. langgraph_api/js/src/hooks.mjs +12 -0
  33. langgraph_api/js/src/parser/parser.mts +443 -0
  34. langgraph_api/js/src/parser/parser.worker.mjs +12 -0
  35. langgraph_api/js/src/schema/types.mts +2136 -0
  36. langgraph_api/js/src/schema/types.template.mts +74 -0
  37. langgraph_api/js/src/utils/importMap.mts +85 -0
  38. langgraph_api/js/src/utils/pythonSchemas.mts +28 -0
  39. langgraph_api/js/src/utils/serde.mts +21 -0
  40. langgraph_api/js/tests/api.test.mts +1566 -0
  41. langgraph_api/js/tests/compose-postgres.yml +56 -0
  42. langgraph_api/js/tests/graphs/.gitignore +1 -0
  43. langgraph_api/js/tests/graphs/agent.mts +127 -0
  44. langgraph_api/js/tests/graphs/error.mts +17 -0
  45. langgraph_api/js/tests/graphs/langgraph.json +8 -0
  46. langgraph_api/js/tests/graphs/nested.mts +44 -0
  47. langgraph_api/js/tests/graphs/package.json +7 -0
  48. langgraph_api/js/tests/graphs/weather.mts +57 -0
  49. langgraph_api/js/tests/graphs/yarn.lock +159 -0
  50. langgraph_api/js/tests/parser.test.mts +870 -0
  51. langgraph_api/js/tests/utils.mts +17 -0
  52. langgraph_api/js/yarn.lock +1340 -0
  53. langgraph_api/lifespan.py +41 -0
  54. langgraph_api/logging.py +121 -0
  55. langgraph_api/metadata.py +101 -0
  56. langgraph_api/models/__init__.py +0 -0
  57. langgraph_api/models/run.py +229 -0
  58. langgraph_api/patch.py +42 -0
  59. langgraph_api/queue.py +245 -0
  60. langgraph_api/route.py +118 -0
  61. langgraph_api/schema.py +190 -0
  62. langgraph_api/serde.py +124 -0
  63. langgraph_api/server.py +48 -0
  64. langgraph_api/sse.py +118 -0
  65. langgraph_api/state.py +67 -0
  66. langgraph_api/stream.py +289 -0
  67. langgraph_api/utils.py +60 -0
  68. langgraph_api/validation.py +141 -0
  69. langgraph_api-0.0.1.dist-info/LICENSE +93 -0
  70. langgraph_api-0.0.1.dist-info/METADATA +26 -0
  71. langgraph_api-0.0.1.dist-info/RECORD +86 -0
  72. langgraph_api-0.0.1.dist-info/WHEEL +4 -0
  73. langgraph_api-0.0.1.dist-info/entry_points.txt +3 -0
  74. langgraph_license/__init__.py +0 -0
  75. langgraph_license/middleware.py +21 -0
  76. langgraph_license/validation.py +11 -0
  77. langgraph_storage/__init__.py +0 -0
  78. langgraph_storage/checkpoint.py +94 -0
  79. langgraph_storage/database.py +190 -0
  80. langgraph_storage/ops.py +1523 -0
  81. langgraph_storage/queue.py +108 -0
  82. langgraph_storage/retry.py +27 -0
  83. langgraph_storage/store.py +28 -0
  84. langgraph_storage/ttl_dict.py +54 -0
  85. logging.json +22 -0
  86. openapi.json +4304 -0
@@ -0,0 +1,1523 @@
1
+ """Implementation of the LangGraph API using in-memory checkpointer & store."""
2
+
3
+ import asyncio
4
+ import base64
5
+ import copy
6
+ import json
7
+ import logging
8
+ import uuid
9
+ from collections import defaultdict
10
+ from collections.abc import AsyncIterator, Sequence
11
+ from contextlib import asynccontextmanager
12
+ from copy import deepcopy
13
+ from datetime import UTC, datetime, timedelta
14
+ from typing import Any, Literal
15
+ from uuid import UUID, uuid4
16
+
17
+ import structlog
18
+ from langgraph.pregel.debug import CheckpointPayload
19
+ from langgraph.pregel.types import StateSnapshot
20
+ from starlette.exceptions import HTTPException
21
+
22
+ from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
23
+ from langgraph_api.errors import UserInterrupt, UserRollback
24
+ from langgraph_api.graph import get_graph
25
+ from langgraph_api.schema import (
26
+ Assistant,
27
+ Checkpoint,
28
+ Config,
29
+ Cron,
30
+ IfNotExists,
31
+ MetadataInput,
32
+ MetadataValue,
33
+ MultitaskStrategy,
34
+ OnConflictBehavior,
35
+ QueueStats,
36
+ Run,
37
+ RunStatus,
38
+ StreamMode,
39
+ Thread,
40
+ ThreadStatus,
41
+ ThreadUpdateResponse,
42
+ )
43
+ from langgraph_api.serde import Fragment
44
+ from langgraph_api.utils import fetchone
45
+ from langgraph_storage.checkpoint import Checkpointer
46
+ from langgraph_storage.database import InMemConnectionProto, connect
47
+ from langgraph_storage.queue import Message, get_stream_manager
48
+
49
+ logger = structlog.stdlib.get_logger(__name__)
50
+
51
+
52
+ def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
53
+ if isinstance(id_, str):
54
+ return uuid.UUID(id_)
55
+ if id_ is None:
56
+ return uuid4()
57
+ return id_
58
+
59
+
60
+ # Right now the whole API types as UUID but frequently passes a str
61
+ # We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
62
+ # which we leave as strings. This is because I'm too lazy to subclass fully
63
+ # and we use non-UUID examples in the OSS version
64
+
65
+
66
+ class Assistants:
67
+ @staticmethod
68
+ async def search(
69
+ conn: InMemConnectionProto,
70
+ *,
71
+ graph_id: str | None,
72
+ metadata: MetadataInput,
73
+ limit: int,
74
+ offset: int,
75
+ ) -> AsyncIterator[Assistant]:
76
+ async def filter_and_yield() -> AsyncIterator[Assistant]:
77
+ assistants = conn.store["assistants"]
78
+ filtered_assistants = [
79
+ assistant
80
+ for assistant in assistants
81
+ if (not graph_id or assistant["graph_id"] == graph_id)
82
+ and (
83
+ not metadata or is_jsonb_contained(assistant["metadata"], metadata)
84
+ )
85
+ ]
86
+ filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
87
+ for assistant in filtered_assistants[offset : offset + limit]:
88
+ yield assistant
89
+
90
+ return filter_and_yield()
91
+
92
+ @staticmethod
93
+ async def get(
94
+ conn: InMemConnectionProto, assistant_id: UUID
95
+ ) -> AsyncIterator[Assistant]:
96
+ """Get an assistant by ID."""
97
+ assistant_id = _ensure_uuid(assistant_id)
98
+
99
+ async def _yield_result():
100
+ for assistant in conn.store["assistants"]:
101
+ if assistant["assistant_id"] == assistant_id:
102
+ yield assistant
103
+
104
+ return _yield_result()
105
+
106
+ @staticmethod
107
+ async def put(
108
+ conn: InMemConnectionProto,
109
+ assistant_id: UUID,
110
+ *,
111
+ graph_id: str,
112
+ config: Config,
113
+ metadata: MetadataInput,
114
+ if_exists: OnConflictBehavior,
115
+ name: str,
116
+ ) -> AsyncIterator[Assistant]:
117
+ """Insert an assistant."""
118
+ assistant_id = _ensure_uuid(assistant_id)
119
+ existing_assistant = next(
120
+ (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
121
+ None,
122
+ )
123
+ if existing_assistant:
124
+ if if_exists == "raise":
125
+ raise HTTPException(
126
+ status_code=409, detail=f"Assistant {assistant_id} already exists"
127
+ )
128
+ elif if_exists == "do_nothing":
129
+
130
+ async def _yield_existing():
131
+ yield existing_assistant
132
+
133
+ return _yield_existing()
134
+
135
+ now = datetime.now(UTC)
136
+ new_assistant: Assistant = {
137
+ "assistant_id": assistant_id,
138
+ "graph_id": graph_id,
139
+ "config": config or {},
140
+ "metadata": metadata or {},
141
+ "name": name,
142
+ "created_at": now,
143
+ "updated_at": now,
144
+ "version": 1,
145
+ }
146
+ new_version = {
147
+ "assistant_id": assistant_id,
148
+ "version": 1,
149
+ "graph_id": graph_id,
150
+ "config": config or {},
151
+ "metadata": metadata or {},
152
+ "created_at": now,
153
+ }
154
+ conn.store["assistants"].append(new_assistant)
155
+ conn.store["assistant_versions"].append(new_version)
156
+
157
+ async def _yield_new():
158
+ yield new_assistant
159
+
160
+ return _yield_new()
161
+
162
+ @staticmethod
163
+ async def patch(
164
+ conn: InMemConnectionProto,
165
+ assistant_id: UUID,
166
+ *,
167
+ config: dict | None = None,
168
+ graph_id: str | None = None,
169
+ metadata: MetadataInput | None = None,
170
+ name: str | None = None,
171
+ ) -> AsyncIterator[Assistant]:
172
+ """Update an assistant.
173
+
174
+ Args:
175
+ assistant_id: The assistant ID.
176
+ graph_id: The graph ID.
177
+ config: The assistant config.
178
+ metadata: The assistant metadata.
179
+ name: The assistant name.
180
+
181
+ Returns:
182
+ return the updated assistant model.
183
+ """
184
+ assistant_id = _ensure_uuid(assistant_id)
185
+ assistant = next(
186
+ (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
187
+ None,
188
+ )
189
+ if not assistant:
190
+ raise HTTPException(
191
+ status_code=404, detail=f"Assistant {assistant_id} not found"
192
+ )
193
+
194
+ now = datetime.now(UTC)
195
+ new_version = (
196
+ max(
197
+ v["version"]
198
+ for v in conn.store["assistant_versions"]
199
+ if v["assistant_id"] == assistant_id
200
+ )
201
+ + 1
202
+ if conn.store["assistant_versions"]
203
+ else 1
204
+ )
205
+
206
+ # Update assistant_versions table
207
+ new_version_entry = {
208
+ "assistant_id": assistant_id,
209
+ "version": new_version,
210
+ "graph_id": graph_id if graph_id is not None else assistant["graph_id"],
211
+ "config": config if config is not None else assistant["config"],
212
+ "metadata": metadata if metadata is not None else assistant["metadata"],
213
+ "created_at": now,
214
+ }
215
+ conn.store["assistant_versions"].append(new_version_entry)
216
+
217
+ # Update assistants table
218
+ assistant.update(
219
+ {
220
+ "graph_id": new_version_entry["graph_id"],
221
+ "config": new_version_entry["config"],
222
+ "metadata": new_version_entry["metadata"],
223
+ "name": name if name is not None else assistant["name"],
224
+ "updated_at": now,
225
+ "version": new_version,
226
+ }
227
+ )
228
+
229
+ async def _yield_updated():
230
+ yield assistant
231
+
232
+ return _yield_updated()
233
+
234
+ @staticmethod
235
+ async def delete(
236
+ conn: InMemConnectionProto, assistant_id: UUID
237
+ ) -> AsyncIterator[UUID]:
238
+ """Delete an assistant by ID."""
239
+ assistant_id = _ensure_uuid(assistant_id)
240
+ conn.store["assistants"] = [
241
+ a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
242
+ ]
243
+ # Cascade delete assistant versions, crons, & runs on this assistant
244
+ conn.store["assistant_versions"] = [
245
+ v
246
+ for v in conn.store["assistant_versions"]
247
+ if v["assistant_id"] != assistant_id
248
+ ]
249
+ retained = []
250
+ for run in conn.store["runs"]:
251
+ if run["assistant_id"] == assistant_id:
252
+ res = await Runs.delete(conn, run["run_id"], thread_id=run["thread_id"])
253
+ await anext(res)
254
+
255
+ else:
256
+ retained.append(run)
257
+
258
+ async def _yield_deleted():
259
+ yield assistant_id
260
+
261
+ return _yield_deleted()
262
+
263
+ @staticmethod
264
+ async def set_latest(
265
+ conn: InMemConnectionProto, assistant_id: UUID, version: int
266
+ ) -> AsyncIterator[Assistant]:
267
+ """Change the version of an assistant."""
268
+ assistant_id = _ensure_uuid(assistant_id)
269
+ assistant = next(
270
+ (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
271
+ None,
272
+ )
273
+ if not assistant:
274
+ raise HTTPException(
275
+ status_code=404, detail=f"Assistant {assistant_id} not found"
276
+ )
277
+
278
+ version_data = next(
279
+ (
280
+ v
281
+ for v in conn.store["assistant_versions"]
282
+ if v["assistant_id"] == assistant_id and v["version"] == version
283
+ ),
284
+ None,
285
+ )
286
+ if not version_data:
287
+ raise HTTPException(
288
+ status_code=404,
289
+ detail=f"Version {version} not found for assistant {assistant_id}",
290
+ )
291
+
292
+ assistant.update(
293
+ {
294
+ "config": version_data["config"],
295
+ "metadata": version_data["metadata"],
296
+ "version": version_data["version"],
297
+ "updated_at": datetime.now(UTC),
298
+ }
299
+ )
300
+
301
+ async def _yield_updated():
302
+ yield assistant
303
+
304
+ return _yield_updated()
305
+
306
+ @staticmethod
307
+ async def get_versions(
308
+ conn: InMemConnectionProto,
309
+ assistant_id: UUID,
310
+ metadata: MetadataInput,
311
+ limit: int,
312
+ offset: int,
313
+ ) -> AsyncIterator[Assistant]:
314
+ """Get all versions of an assistant."""
315
+ assistant_id = _ensure_uuid(assistant_id)
316
+ versions = [
317
+ v
318
+ for v in conn.store["assistant_versions"]
319
+ if v["assistant_id"] == assistant_id
320
+ and (not metadata or is_jsonb_contained(v["metadata"], metadata))
321
+ ]
322
+ versions.sort(key=lambda x: x["version"], reverse=True)
323
+
324
+ async def _yield_versions():
325
+ for version in versions[offset : offset + limit]:
326
+ yield version
327
+
328
+ return _yield_versions()
329
+
330
+
331
+ def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
332
+ """
333
+ Implements Postgres' @> (containment) operator for dictionaries.
334
+ Returns True if superset contains all key/value pairs from subset.
335
+ """
336
+ for key, value in subset.items():
337
+ if key not in superset:
338
+ return False
339
+ if isinstance(value, dict) and isinstance(superset[key], dict):
340
+ if not is_jsonb_contained(superset[key], value):
341
+ return False
342
+ elif superset[key] != value:
343
+ return False
344
+ return True
345
+
346
+
347
+ def bytes_decoder(obj):
348
+ """Custom JSON decoder that converts base64 back to bytes."""
349
+ if "__type__" in obj and obj["__type__"] == "bytes":
350
+ return base64.b64decode(obj["value"].encode("utf-8"))
351
+ return obj
352
+
353
+
354
+ def _replace_thread_id(data, new_thread_id, thread_id):
355
+ class BytesEncoder(json.JSONEncoder):
356
+ """Custom JSON encoder that handles bytes by converting them to base64."""
357
+
358
+ def default(self, obj):
359
+ if isinstance(obj, bytes | bytearray):
360
+ return {
361
+ "__type__": "bytes",
362
+ "value": base64.b64encode(
363
+ obj.replace(
364
+ str(thread_id).encode(), str(new_thread_id).encode()
365
+ )
366
+ ).decode("utf-8"),
367
+ }
368
+
369
+ return super().default(obj)
370
+
371
+ try:
372
+ json_str = json.dumps(data, cls=BytesEncoder, indent=2)
373
+ except Exception as e:
374
+ raise ValueError(data) from e
375
+ json_str = json_str.replace(str(thread_id), str(new_thread_id))
376
+
377
+ # Decoding back from JSON
378
+ d = json.loads(json_str, object_hook=bytes_decoder)
379
+ return d
380
+
381
+
382
+ class Threads:
383
+ @staticmethod
384
+ async def search(
385
+ conn: InMemConnectionProto,
386
+ *,
387
+ metadata: MetadataInput,
388
+ values: MetadataInput,
389
+ status: ThreadStatus | None,
390
+ limit: int,
391
+ offset: int,
392
+ ) -> AsyncIterator[Thread]:
393
+ threads = conn.store["threads"]
394
+ filtered_threads: list[Thread] = []
395
+
396
+ # Apply filters
397
+ for thread in threads:
398
+ matches = True
399
+
400
+ if metadata and not is_jsonb_contained(thread["metadata"], metadata):
401
+ matches = False
402
+
403
+ if (
404
+ values
405
+ and "values" in thread
406
+ and not is_jsonb_contained(thread["values"], values)
407
+ ):
408
+ matches = False
409
+
410
+ if status and thread.get("status") != status:
411
+ matches = False
412
+
413
+ if matches:
414
+ filtered_threads.append(thread)
415
+
416
+ # Sort by created_at in descending order
417
+ sorted_threads = sorted(
418
+ filtered_threads, key=lambda x: x["created_at"], reverse=True
419
+ )
420
+
421
+ # Apply limit and offset
422
+ paginated_threads = sorted_threads[offset : offset + limit]
423
+
424
+ async def thread_iterator() -> AsyncIterator[Thread]:
425
+ for thread in paginated_threads:
426
+ yield thread
427
+
428
+ return thread_iterator()
429
+
430
+ @staticmethod
431
+ async def get(conn: InMemConnectionProto, thread_id: UUID) -> AsyncIterator[Thread]:
432
+ """Get a thread by ID."""
433
+ thread_id = _ensure_uuid(thread_id)
434
+ matching_thread = next(
435
+ (
436
+ thread
437
+ for thread in conn.store["threads"]
438
+ if thread["thread_id"] == thread_id
439
+ ),
440
+ None,
441
+ )
442
+ if not matching_thread:
443
+ raise HTTPException(
444
+ status_code=404, detail=f"Thread with ID {thread_id} not found"
445
+ )
446
+
447
+ async def _yield_result():
448
+ if matching_thread:
449
+ yield matching_thread
450
+
451
+ return _yield_result()
452
+
453
+ @staticmethod
454
+ async def put(
455
+ conn: InMemConnectionProto,
456
+ thread_id: UUID,
457
+ *,
458
+ metadata: MetadataInput,
459
+ if_exists: OnConflictBehavior,
460
+ ) -> AsyncIterator[Thread]:
461
+ """Insert or update a thread."""
462
+ thread_id = _ensure_uuid(thread_id)
463
+ if metadata is None:
464
+ metadata = {}
465
+
466
+ # Check if thread already exists
467
+ existing_thread = next(
468
+ (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
469
+ )
470
+
471
+ if existing_thread:
472
+ if if_exists == "raise":
473
+ raise HTTPException(
474
+ status_code=409, detail=f"Thread with ID {thread_id} already exists"
475
+ )
476
+ elif if_exists == "do_nothing":
477
+
478
+ async def _yield_existing():
479
+ yield existing_thread
480
+
481
+ return _yield_existing()
482
+
483
+ # Create new thread
484
+ new_thread: Thread = {
485
+ "thread_id": thread_id,
486
+ "created_at": datetime.now(UTC),
487
+ "updated_at": datetime.now(UTC),
488
+ "metadata": copy.deepcopy(metadata),
489
+ "status": "idle",
490
+ "config": {},
491
+ "values": None,
492
+ }
493
+
494
+ # Add to store
495
+ conn.store["threads"].append(new_thread)
496
+
497
+ async def _yield_new():
498
+ yield new_thread
499
+
500
+ return _yield_new()
501
+
502
+ @staticmethod
503
+ async def patch(
504
+ conn: InMemConnectionProto, thread_id: UUID, *, metadata: MetadataValue
505
+ ) -> AsyncIterator[Thread]:
506
+ """Update a thread."""
507
+ thread_list = conn.store["threads"]
508
+ thread_idx = None
509
+ thread_id = _ensure_uuid(thread_id)
510
+
511
+ for idx, thread in enumerate(thread_list):
512
+ if thread["thread_id"] == thread_id:
513
+ thread_idx = idx
514
+ break
515
+
516
+ if thread_idx is not None:
517
+ thread = copy.deepcopy(thread_list[thread_idx])
518
+ thread["metadata"] = {**thread["metadata"], **metadata}
519
+ thread["updated_at"] = datetime.now(UTC)
520
+ thread_list[thread_idx] = thread
521
+
522
+ async def thread_iterator() -> AsyncIterator[Thread]:
523
+ yield thread
524
+
525
+ return thread_iterator()
526
+
527
+ async def empty_iterator() -> AsyncIterator[Thread]:
528
+ if False: # This ensures the iterator is empty
529
+ yield
530
+
531
+ return empty_iterator()
532
+
533
+ @staticmethod
534
+ async def set_status(
535
+ conn: InMemConnectionProto,
536
+ thread_id: UUID,
537
+ checkpoint: CheckpointPayload | None,
538
+ exception: BaseException | None,
539
+ ) -> None:
540
+ """Set the status of a thread."""
541
+ thread_id = _ensure_uuid(thread_id)
542
+
543
+ async def has_pending_runs(conn_: InMemConnectionProto, tid: UUID) -> bool:
544
+ """Check if thread has any pending runs."""
545
+ return any(
546
+ run["status"] == "pending" and run["thread_id"] == tid
547
+ for run in conn_.store["runs"]
548
+ )
549
+
550
+ # Find the thread
551
+ thread = next(
552
+ (
553
+ thread
554
+ for thread in conn.store["threads"]
555
+ if thread["thread_id"] == thread_id
556
+ ),
557
+ None,
558
+ )
559
+
560
+ if not thread:
561
+ raise HTTPException(
562
+ status_code=404, detail=f"Thread {thread_id} not found."
563
+ )
564
+
565
+ # Determine has_next from checkpoint
566
+ has_next = False if checkpoint is None else bool(checkpoint["next"])
567
+
568
+ # Determine base status
569
+ if exception:
570
+ status = "error"
571
+ elif has_next:
572
+ status = "interrupted"
573
+ else:
574
+ status = "idle"
575
+
576
+ # Check for pending runs and update to busy if found
577
+ if await has_pending_runs(conn, thread_id):
578
+ status = "busy"
579
+
580
+ # Update thread
581
+ thread.update(
582
+ {
583
+ "updated_at": datetime.now(UTC),
584
+ "values": checkpoint["values"] if checkpoint else None,
585
+ "status": status,
586
+ }
587
+ )
588
+
589
+ @staticmethod
590
+ async def delete(
591
+ conn: InMemConnectionProto, thread_id: UUID
592
+ ) -> AsyncIterator[UUID]:
593
+ """Delete a thread by ID and cascade delete all associated runs."""
594
+ thread_list = conn.store["threads"]
595
+ thread_idx = None
596
+ thread_id = _ensure_uuid(thread_id)
597
+ conn.locks.pop(thread_id, None)
598
+
599
+ # Find the thread to delete
600
+ for idx, thread in enumerate(thread_list):
601
+ if thread["thread_id"] == thread_id:
602
+ thread_idx = idx
603
+ break
604
+ # Cascade delete all runs associated with this thread
605
+ conn.store["runs"] = [
606
+ run for run in conn.store["runs"] if run["thread_id"] != thread_id
607
+ ]
608
+ _delete_checkpoints_for_thread(thread_id, conn)
609
+
610
+ if thread_idx is not None:
611
+ # Remove the thread from the store
612
+ deleted_thread = thread_list.pop(thread_idx)
613
+
614
+ # Return an async iterator with the deleted thread_id
615
+ async def id_iterator() -> AsyncIterator[UUID]:
616
+ yield deleted_thread["thread_id"]
617
+
618
+ return id_iterator()
619
+
620
+ # If thread not found, return empty iterator
621
+ async def empty_iterator() -> AsyncIterator[UUID]:
622
+ if False: # This ensures the iterator is empty
623
+ yield
624
+
625
+ return empty_iterator()
626
+
627
+ @staticmethod
628
+ async def copy(
629
+ conn: InMemConnectionProto, thread_id: UUID
630
+ ) -> AsyncIterator[Thread]:
631
+ """Create a copy of an existing thread."""
632
+ thread_id = _ensure_uuid(thread_id)
633
+ new_thread_id = uuid4()
634
+
635
+ async with conn.pipeline():
636
+ # Find the original thread in our store
637
+ original_thread = next(
638
+ (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
639
+ )
640
+
641
+ if not original_thread:
642
+ return
643
+
644
+ # Create new thread with copied metadata
645
+ new_thread: Thread = {
646
+ "thread_id": new_thread_id,
647
+ "created_at": datetime.now(tz=UTC),
648
+ "updated_at": datetime.now(tz=UTC),
649
+ "metadata": deepcopy(original_thread["metadata"]),
650
+ "status": "idle",
651
+ "config": {},
652
+ }
653
+
654
+ # Add new thread to store
655
+ conn.store["threads"].append(new_thread)
656
+
657
+ checkpointer = Checkpointer(conn)
658
+ copied_storage = _replace_thread_id(
659
+ checkpointer.storage[str(thread_id)], new_thread_id, thread_id
660
+ )
661
+ checkpointer.storage[str(new_thread_id)] = copied_storage
662
+ # Copy the writes over (if any)
663
+ outer_keys = []
664
+ for k in checkpointer.writes:
665
+ if k[0] == str(thread_id):
666
+ outer_keys.append(k)
667
+ for tid, checkpoint_ns, checkpoint_id in outer_keys:
668
+ mapped = {
669
+ k: _replace_thread_id(v, new_thread_id, thread_id)
670
+ for k, v in checkpointer.writes[
671
+ (str(tid), checkpoint_ns, checkpoint_id)
672
+ ].items()
673
+ }
674
+
675
+ checkpointer.writes[
676
+ (str(new_thread_id), checkpoint_ns, checkpoint_id)
677
+ ] = mapped
678
+
679
+ async def row_generator() -> AsyncIterator[Thread]:
680
+ yield new_thread
681
+
682
+ return row_generator()
683
+
684
+ class State:
685
+ @staticmethod
686
+ async def get(
687
+ conn: InMemConnectionProto, config: Config, subgraphs: bool = False
688
+ ) -> StateSnapshot:
689
+ """Get state for a thread."""
690
+ checkpointer = Checkpointer(conn)
691
+ thread_id = _ensure_uuid(config["configurable"]["thread_id"])
692
+ thread_iter = await Threads.get(conn, thread_id)
693
+ thread = await anext(thread_iter)
694
+ checkpoint = await checkpointer.aget(config)
695
+
696
+ if not thread:
697
+ return StateSnapshot(
698
+ values={},
699
+ next=[],
700
+ config=None,
701
+ metadata=None,
702
+ created_at=None,
703
+ parent_config=None,
704
+ tasks=tuple(),
705
+ )
706
+
707
+ metadata = thread.get("metadata", {})
708
+ thread_config = thread.get("config", {})
709
+
710
+ if graph_id := metadata.get("graph_id"):
711
+ # format latest checkpoint for response
712
+ checkpointer.latest_iter = checkpoint
713
+ graph = get_graph(graph_id, thread_config, checkpointer=checkpointer)
714
+ result = await graph.aget_state(config, subgraphs=subgraphs)
715
+ if (
716
+ result.metadata is not None
717
+ and "checkpoint_ns" in result.metadata
718
+ and result.metadata["checkpoint_ns"] == ""
719
+ ):
720
+ result.metadata.pop("checkpoint_ns")
721
+ return result
722
+ else:
723
+ return StateSnapshot(
724
+ values={},
725
+ next=[],
726
+ config=None,
727
+ metadata=None,
728
+ created_at=None,
729
+ parent_config=None,
730
+ tasks=tuple(),
731
+ )
732
+
733
+ @staticmethod
734
+ async def post(
735
+ conn: InMemConnectionProto,
736
+ config: Config,
737
+ values: Sequence[dict] | dict[str, Any] | None,
738
+ as_node: str | None = None,
739
+ ) -> ThreadUpdateResponse:
740
+ """Add state to a thread."""
741
+
742
+ checkpointer = Checkpointer(conn)
743
+ thread_id = _ensure_uuid(config["configurable"]["thread_id"])
744
+ thread_iter = await Threads.get(conn, thread_id)
745
+ thread = await fetchone(
746
+ thread_iter, not_found_detail=f"Thread {thread_id} not found."
747
+ )
748
+ checkpoint = await checkpointer.aget(config)
749
+
750
+ if not thread:
751
+ raise HTTPException(status_code=404, detail="Thread not found")
752
+
753
+ metadata = thread["metadata"]
754
+ thread_config = thread["config"]
755
+
756
+ if graph_id := metadata.get("graph_id"):
757
+ config["configurable"].setdefault("graph_id", graph_id)
758
+
759
+ checkpointer.latest_iter = checkpoint
760
+ graph = get_graph(graph_id, thread_config, checkpointer=checkpointer)
761
+ update_config = config.copy()
762
+ update_config["configurable"] = {
763
+ **config["configurable"],
764
+ "checkpoint_ns": config["configurable"].get("checkpoint_ns", ""),
765
+ }
766
+ next_config = await graph.aupdate_state(
767
+ update_config, values, as_node=as_node
768
+ )
769
+
770
+ # Get current state
771
+ state = await Threads.State.get(conn, config, subgraphs=False)
772
+ # Update thread values
773
+ for thread in conn.store["threads"]:
774
+ if thread["thread_id"] == thread_id:
775
+ thread["values"] = state.values
776
+ break
777
+
778
+ return ThreadUpdateResponse(
779
+ checkpoint=next_config["configurable"],
780
+ # Including deprecated fields
781
+ configurable=next_config["configurable"],
782
+ checkpoint_id=next_config["configurable"]["checkpoint_id"],
783
+ )
784
+ else:
785
+ raise HTTPException(status_code=400, detail="Thread has no graph ID.")
786
+
787
+ @staticmethod
788
+ async def list(
789
+ conn: InMemConnectionProto,
790
+ *,
791
+ config: Config,
792
+ limit: int = 10,
793
+ before: str | Checkpoint | None = None,
794
+ metadata: MetadataInput = None,
795
+ ) -> list[StateSnapshot]:
796
+ """Get the history of a thread."""
797
+
798
+ thread_id = _ensure_uuid(config["configurable"]["thread_id"])
799
+ thread = None
800
+
801
+ for t in conn.store["threads"]:
802
+ if t["thread_id"] == thread_id:
803
+ thread = t
804
+ break
805
+
806
+ if not thread:
807
+ return []
808
+
809
+ # Parse thread metadata and config
810
+ thread_metadata = thread["metadata"]
811
+ thread_config = thread["config"]
812
+ # If graph_id exists, get state history
813
+ if graph_id := thread_metadata.get("graph_id"):
814
+ graph = get_graph(
815
+ graph_id, thread_config, checkpointer=Checkpointer(conn)
816
+ )
817
+
818
+ # Convert before parameter if it's a string
819
+ before_param = (
820
+ {"configurable": {"checkpoint_id": before}}
821
+ if isinstance(before, str)
822
+ else before
823
+ )
824
+
825
+ states = [
826
+ state
827
+ async for state in graph.aget_state_history(
828
+ config, limit=limit, filter=metadata, before=before_param
829
+ )
830
+ ]
831
+
832
+ return states
833
+
834
+ return []
835
+
836
+
837
+ class Runs:
838
+ @staticmethod
839
+ async def stats(conn: InMemConnectionProto) -> QueueStats:
840
+ """Get stats about the queue."""
841
+ pending_runs = [run for run in conn.store["runs"] if run["status"] == "pending"]
842
+
843
+ if not pending_runs:
844
+ return {"n_pending": 0, "max_age_secs": None, "med_age_secs": None}
845
+
846
+ # Get all creation timestamps
847
+ created_times = [run.get("created_at") for run in pending_runs]
848
+ created_times = [
849
+ t for t in created_times if t is not None
850
+ ] # Filter out None values
851
+
852
+ if not created_times:
853
+ return {
854
+ "n_pending": len(pending_runs),
855
+ "max_age_secs": None,
856
+ "med_age_secs": None,
857
+ }
858
+
859
+ # Find oldest (max age)
860
+ oldest_time = min(created_times) # Earliest timestamp = oldest run
861
+
862
+ # Find median age
863
+ sorted_times = sorted(created_times)
864
+ median_idx = len(sorted_times) // 2
865
+ median_time = sorted_times[median_idx]
866
+
867
+ return {
868
+ "n_pending": len(pending_runs),
869
+ "max_age_secs": oldest_time,
870
+ "med_age_secs": median_time,
871
+ }
872
+
873
+ @asynccontextmanager
874
+ @staticmethod
875
+ async def next(conn: InMemConnectionProto) -> AsyncIterator[tuple[Run, int] | None]:
876
+ """Get the next run from the queue, and the attempt number.
877
+ 1 is the first attempt, 2 is the first retry, etc."""
878
+ now = datetime.now(UTC)
879
+
880
+ pending_runs = sorted(
881
+ [
882
+ run
883
+ for run in conn.store["runs"]
884
+ if run["status"] == "pending" and run.get("created_at", now) < now
885
+ ],
886
+ key=lambda x: x.get("created_at", datetime.min),
887
+ )
888
+
889
+ if not pending_runs:
890
+ yield None
891
+ return
892
+
893
+ # Try to lock and get the first available run
894
+ for run in pending_runs:
895
+ run_id = run["run_id"]
896
+ thread_id = run["thread_id"]
897
+ lock = conn.locks[thread_id]
898
+ acquired = lock.acquire(blocking=False)
899
+ if not acquired:
900
+ continue
901
+ try:
902
+ if run["status"] != "pending":
903
+ continue
904
+
905
+ thread = next(
906
+ (
907
+ t
908
+ for t in conn.store["threads"]
909
+ if t["thread_id"] == run["thread_id"]
910
+ ),
911
+ None,
912
+ )
913
+
914
+ if thread is None:
915
+ await logger.awarning(
916
+ "Unexpected missing thread in Runs.next",
917
+ thread_id=run["thread_id"],
918
+ )
919
+ continue
920
+
921
+ # Increment attempt counter
922
+ attempt = await conn.retry_counter.increment(run_id)
923
+ enriched_run = {
924
+ **run,
925
+ "thread_created_at": thread.get("created_at", now),
926
+ }
927
+ yield enriched_run, attempt
928
+ finally:
929
+ lock.release()
930
+ return
931
+ yield None
932
+
933
+ @asynccontextmanager
934
+ @staticmethod
935
+ async def enter(run_id: UUID) -> AsyncIterator[ValueEvent]:
936
+ """Enter a run, listen for cancellation while running, signal when done."
937
+ This method should be called as a context manager by a worker executing a run.
938
+ """
939
+ stream_manager = get_stream_manager()
940
+ # Get queue for this run
941
+ queue = await Runs.Stream.subscribe(run_id)
942
+
943
+ async with SimpleTaskGroup(cancel=True) as tg:
944
+ done = ValueEvent()
945
+ tg.create_task(listen_for_cancellation(queue, run_id, done))
946
+
947
+ try:
948
+ # Give done event to caller
949
+ yield done
950
+ finally:
951
+ # Signal done to all subscribers
952
+ control_message = Message(
953
+ topic=f"run:{run_id}:control".encode(), data=b"done"
954
+ )
955
+
956
+ # Store the control message for late subscribers
957
+ await stream_manager.put(run_id, control_message)
958
+ stream_manager.control_queues[run_id].append(control_message)
959
+ # Clean up this queue
960
+ await stream_manager.remove_queue(run_id, queue)
961
+
962
+ @staticmethod
963
+ def _merge_jsonb(*objects: dict) -> dict:
964
+ """Mimics PostgreSQL's JSONB merge behavior"""
965
+ result = {}
966
+ for obj in objects:
967
+ if obj is not None:
968
+ result.update(copy.deepcopy(obj))
969
+ return result
970
+
971
+ @staticmethod
972
+ def _get_configurable(config: dict) -> dict:
973
+ """Extract configurable from config, mimicking PostgreSQL's coalesce"""
974
+ return config.get("configurable", {})
975
+
976
+ @staticmethod
977
+ async def put(
978
+ conn: InMemConnectionProto,
979
+ assistant_id: UUID,
980
+ kwargs: dict,
981
+ *,
982
+ thread_id: UUID | None = None,
983
+ user_id: str | None = None,
984
+ run_id: UUID | None = None,
985
+ status: RunStatus | None = "pending",
986
+ metadata: MetadataInput,
987
+ prevent_insert_if_inflight: bool,
988
+ multitask_strategy: MultitaskStrategy = "reject",
989
+ if_not_exists: IfNotExists = "reject",
990
+ after_seconds: int = 0,
991
+ ) -> AsyncIterator[Run]:
992
+ """Create a run."""
993
+ assistant_id = _ensure_uuid(assistant_id)
994
+ assistant = next(
995
+ (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
996
+ None,
997
+ )
998
+
999
+ async def empty_generator():
1000
+ if False:
1001
+ yield
1002
+
1003
+ if not assistant:
1004
+ return empty_generator()
1005
+
1006
+ thread_id = _ensure_uuid(thread_id) if thread_id else None
1007
+ run_id = _ensure_uuid(run_id) if run_id else None
1008
+ metadata = metadata or {}
1009
+ config = kwargs.get("config", {})
1010
+
1011
+ # Handle thread creation/update
1012
+ existing_thread = next(
1013
+ (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
1014
+ )
1015
+
1016
+ if not existing_thread and (thread_id is None or if_not_exists == "create"):
1017
+ # Create new thread
1018
+ if thread_id is None:
1019
+ thread_id = uuid4()
1020
+ thread = Thread(
1021
+ thread_id=thread_id,
1022
+ status="busy",
1023
+ metadata={"graph_id": assistant["graph_id"]},
1024
+ config=Runs._merge_jsonb(
1025
+ assistant["config"],
1026
+ config,
1027
+ {
1028
+ "configurable": Runs._merge_jsonb(
1029
+ Runs._get_configurable(assistant["config"]),
1030
+ Runs._get_configurable(config),
1031
+ )
1032
+ },
1033
+ ),
1034
+ created_at=datetime.now(UTC),
1035
+ updated_at=datetime.now(UTC),
1036
+ )
1037
+ conn.store["threads"].append(thread)
1038
+ elif existing_thread:
1039
+ # Update existing thread
1040
+ if existing_thread["status"] != "busy":
1041
+ existing_thread["status"] = "busy"
1042
+ existing_thread["metadata"] = Runs._merge_jsonb(
1043
+ existing_thread["metadata"], {"graph_id": assistant["graph_id"]}
1044
+ )
1045
+ existing_thread["config"] = Runs._merge_jsonb(
1046
+ assistant["config"],
1047
+ existing_thread["config"],
1048
+ config,
1049
+ {
1050
+ "configurable": Runs._merge_jsonb(
1051
+ Runs._get_configurable(assistant["config"]),
1052
+ Runs._get_configurable(existing_thread["config"]),
1053
+ Runs._get_configurable(config),
1054
+ )
1055
+ },
1056
+ )
1057
+ existing_thread["updated_at"] = datetime.now(UTC)
1058
+ else:
1059
+ return empty_generator()
1060
+
1061
+ # Check for inflight runs if needed
1062
+ inflight_runs = [
1063
+ r
1064
+ for r in conn.store["runs"]
1065
+ if r["thread_id"] == thread_id and r["status"] == "pending"
1066
+ ]
1067
+ if prevent_insert_if_inflight:
1068
+ if inflight_runs:
1069
+
1070
+ async def _return_inflight():
1071
+ for run in inflight_runs:
1072
+ yield run
1073
+
1074
+ return _return_inflight()
1075
+
1076
+ # Create new run
1077
+ configurable = Runs._merge_jsonb(
1078
+ Runs._get_configurable(assistant["config"]),
1079
+ Runs._get_configurable(config),
1080
+ {
1081
+ "run_id": str(run_id),
1082
+ "thread_id": str(thread_id),
1083
+ "graph_id": assistant["graph_id"],
1084
+ "assistant_id": str(assistant_id),
1085
+ "user_id": (
1086
+ config.get("configurable", {}).get("user_id")
1087
+ or assistant["config"].get("configurable", {}).get("user_id")
1088
+ or user_id
1089
+ ),
1090
+ },
1091
+ )
1092
+ merged_metadata = Runs._merge_jsonb(
1093
+ assistant["metadata"],
1094
+ existing_thread["metadata"] if existing_thread else {},
1095
+ metadata,
1096
+ )
1097
+ new_run = Run(
1098
+ run_id=run_id,
1099
+ thread_id=thread_id,
1100
+ assistant_id=assistant_id,
1101
+ metadata=merged_metadata,
1102
+ status=status,
1103
+ kwargs=Runs._merge_jsonb(
1104
+ kwargs,
1105
+ {
1106
+ "config": Runs._merge_jsonb(
1107
+ assistant["config"],
1108
+ config,
1109
+ {"configurable": configurable},
1110
+ {
1111
+ "metadata": merged_metadata,
1112
+ },
1113
+ )
1114
+ },
1115
+ ),
1116
+ multitask_strategy=multitask_strategy,
1117
+ created_at=datetime.now(UTC) + timedelta(seconds=after_seconds),
1118
+ updated_at=datetime.now(UTC),
1119
+ )
1120
+ conn.store["runs"].append(new_run)
1121
+
1122
+ async def _yield_new():
1123
+ yield new_run
1124
+ for r in inflight_runs:
1125
+ yield r
1126
+
1127
+ return _yield_new()
1128
+
1129
+ @staticmethod
1130
+ async def get(
1131
+ conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
1132
+ ) -> AsyncIterator[Run]:
1133
+ """Get a run by ID."""
1134
+
1135
+ run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1136
+
1137
+ async def _yield_result():
1138
+ matching_run = None
1139
+ for run in conn.store["runs"]:
1140
+ if run["run_id"] == run_id and run["thread_id"] == thread_id:
1141
+ matching_run = run
1142
+ break
1143
+ if matching_run:
1144
+ yield matching_run
1145
+
1146
+ return _yield_result()
1147
+
1148
+ @staticmethod
1149
+ async def delete(
1150
+ conn: InMemConnectionProto, run_id: UUID, *, thread_id: UUID
1151
+ ) -> AsyncIterator[UUID]:
1152
+ """Delete a run by ID."""
1153
+ run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1154
+ _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
1155
+ found = False
1156
+ for i, run in enumerate(conn.store["runs"]):
1157
+ if run["run_id"] == run_id and run["thread_id"] == thread_id:
1158
+ del conn.store["runs"][i]
1159
+ found = True
1160
+ break
1161
+ if not found:
1162
+ raise HTTPException(status_code=404, detail="Run not found")
1163
+
1164
+ async def _yield_deleted():
1165
+ yield run_id
1166
+
1167
+ return _yield_deleted()
1168
+
1169
+ @staticmethod
1170
+ async def join(
1171
+ run_id: UUID,
1172
+ *,
1173
+ thread_id: UUID,
1174
+ ) -> Fragment:
1175
+ """Wait for a run to complete. If already done, return immediately.
1176
+
1177
+ Returns:
1178
+ the final state of the run.
1179
+ """
1180
+ last_chunk: bytes | None = None
1181
+ # wait for the run to complete
1182
+ async for mode, chunk in Runs.Stream.join(
1183
+ run_id, thread_id=thread_id, stream_mode="values"
1184
+ ):
1185
+ if mode == b"values":
1186
+ last_chunk = chunk
1187
+ # if we received a final chunk, return it
1188
+ if last_chunk is not None:
1189
+ # ie. if the run completed while we were waiting for it
1190
+ return Fragment(last_chunk)
1191
+ else:
1192
+ # otherwise, the run had already finished, so fetch the state from thread
1193
+ async with connect() as conn:
1194
+ thread_iter = await Threads.get(conn, thread_id)
1195
+ thread = await fetchone(thread_iter)
1196
+ return thread["values"]
1197
+
1198
+ @staticmethod
1199
+ async def cancel(
1200
+ conn: InMemConnectionProto,
1201
+ run_ids: Sequence[UUID],
1202
+ *,
1203
+ action: Literal["interrupt", "rollback"] = "interrupt",
1204
+ thread_id: UUID,
1205
+ ) -> None:
1206
+ """Cancel a run."""
1207
+ # Cancellation tries to take two actions, to cover runs in different states:
1208
+ # - For any run, send a cancellation message through the stream manager
1209
+ # - For queued runs not yet picked up by a worker, update their status if interrupt,
1210
+ # delete if rollback.
1211
+ # - For runs currently being worked on, the worker will handle cancellation
1212
+ # - For runs in any other state, we raise a 404
1213
+ run_ids = [_ensure_uuid(run_id) for run_id in run_ids]
1214
+ thread_id = _ensure_uuid(thread_id)
1215
+
1216
+ stream_manager = get_stream_manager()
1217
+ found_runs = []
1218
+ coros = []
1219
+ for run_id in run_ids:
1220
+ run = next(
1221
+ (
1222
+ r
1223
+ for r in conn.store["runs"]
1224
+ if r["run_id"] == run_id and r["thread_id"] == thread_id
1225
+ ),
1226
+ None,
1227
+ )
1228
+ if run:
1229
+ found_runs.append(run)
1230
+ # Send cancellation message through stream manager
1231
+ control_message = Message(
1232
+ topic=f"run:{run_id}:control".encode(),
1233
+ data=action.encode(),
1234
+ )
1235
+ queues = stream_manager.get_queues(run_id)
1236
+ coros.append(stream_manager.put(run_id, control_message))
1237
+
1238
+ # Update status for pending runs
1239
+ if run["status"] == "pending":
1240
+ if queues or action != "rollback":
1241
+ run["status"] = "interrupted"
1242
+ run["updated_at"] = datetime.now(tz=UTC)
1243
+ else:
1244
+ await logger.ainfo(
1245
+ "Eagerly deleting unscheduled run with rollback action",
1246
+ run_id=run_id,
1247
+ thread_id=thread_id,
1248
+ )
1249
+ coros.append(Runs.delete(conn, run_id, thread_id=thread_id))
1250
+
1251
+ else:
1252
+ await logger.awarning(
1253
+ "Attempted to cancel non-pending run.",
1254
+ run_id=run_id,
1255
+ status=run["status"],
1256
+ )
1257
+ if coros:
1258
+ await asyncio.gather(*coros)
1259
+ if len(found_runs) == len(run_ids):
1260
+ await logger.ainfo(
1261
+ "Cancelled runs",
1262
+ run_ids=[str(run_id) for run_id in run_ids],
1263
+ thread_id=str(thread_id),
1264
+ action=action,
1265
+ )
1266
+ else:
1267
+ raise HTTPException(status_code=404, detail="Run not found")
1268
+
1269
+ @staticmethod
1270
+ async def search(
1271
+ conn: InMemConnectionProto,
1272
+ thread_id: UUID,
1273
+ *,
1274
+ limit: int = 10,
1275
+ offset: int = 0,
1276
+ metadata: MetadataInput,
1277
+ ) -> AsyncIterator[Run]:
1278
+ """List all runs by thread."""
1279
+ runs = conn.store["runs"]
1280
+ metadata = metadata or {}
1281
+ thread_id = _ensure_uuid(thread_id)
1282
+ filtered_runs = [
1283
+ run
1284
+ for run in runs
1285
+ if run["thread_id"] == thread_id
1286
+ and is_jsonb_contained(run["metadata"], metadata)
1287
+ ]
1288
+ sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
1289
+ sliced_runs = sorted_runs[offset : offset + limit]
1290
+
1291
+ async def _return():
1292
+ for run in sliced_runs:
1293
+ yield run
1294
+
1295
+ return _return()
1296
+
1297
+ @staticmethod
1298
+ async def set_status(
1299
+ conn: InMemConnectionProto, run_id: UUID, status: RunStatus
1300
+ ) -> None:
1301
+ """Set the status of a run."""
1302
+ # Find the run in the store
1303
+ run_id = _ensure_uuid(run_id)
1304
+ run = next((run for run in conn.store["runs"] if run["run_id"] == run_id), None)
1305
+
1306
+ if run:
1307
+ # Update the status and updated_at timestamp
1308
+ run["status"] = status
1309
+ run["updated_at"] = datetime.now(tz=UTC)
1310
+ return run
1311
+ return None
1312
+
1313
+ class Stream:
1314
+ @staticmethod
1315
+ async def subscribe(
1316
+ run_id: UUID,
1317
+ *,
1318
+ stream_mode: "StreamMode | None" = None,
1319
+ ) -> asyncio.Queue:
1320
+ """Subscribe to the run stream, returning a queue."""
1321
+ stream_manager = get_stream_manager()
1322
+ queue = await stream_manager.add_queue(_ensure_uuid(run_id))
1323
+
1324
+ # If there's a control message already stored, send it to the new subscriber
1325
+ if control_messages := stream_manager.control_queues.get(run_id):
1326
+ for control_msg in control_messages:
1327
+ await queue.put(control_msg)
1328
+ return queue
1329
+
1330
+ @staticmethod
1331
+ async def join(
1332
+ run_id: UUID,
1333
+ *,
1334
+ thread_id: UUID,
1335
+ ignore_404: bool = False,
1336
+ cancel_on_disconnect: bool = False,
1337
+ stream_mode: "StreamMode | asyncio.Queue | None" = None,
1338
+ ) -> AsyncIterator[tuple[bytes, bytes]]:
1339
+ """Stream the run output."""
1340
+ log = logger.isEnabledFor(logging.DEBUG)
1341
+ queue = (
1342
+ stream_mode
1343
+ if isinstance(stream_mode, asyncio.Queue)
1344
+ else await Runs.Stream.subscribe(run_id)
1345
+ )
1346
+
1347
+ try:
1348
+ async with connect() as conn:
1349
+ channel_prefix = f"run:{run_id}:stream:"
1350
+ len_prefix = len(channel_prefix.encode())
1351
+
1352
+ while True:
1353
+ try:
1354
+ # Wait for messages with a timeout
1355
+ message = await asyncio.wait_for(queue.get(), timeout=0.5)
1356
+ topic, data = message.topic, message.data
1357
+
1358
+ if topic.decode() == f"run:{run_id}:control":
1359
+ if data == b"done":
1360
+ break
1361
+ else:
1362
+ # Extract mode from topic
1363
+ yield topic[len_prefix:], data
1364
+ if log:
1365
+ await logger.adebug(
1366
+ "Streamed run event",
1367
+ run_id=str(run_id),
1368
+ stream_mode=topic[len_prefix:],
1369
+ data=data,
1370
+ )
1371
+ except TimeoutError:
1372
+ # Check if the run is still pending
1373
+ run_iter = await Runs.get(conn, run_id, thread_id=thread_id)
1374
+ run = await anext(run_iter, None)
1375
+
1376
+ if ignore_404 and run is None:
1377
+ break
1378
+ elif run is None:
1379
+ yield (
1380
+ b"error",
1381
+ HTTPException(
1382
+ status_code=404, detail="Run not found"
1383
+ ),
1384
+ )
1385
+ break
1386
+ elif run["status"] != "pending":
1387
+ break
1388
+ except:
1389
+ if cancel_on_disconnect:
1390
+ create_task(cancel_run(thread_id, run_id))
1391
+ raise
1392
+ finally:
1393
+ stream_manager = get_stream_manager()
1394
+ await stream_manager.remove_queue(run_id, queue)
1395
+
1396
+ @staticmethod
1397
+ async def publish(
1398
+ run_id: UUID,
1399
+ event: str,
1400
+ message: bytes,
1401
+ ) -> None:
1402
+ """Publish a message to all subscribers of the run stream."""
1403
+ topic = f"run:{run_id}:stream:{event}".encode()
1404
+
1405
+ stream_manager = get_stream_manager()
1406
+ # Send to all queues subscribed to this run_id
1407
+ await stream_manager.put(run_id, Message(topic=topic, data=message))
1408
+
1409
+
1410
+ async def listen_for_cancellation(
1411
+ queue: asyncio.Queue, run_id: UUID, done: "ValueEvent"
1412
+ ):
1413
+ """Listen for cancellation messages and set the done event accordingly."""
1414
+ stream_manager = get_stream_manager()
1415
+ control_key = f"run:{run_id}:control"
1416
+
1417
+ if existing_queue := stream_manager.control_queues.get(run_id):
1418
+ for message in existing_queue:
1419
+ payload = message.data
1420
+ if payload == b"rollback":
1421
+ done.set(UserRollback())
1422
+ elif payload == b"interrupt":
1423
+ done.set(UserInterrupt())
1424
+
1425
+ while not done.is_set():
1426
+ try:
1427
+ # This task gets cancelled when Runs.enter exits anyway,
1428
+ # so we can have a pretty length timeout here
1429
+ message = await asyncio.wait_for(queue.get(), timeout=240)
1430
+ payload = message.data
1431
+ if payload == b"rollback":
1432
+ done.set(UserRollback())
1433
+ elif payload == b"interrupt":
1434
+ done.set(UserInterrupt())
1435
+ elif payload == b"done":
1436
+ done.set()
1437
+ break
1438
+
1439
+ # Store control messages for late subscribers
1440
+ if message.topic.decode() == control_key:
1441
+ stream_manager.control_queues[run_id].append(message)
1442
+ except TimeoutError:
1443
+ break
1444
+
1445
+
1446
+ class Crons:
1447
+ @staticmethod
1448
+ async def put(
1449
+ conn: InMemConnectionProto,
1450
+ *,
1451
+ payload: dict,
1452
+ schedule: str,
1453
+ cron_id: UUID | None = None,
1454
+ thread_id: UUID | None = None,
1455
+ user_id: str | None = None,
1456
+ end_time: datetime | None = None,
1457
+ ) -> AsyncIterator[Cron]:
1458
+ raise NotImplementedError
1459
+
1460
+ @staticmethod
1461
+ async def delete(conn: InMemConnectionProto, cron_id: UUID) -> AsyncIterator[UUID]:
1462
+ raise NotImplementedError
1463
+
1464
+ @staticmethod
1465
+ async def next(conn: InMemConnectionProto) -> AsyncIterator[Cron]:
1466
+ raise NotImplementedError
1467
+
1468
+ @staticmethod
1469
+ async def set_next_run_date(
1470
+ conn: InMemConnectionProto, cron_id: UUID, next_run_date: datetime
1471
+ ) -> None:
1472
+ raise NotImplementedError
1473
+
1474
+ @staticmethod
1475
+ async def search(
1476
+ conn: InMemConnectionProto,
1477
+ *,
1478
+ assistant_id: UUID | None,
1479
+ thread_id: UUID | None,
1480
+ limit: int,
1481
+ offset: int,
1482
+ ) -> AsyncIterator[Cron]:
1483
+ raise NotImplementedError
1484
+
1485
+
1486
+ async def cancel_run(thread_id: UUID, run_id: UUID) -> None:
1487
+ async with connect() as conn:
1488
+ await Runs.cancel(conn, [run_id], thread_id=thread_id)
1489
+
1490
+
1491
+ def _delete_checkpoints_for_thread(
1492
+ thread_id: str | UUID,
1493
+ conn: InMemConnectionProto,
1494
+ run_id: str | UUID | None = None,
1495
+ ):
1496
+ checkpointer = Checkpointer(conn)
1497
+ thread_id = str(thread_id)
1498
+ if thread_id not in checkpointer.storage:
1499
+ return
1500
+ if run_id:
1501
+ # Look through metadata
1502
+ run_id = str(run_id)
1503
+ for checkpoint_ns, checkpoints in list(checkpointer.storage[thread_id].items()):
1504
+ for checkpoint_id, (_, metadata_b, _) in list(checkpoints.items()):
1505
+ metadata = checkpointer.serde.loads_typed(metadata_b)
1506
+ if metadata.get("run_id") == run_id:
1507
+ del checkpointer.storage[thread_id][checkpoint_ns][checkpoint_id]
1508
+ if not checkpointer.storage[thread_id][checkpoint_ns]:
1509
+ del checkpointer.storage[thread_id][checkpoint_ns]
1510
+ else:
1511
+ del checkpointer.storage[thread_id]
1512
+ # Keys are (thread_id, checkpoint_ns, checkpoint_id)
1513
+ checkpointer.writes = defaultdict(
1514
+ dict, {k: v for k, v in checkpointer.writes.items() if k[0] != thread_id}
1515
+ )
1516
+
1517
+
1518
+ __all__ = [
1519
+ "Assistants",
1520
+ "Crons",
1521
+ "Runs",
1522
+ "Threads",
1523
+ ]