langgraph-api 0.0.48__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

Files changed (50) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +2 -2
  3. langgraph_api/api/assistants.py +3 -3
  4. langgraph_api/api/meta.py +9 -11
  5. langgraph_api/api/runs.py +3 -3
  6. langgraph_api/api/store.py +2 -2
  7. langgraph_api/api/threads.py +3 -3
  8. langgraph_api/auth/custom.py +25 -4
  9. langgraph_api/cli.py +3 -1
  10. langgraph_api/config.py +3 -0
  11. langgraph_api/cron_scheduler.py +3 -3
  12. langgraph_api/graph.py +6 -14
  13. langgraph_api/js/base.py +17 -0
  14. langgraph_api/js/build.mts +3 -3
  15. langgraph_api/js/client.mts +64 -3
  16. langgraph_api/js/global.d.ts +1 -0
  17. langgraph_api/js/package.json +4 -3
  18. langgraph_api/js/remote.py +96 -5
  19. langgraph_api/js/src/graph.mts +0 -6
  20. langgraph_api/js/src/utils/files.mts +4 -0
  21. langgraph_api/js/tests/api.test.mts +80 -80
  22. langgraph_api/js/tests/auth.test.mts +648 -0
  23. langgraph_api/js/tests/compose-postgres.auth.yml +59 -0
  24. langgraph_api/js/tests/graphs/agent_simple.mts +79 -0
  25. langgraph_api/js/tests/graphs/auth.mts +106 -0
  26. langgraph_api/js/tests/graphs/package.json +3 -1
  27. langgraph_api/js/tests/graphs/yarn.lock +9 -4
  28. langgraph_api/js/yarn.lock +18 -23
  29. langgraph_api/metadata.py +7 -0
  30. langgraph_api/models/run.py +10 -1
  31. langgraph_api/queue_entrypoint.py +1 -1
  32. langgraph_api/server.py +2 -2
  33. langgraph_api/stream.py +5 -4
  34. langgraph_api/thread_ttl.py +2 -2
  35. langgraph_api/worker.py +4 -25
  36. {langgraph_api-0.0.48.dist-info → langgraph_api-0.1.2.dist-info}/METADATA +1 -2
  37. {langgraph_api-0.0.48.dist-info → langgraph_api-0.1.2.dist-info}/RECORD +42 -44
  38. langgraph_runtime/__init__.py +39 -0
  39. langgraph_api/lifespan.py +0 -74
  40. langgraph_storage/checkpoint.py +0 -123
  41. langgraph_storage/database.py +0 -200
  42. langgraph_storage/inmem_stream.py +0 -109
  43. langgraph_storage/ops.py +0 -2172
  44. langgraph_storage/queue.py +0 -186
  45. langgraph_storage/retry.py +0 -31
  46. langgraph_storage/store.py +0 -100
  47. {langgraph_storage → langgraph_api/js}/__init__.py +0 -0
  48. {langgraph_api-0.0.48.dist-info → langgraph_api-0.1.2.dist-info}/LICENSE +0 -0
  49. {langgraph_api-0.0.48.dist-info → langgraph_api-0.1.2.dist-info}/WHEEL +0 -0
  50. {langgraph_api-0.0.48.dist-info → langgraph_api-0.1.2.dist-info}/entry_points.txt +0 -0
langgraph_storage/ops.py DELETED
@@ -1,2172 +0,0 @@
1
- """Implementation of the LangGraph API using in-memory checkpointer & store."""
2
-
3
- import asyncio
4
- import base64
5
- import copy
6
- import json
7
- import uuid
8
- from collections import defaultdict
9
- from collections.abc import AsyncIterator, Sequence
10
- from contextlib import asynccontextmanager
11
- from copy import deepcopy
12
- from datetime import UTC, datetime, timedelta
13
- from typing import Any, Literal, cast
14
- from uuid import UUID, uuid4
15
-
16
- import structlog
17
- from langgraph.checkpoint.serde.jsonplus import _msgpack_ext_hook_to_json
18
- from langgraph.pregel.debug import CheckpointPayload
19
- from langgraph.pregel.types import StateSnapshot
20
- from langgraph_sdk import Auth
21
- from starlette.exceptions import HTTPException
22
-
23
- from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
24
- from langgraph_api.auth.custom import handle_event
25
- from langgraph_api.command import map_cmd
26
- from langgraph_api.config import ThreadTTLConfig
27
- from langgraph_api.errors import UserInterrupt, UserRollback
28
- from langgraph_api.graph import get_graph
29
- from langgraph_api.schema import (
30
- Assistant,
31
- Checkpoint,
32
- Config,
33
- Cron,
34
- IfNotExists,
35
- MetadataInput,
36
- MetadataValue,
37
- MultitaskStrategy,
38
- OnConflictBehavior,
39
- QueueStats,
40
- Run,
41
- RunStatus,
42
- StreamMode,
43
- Thread,
44
- ThreadStatus,
45
- ThreadUpdateResponse,
46
- )
47
- from langgraph_api.serde import Fragment
48
- from langgraph_api.utils import fetchone, get_auth_ctx
49
- from langgraph_storage.checkpoint import Checkpointer
50
- from langgraph_storage.database import InMemConnectionProto, connect
51
- from langgraph_storage.inmem_stream import Message, get_stream_manager
52
-
53
- logger = structlog.stdlib.get_logger(__name__)
54
-
55
-
56
- def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
57
- if isinstance(id_, str):
58
- return uuid.UUID(id_)
59
- if id_ is None:
60
- return uuid4()
61
- return id_
62
-
63
-
64
- class WrappedHTTPException(Exception):
65
- def __init__(self, http_exception: HTTPException):
66
- self.http_exception = http_exception
67
-
68
-
69
- # Right now the whole API types as UUID but frequently passes a str
70
- # We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
71
- # which we leave as strings. This is because I'm too lazy to subclass fully
72
- # and we use non-UUID examples in the OSS version
73
-
74
-
75
- class Authenticated:
76
- resource: Literal["threads", "crons", "assistants"]
77
-
78
- @classmethod
79
- def _context(
80
- cls,
81
- ctx: Auth.types.BaseAuthContext | None,
82
- action: Literal["create", "read", "update", "delete", "create_run"],
83
- ) -> Auth.types.AuthContext | None:
84
- if not ctx:
85
- return
86
- return Auth.types.AuthContext(
87
- user=ctx.user,
88
- permissions=ctx.permissions,
89
- resource=cls.resource,
90
- action=action,
91
- )
92
-
93
- @classmethod
94
- async def handle_event(
95
- cls,
96
- ctx: Auth.types.BaseAuthContext | None,
97
- action: Literal["create", "read", "update", "delete", "search", "create_run"],
98
- value: Any,
99
- ) -> Auth.types.FilterType | None:
100
- ctx = ctx or get_auth_ctx()
101
- if not ctx:
102
- return
103
- return await handle_event(cls._context(ctx, action), value)
104
-
105
-
106
- class Assistants(Authenticated):
107
- resource = "assistants"
108
-
109
- @staticmethod
110
- async def search(
111
- conn: InMemConnectionProto,
112
- *,
113
- graph_id: str | None,
114
- metadata: MetadataInput,
115
- limit: int,
116
- offset: int,
117
- ctx: Auth.types.BaseAuthContext | None = None,
118
- ) -> AsyncIterator[Assistant]:
119
- metadata = metadata if metadata is not None else {}
120
- filters = await Assistants.handle_event(
121
- ctx,
122
- "search",
123
- Auth.types.AssistantsSearch(
124
- graph_id=graph_id, metadata=metadata, limit=limit, offset=offset
125
- ),
126
- )
127
-
128
- async def filter_and_yield() -> AsyncIterator[Assistant]:
129
- assistants = conn.store["assistants"]
130
- filtered_assistants = [
131
- assistant
132
- for assistant in assistants
133
- if (not graph_id or assistant["graph_id"] == graph_id)
134
- and (
135
- not metadata or is_jsonb_contained(assistant["metadata"], metadata)
136
- )
137
- and (not filters or _check_filter_match(assistant["metadata"], filters))
138
- ]
139
- filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
140
- for assistant in filtered_assistants[offset : offset + limit]:
141
- yield assistant
142
-
143
- return filter_and_yield()
144
-
145
- @staticmethod
146
- async def get(
147
- conn: InMemConnectionProto,
148
- assistant_id: UUID,
149
- ctx: Auth.types.BaseAuthContext | None = None,
150
- ) -> AsyncIterator[Assistant]:
151
- """Get an assistant by ID."""
152
- assistant_id = _ensure_uuid(assistant_id)
153
- filters = await Assistants.handle_event(
154
- ctx,
155
- "read",
156
- Auth.types.AssistantsRead(assistant_id=assistant_id),
157
- )
158
-
159
- async def _yield_result():
160
- for assistant in conn.store["assistants"]:
161
- if assistant["assistant_id"] == assistant_id and (
162
- not filters or _check_filter_match(assistant["metadata"], filters)
163
- ):
164
- yield assistant
165
-
166
- return _yield_result()
167
-
168
- @staticmethod
169
- async def put(
170
- conn: InMemConnectionProto,
171
- assistant_id: UUID,
172
- *,
173
- graph_id: str,
174
- config: Config,
175
- metadata: MetadataInput,
176
- if_exists: OnConflictBehavior,
177
- name: str,
178
- ctx: Auth.types.BaseAuthContext | None = None,
179
- description: str | None = None,
180
- ) -> AsyncIterator[Assistant]:
181
- """Insert an assistant."""
182
- assistant_id = _ensure_uuid(assistant_id)
183
- metadata = metadata if metadata is not None else {}
184
- filters = await Assistants.handle_event(
185
- ctx,
186
- "create",
187
- Auth.types.AssistantsCreate(
188
- assistant_id=assistant_id,
189
- graph_id=graph_id,
190
- config=config,
191
- metadata=metadata,
192
- name=name,
193
- ),
194
- )
195
- existing_assistant = next(
196
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
197
- None,
198
- )
199
- if existing_assistant:
200
- if filters and not _check_filter_match(
201
- existing_assistant["metadata"], filters
202
- ):
203
- raise HTTPException(
204
- status_code=409, detail=f"Assistant {assistant_id} already exists"
205
- )
206
- if if_exists == "raise":
207
- raise HTTPException(
208
- status_code=409, detail=f"Assistant {assistant_id} already exists"
209
- )
210
- elif if_exists == "do_nothing":
211
-
212
- async def _yield_existing():
213
- yield existing_assistant
214
-
215
- return _yield_existing()
216
-
217
- now = datetime.now(UTC)
218
- new_assistant: Assistant = {
219
- "assistant_id": assistant_id,
220
- "graph_id": graph_id,
221
- "config": config or {},
222
- "metadata": metadata or {},
223
- "name": name,
224
- "created_at": now,
225
- "updated_at": now,
226
- "version": 1,
227
- "description": description,
228
- }
229
- new_version = {
230
- "assistant_id": assistant_id,
231
- "version": 1,
232
- "graph_id": graph_id,
233
- "config": config or {},
234
- "metadata": metadata or {},
235
- "created_at": now,
236
- "name": name,
237
- }
238
- conn.store["assistants"].append(new_assistant)
239
- conn.store["assistant_versions"].append(new_version)
240
-
241
- async def _yield_new():
242
- yield new_assistant
243
-
244
- return _yield_new()
245
-
246
- @staticmethod
247
- async def patch(
248
- conn: InMemConnectionProto,
249
- assistant_id: UUID,
250
- *,
251
- config: dict | None = None,
252
- graph_id: str | None = None,
253
- metadata: MetadataInput | None = None,
254
- name: str | None = None,
255
- description: str | None = None,
256
- ctx: Auth.types.BaseAuthContext | None = None,
257
- ) -> AsyncIterator[Assistant]:
258
- """Update an assistant.
259
-
260
- Args:
261
- conn: The connection to the in-memory store.
262
- assistant_id: The assistant ID.
263
- graph_id: The graph ID.
264
- config: The assistant config.
265
- metadata: The assistant metadata.
266
- name: The assistant name.
267
- description: The assistant description.
268
- ctx: The auth context.
269
-
270
- Returns:
271
- return the updated assistant model.
272
- """
273
- assistant_id = _ensure_uuid(assistant_id)
274
- metadata = metadata if metadata is not None else {}
275
- filters = await Assistants.handle_event(
276
- ctx,
277
- "update",
278
- Auth.types.AssistantsUpdate(
279
- assistant_id=assistant_id,
280
- graph_id=graph_id,
281
- config=config,
282
- metadata=metadata,
283
- name=name,
284
- ),
285
- )
286
- assistant = next(
287
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
288
- None,
289
- )
290
- if not assistant:
291
- raise HTTPException(
292
- status_code=404, detail=f"Assistant {assistant_id} not found"
293
- )
294
- elif filters and not _check_filter_match(assistant["metadata"], filters):
295
- raise HTTPException(
296
- status_code=404, detail=f"Assistant {assistant_id} not found"
297
- )
298
-
299
- now = datetime.now(UTC)
300
- new_version = (
301
- max(
302
- v["version"]
303
- for v in conn.store["assistant_versions"]
304
- if v["assistant_id"] == assistant_id
305
- )
306
- + 1
307
- if conn.store["assistant_versions"]
308
- else 1
309
- )
310
-
311
- # Update assistant_versions table
312
- if metadata:
313
- metadata = {
314
- **assistant["metadata"],
315
- **metadata,
316
- }
317
- new_version_entry = {
318
- "assistant_id": assistant_id,
319
- "version": new_version,
320
- "graph_id": graph_id if graph_id is not None else assistant["graph_id"],
321
- "config": config if config is not None else assistant["config"],
322
- "metadata": metadata if metadata is not None else assistant["metadata"],
323
- "created_at": now,
324
- "name": name if name is not None else assistant["name"],
325
- "description": description
326
- if description is not None
327
- else assistant["description"],
328
- }
329
- conn.store["assistant_versions"].append(new_version_entry)
330
-
331
- # Update assistants table
332
- assistant.update(
333
- {
334
- "graph_id": new_version_entry["graph_id"],
335
- "config": new_version_entry["config"],
336
- "metadata": new_version_entry["metadata"],
337
- "name": name if name is not None else assistant["name"],
338
- "description": description
339
- if description is not None
340
- else assistant["description"],
341
- "updated_at": now,
342
- "version": new_version,
343
- }
344
- )
345
-
346
- async def _yield_updated():
347
- yield assistant
348
-
349
- return _yield_updated()
350
-
351
- @staticmethod
352
- async def delete(
353
- conn: InMemConnectionProto,
354
- assistant_id: UUID,
355
- ctx: Auth.types.BaseAuthContext | None = None,
356
- ) -> AsyncIterator[UUID]:
357
- """Delete an assistant by ID."""
358
- assistant_id = _ensure_uuid(assistant_id)
359
- filters = await Assistants.handle_event(
360
- ctx,
361
- "delete",
362
- Auth.types.AssistantsDelete(
363
- assistant_id=assistant_id,
364
- ),
365
- )
366
- assistant = next(
367
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
368
- None,
369
- )
370
-
371
- if not assistant:
372
- raise HTTPException(
373
- status_code=404, detail=f"Assistant with ID {assistant_id} not found"
374
- )
375
- elif filters and not _check_filter_match(assistant["metadata"], filters):
376
- raise HTTPException(
377
- status_code=404, detail=f"Assistant with ID {assistant_id} not found"
378
- )
379
-
380
- conn.store["assistants"] = [
381
- a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
382
- ]
383
- # Cascade delete assistant versions, crons, & runs on this assistant
384
- conn.store["assistant_versions"] = [
385
- v
386
- for v in conn.store["assistant_versions"]
387
- if v["assistant_id"] != assistant_id
388
- ]
389
- retained = []
390
- for run in conn.store["runs"]:
391
- if run["assistant_id"] == assistant_id:
392
- res = await Runs.delete(
393
- conn, run["run_id"], thread_id=run["thread_id"], ctx=ctx
394
- )
395
- await anext(res)
396
- else:
397
- retained.append(run)
398
-
399
- async def _yield_deleted():
400
- yield assistant_id
401
-
402
- return _yield_deleted()
403
-
404
- @staticmethod
405
- async def set_latest(
406
- conn: InMemConnectionProto,
407
- assistant_id: UUID,
408
- version: int,
409
- ctx: Auth.types.BaseAuthContext | None = None,
410
- ) -> AsyncIterator[Assistant]:
411
- """Change the version of an assistant."""
412
- assistant_id = _ensure_uuid(assistant_id)
413
- filters = await Assistants.handle_event(
414
- ctx,
415
- "update",
416
- Auth.types.AssistantsUpdate(
417
- assistant_id=assistant_id,
418
- version=version,
419
- ),
420
- )
421
- assistant = next(
422
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
423
- None,
424
- )
425
- if not assistant:
426
- raise HTTPException(
427
- status_code=404, detail=f"Assistant {assistant_id} not found"
428
- )
429
- elif filters and not _check_filter_match(assistant["metadata"], filters):
430
- raise HTTPException(
431
- status_code=404, detail=f"Assistant {assistant_id} not found"
432
- )
433
-
434
- version_data = next(
435
- (
436
- v
437
- for v in conn.store["assistant_versions"]
438
- if v["assistant_id"] == assistant_id and v["version"] == version
439
- ),
440
- None,
441
- )
442
- if not version_data:
443
- raise HTTPException(
444
- status_code=404,
445
- detail=f"Version {version} not found for assistant {assistant_id}",
446
- )
447
-
448
- assistant.update(
449
- {
450
- "config": version_data["config"],
451
- "metadata": version_data["metadata"],
452
- "version": version_data["version"],
453
- "updated_at": datetime.now(UTC),
454
- }
455
- )
456
-
457
- async def _yield_updated():
458
- yield assistant
459
-
460
- return _yield_updated()
461
-
462
- @staticmethod
463
- async def get_versions(
464
- conn: InMemConnectionProto,
465
- assistant_id: UUID,
466
- metadata: MetadataInput,
467
- limit: int,
468
- offset: int,
469
- ctx: Auth.types.BaseAuthContext | None = None,
470
- ) -> AsyncIterator[Assistant]:
471
- """Get all versions of an assistant."""
472
- assistant_id = _ensure_uuid(assistant_id)
473
- filters = await Assistants.handle_event(
474
- ctx,
475
- "read",
476
- Auth.types.AssistantsRead(assistant_id=assistant_id),
477
- )
478
- assistant = next(
479
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
480
- None,
481
- )
482
- if not assistant:
483
- raise HTTPException(
484
- status_code=404, detail=f"Assistant {assistant_id} not found"
485
- )
486
- versions = [
487
- v
488
- for v in conn.store["assistant_versions"]
489
- if v["assistant_id"] == assistant_id
490
- and (not metadata or is_jsonb_contained(v["metadata"], metadata))
491
- and (not filters or _check_filter_match(v["metadata"], filters))
492
- ]
493
-
494
- # Previously, the name was not included in the assistant_versions table. So we should add them here.
495
- for v in versions:
496
- if "name" not in v:
497
- v["name"] = assistant["name"]
498
-
499
- versions.sort(key=lambda x: x["version"], reverse=True)
500
-
501
- async def _yield_versions():
502
- for version in versions[offset : offset + limit]:
503
- yield version
504
-
505
- return _yield_versions()
506
-
507
-
508
- def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
509
- """
510
- Implements Postgres' @> (containment) operator for dictionaries.
511
- Returns True if superset contains all key/value pairs from subset.
512
- """
513
- for key, value in subset.items():
514
- if key not in superset:
515
- return False
516
- if isinstance(value, dict) and isinstance(superset[key], dict):
517
- if not is_jsonb_contained(superset[key], value):
518
- return False
519
- elif superset[key] != value:
520
- return False
521
- return True
522
-
523
-
524
- def bytes_decoder(obj):
525
- """Custom JSON decoder that converts base64 back to bytes."""
526
- if "__type__" in obj and obj["__type__"] == "bytes":
527
- return base64.b64decode(obj["value"].encode("utf-8"))
528
- return obj
529
-
530
-
531
- def _replace_thread_id(data, new_thread_id, thread_id):
532
- class BytesEncoder(json.JSONEncoder):
533
- """Custom JSON encoder that handles bytes by converting them to base64."""
534
-
535
- def default(self, obj):
536
- if isinstance(obj, bytes | bytearray):
537
- return {
538
- "__type__": "bytes",
539
- "value": base64.b64encode(
540
- obj.replace(
541
- str(thread_id).encode(), str(new_thread_id).encode()
542
- )
543
- ).decode("utf-8"),
544
- }
545
-
546
- return super().default(obj)
547
-
548
- try:
549
- json_str = json.dumps(data, cls=BytesEncoder, indent=2)
550
- except Exception as e:
551
- raise ValueError(data) from e
552
- json_str = json_str.replace(str(thread_id), str(new_thread_id))
553
-
554
- # Decoding back from JSON
555
- d = json.loads(json_str, object_hook=bytes_decoder)
556
- return d
557
-
558
-
559
- class Threads(Authenticated):
560
- resource = "threads"
561
-
562
- @staticmethod
563
- async def search(
564
- conn: InMemConnectionProto,
565
- *,
566
- metadata: MetadataInput,
567
- values: MetadataInput,
568
- status: ThreadStatus | None,
569
- limit: int,
570
- offset: int,
571
- ctx: Auth.types.BaseAuthContext | None = None,
572
- ) -> AsyncIterator[Thread]:
573
- threads = conn.store["threads"]
574
- filtered_threads: list[Thread] = []
575
- metadata = metadata if metadata is not None else {}
576
- values = values if values is not None else {}
577
- filters = await Threads.handle_event(
578
- ctx,
579
- "search",
580
- Auth.types.ThreadsSearch(
581
- metadata=metadata,
582
- values=values,
583
- status=status,
584
- limit=limit,
585
- offset=offset,
586
- ),
587
- )
588
-
589
- # Apply filters
590
- for thread in threads:
591
- if filters and not _check_filter_match(thread["metadata"], filters):
592
- continue
593
-
594
- if metadata and not is_jsonb_contained(thread["metadata"], metadata):
595
- continue
596
-
597
- if (
598
- values
599
- and "values" in thread
600
- and not is_jsonb_contained(thread["values"], values)
601
- ):
602
- continue
603
-
604
- if status and thread.get("status") != status:
605
- continue
606
-
607
- filtered_threads.append(thread)
608
-
609
- # Sort by created_at in descending order
610
- sorted_threads = sorted(
611
- filtered_threads, key=lambda x: x["created_at"], reverse=True
612
- )
613
-
614
- # Apply limit and offset
615
- paginated_threads = sorted_threads[offset : offset + limit]
616
-
617
- async def thread_iterator() -> AsyncIterator[Thread]:
618
- for thread in paginated_threads:
619
- yield thread
620
-
621
- return thread_iterator()
622
-
623
- @staticmethod
624
- async def _get_with_filters(
625
- conn: InMemConnectionProto,
626
- thread_id: UUID,
627
- filters: Auth.types.FilterType | None,
628
- ) -> Thread | None:
629
- thread_id = _ensure_uuid(thread_id)
630
- matching_thread = next(
631
- (
632
- thread
633
- for thread in conn.store["threads"]
634
- if thread["thread_id"] == thread_id
635
- ),
636
- None,
637
- )
638
- if not matching_thread or (
639
- filters and not _check_filter_match(matching_thread["metadata"], filters)
640
- ):
641
- return
642
-
643
- return matching_thread
644
-
645
- @staticmethod
646
- async def _get(
647
- conn: InMemConnectionProto,
648
- thread_id: UUID,
649
- ctx: Auth.types.BaseAuthContext | None = None,
650
- ) -> Thread | None:
651
- """Get a thread by ID."""
652
- thread_id = _ensure_uuid(thread_id)
653
- filters = await Threads.handle_event(
654
- ctx,
655
- "read",
656
- Auth.types.ThreadsRead(thread_id=thread_id),
657
- )
658
- return await Threads._get_with_filters(conn, thread_id, filters)
659
-
660
- @staticmethod
661
- async def get(
662
- conn: InMemConnectionProto,
663
- thread_id: UUID,
664
- ctx: Auth.types.BaseAuthContext | None = None,
665
- ) -> AsyncIterator[Thread]:
666
- """Get a thread by ID."""
667
- matching_thread = await Threads._get(conn, thread_id, ctx)
668
-
669
- if not matching_thread:
670
- raise HTTPException(
671
- status_code=404, detail=f"Thread with ID {thread_id} not found"
672
- )
673
-
674
- async def _yield_result():
675
- if matching_thread:
676
- yield matching_thread
677
-
678
- return _yield_result()
679
-
680
- @staticmethod
681
- async def put(
682
- conn: InMemConnectionProto,
683
- thread_id: UUID,
684
- *,
685
- metadata: MetadataInput,
686
- if_exists: OnConflictBehavior,
687
- ttl: ThreadTTLConfig | None = None,
688
- ctx: Auth.types.BaseAuthContext | None = None,
689
- ) -> AsyncIterator[Thread]:
690
- """Insert or update a thread."""
691
- thread_id = _ensure_uuid(thread_id)
692
- if metadata is None:
693
- metadata = {}
694
-
695
- # Check if thread already exists
696
- existing_thread = next(
697
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
698
- )
699
- filters = await Threads.handle_event(
700
- ctx,
701
- "create",
702
- Auth.types.ThreadsCreate(
703
- thread_id=thread_id, metadata=metadata, if_exists=if_exists
704
- ),
705
- )
706
-
707
- if existing_thread:
708
- if filters and not _check_filter_match(
709
- existing_thread["metadata"], filters
710
- ):
711
- # Should we use a different status code here?
712
- raise HTTPException(
713
- status_code=409, detail=f"Thread with ID {thread_id} already exists"
714
- )
715
- if if_exists == "raise":
716
- raise HTTPException(
717
- status_code=409, detail=f"Thread with ID {thread_id} already exists"
718
- )
719
- elif if_exists == "do_nothing":
720
-
721
- async def _yield_existing():
722
- yield existing_thread
723
-
724
- return _yield_existing()
725
- # Create new thread
726
- new_thread: Thread = {
727
- "thread_id": thread_id,
728
- "created_at": datetime.now(UTC),
729
- "updated_at": datetime.now(UTC),
730
- "metadata": copy.deepcopy(metadata),
731
- "status": "idle",
732
- "config": {},
733
- "values": None,
734
- }
735
-
736
- # Add to store
737
- conn.store["threads"].append(new_thread)
738
-
739
- async def _yield_new():
740
- yield new_thread
741
-
742
- return _yield_new()
743
-
744
- @staticmethod
745
- async def patch(
746
- conn: InMemConnectionProto,
747
- thread_id: UUID,
748
- *,
749
- metadata: MetadataValue,
750
- ctx: Auth.types.BaseAuthContext | None = None,
751
- ) -> AsyncIterator[Thread]:
752
- """Update a thread."""
753
- thread_list = conn.store["threads"]
754
- thread_idx = None
755
- thread_id = _ensure_uuid(thread_id)
756
-
757
- for idx, thread in enumerate(thread_list):
758
- if thread["thread_id"] == thread_id:
759
- thread_idx = idx
760
- break
761
-
762
- if thread_idx is not None:
763
- filters = await Threads.handle_event(
764
- ctx,
765
- "update",
766
- Auth.types.ThreadsUpdate(thread_id=thread_id, metadata=metadata),
767
- )
768
- if not filters or _check_filter_match(
769
- thread_list[thread_idx]["metadata"], filters
770
- ):
771
- thread = copy.deepcopy(thread_list[thread_idx])
772
- thread["metadata"] = {**thread["metadata"], **metadata}
773
- thread["updated_at"] = datetime.now(UTC)
774
- thread_list[thread_idx] = thread
775
-
776
- async def thread_iterator() -> AsyncIterator[Thread]:
777
- yield thread
778
-
779
- return thread_iterator()
780
-
781
- async def empty_iterator() -> AsyncIterator[Thread]:
782
- if False: # This ensures the iterator is empty
783
- yield
784
-
785
- return empty_iterator()
786
-
787
- @staticmethod
788
- async def set_status(
789
- conn: InMemConnectionProto,
790
- thread_id: UUID,
791
- checkpoint: CheckpointPayload | None,
792
- exception: BaseException | None,
793
- # This does not accept the auth context since it's only used internally
794
- ) -> None:
795
- """Set the status of a thread."""
796
- thread_id = _ensure_uuid(thread_id)
797
-
798
- async def has_pending_runs(conn_: InMemConnectionProto, tid: UUID) -> bool:
799
- """Check if thread has any pending runs."""
800
- return any(
801
- run["status"] in ("pending", "running") and run["thread_id"] == tid
802
- for run in conn_.store["runs"]
803
- )
804
-
805
- # Find the thread
806
- thread = next(
807
- (
808
- thread
809
- for thread in conn.store["threads"]
810
- if thread["thread_id"] == thread_id
811
- ),
812
- None,
813
- )
814
-
815
- if not thread:
816
- raise HTTPException(
817
- status_code=404, detail=f"Thread {thread_id} not found."
818
- )
819
-
820
- # Determine has_next from checkpoint
821
- has_next = False if checkpoint is None else bool(checkpoint["next"])
822
-
823
- # Determine base status
824
- if exception:
825
- status = "error"
826
- elif has_next:
827
- status = "interrupted"
828
- else:
829
- status = "idle"
830
-
831
- # Check for pending runs and update to busy if found
832
- if await has_pending_runs(conn, thread_id):
833
- status = "busy"
834
-
835
- # Update thread
836
- thread.update(
837
- {
838
- "updated_at": datetime.now(UTC),
839
- "values": checkpoint["values"] if checkpoint else None,
840
- "status": status,
841
- "interrupts": (
842
- {
843
- t["id"]: t["interrupts"]
844
- for t in checkpoint["tasks"]
845
- if t.get("interrupts")
846
- }
847
- if checkpoint
848
- else {}
849
- ),
850
- }
851
- )
852
-
853
- @staticmethod
854
- async def delete(
855
- conn: InMemConnectionProto,
856
- thread_id: UUID,
857
- ctx: Auth.types.BaseAuthContext | None = None,
858
- ) -> AsyncIterator[UUID]:
859
- """Delete a thread by ID and cascade delete all associated runs."""
860
- thread_list = conn.store["threads"]
861
- thread_idx = None
862
- thread_id = _ensure_uuid(thread_id)
863
-
864
- # Find the thread to delete
865
- for idx, thread in enumerate(thread_list):
866
- if thread["thread_id"] == thread_id:
867
- thread_idx = idx
868
- break
869
- filters = await Threads.handle_event(
870
- ctx,
871
- "delete",
872
- Auth.types.ThreadsDelete(thread_id=thread_id),
873
- )
874
- if (filters and not _check_filter_match(thread["metadata"], filters)) or (
875
- thread_idx is None
876
- ):
877
- raise HTTPException(
878
- status_code=404, detail=f"Thread with ID {thread_id} not found"
879
- )
880
- # Cascade delete all runs associated with this thread
881
- conn.store["runs"] = [
882
- run for run in conn.store["runs"] if run["thread_id"] != thread_id
883
- ]
884
- _delete_checkpoints_for_thread(thread_id, conn)
885
-
886
- if thread_idx is not None:
887
- # Remove the thread from the store
888
- deleted_thread = thread_list.pop(thread_idx)
889
-
890
- # Return an async iterator with the deleted thread_id
891
- async def id_iterator() -> AsyncIterator[UUID]:
892
- yield deleted_thread["thread_id"]
893
-
894
- return id_iterator()
895
-
896
- # If thread not found, return empty iterator
897
- async def empty_iterator() -> AsyncIterator[UUID]:
898
- if False: # This ensures the iterator is empty
899
- yield
900
-
901
- return empty_iterator()
902
-
903
- @staticmethod
904
- async def copy(
905
- conn: InMemConnectionProto,
906
- thread_id: UUID,
907
- ctx: Auth.types.BaseAuthContext | None = None,
908
- ) -> AsyncIterator[Thread]:
909
- """Create a copy of an existing thread."""
910
- thread_id = _ensure_uuid(thread_id)
911
- new_thread_id = uuid4()
912
- filters = await Threads.handle_event(
913
- ctx,
914
- "read",
915
- Auth.types.ThreadsRead(
916
- thread_id=new_thread_id,
917
- ),
918
- )
919
- async with conn.pipeline():
920
- # Find the original thread in our store
921
- original_thread = next(
922
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
923
- )
924
-
925
- if not original_thread:
926
- return _empty_generator()
927
- if filters and not _check_filter_match(
928
- original_thread["metadata"], filters
929
- ):
930
- return _empty_generator()
931
-
932
- # Create new thread with copied metadata
933
- new_thread: Thread = {
934
- "thread_id": new_thread_id,
935
- "created_at": datetime.now(tz=UTC),
936
- "updated_at": datetime.now(tz=UTC),
937
- "metadata": deepcopy(original_thread["metadata"]),
938
- "status": "idle",
939
- "config": {},
940
- }
941
-
942
- # Add new thread to store
943
- conn.store["threads"].append(new_thread)
944
-
945
- checkpointer = Checkpointer(conn)
946
- copied_storage = _replace_thread_id(
947
- checkpointer.storage[str(thread_id)], new_thread_id, thread_id
948
- )
949
- checkpointer.storage[str(new_thread_id)] = copied_storage
950
- # Copy the writes over (if any)
951
- outer_keys = []
952
- for k in checkpointer.writes:
953
- if k[0] == str(thread_id):
954
- outer_keys.append(k)
955
- for tid, checkpoint_ns, checkpoint_id in outer_keys:
956
- mapped = {
957
- k: _replace_thread_id(v, new_thread_id, thread_id)
958
- for k, v in checkpointer.writes[
959
- (str(tid), checkpoint_ns, checkpoint_id)
960
- ].items()
961
- }
962
-
963
- checkpointer.writes[
964
- (str(new_thread_id), checkpoint_ns, checkpoint_id)
965
- ] = mapped
966
- # Copy the blobs
967
- for k in list(checkpointer.blobs):
968
- if str(k[0]) == str(thread_id):
969
- new_key = (str(new_thread_id), *k[1:])
970
- checkpointer.blobs[new_key] = checkpointer.blobs[k]
971
-
972
- async def row_generator() -> AsyncIterator[Thread]:
973
- yield new_thread
974
-
975
- return row_generator()
976
-
977
- @staticmethod
978
- async def sweep_ttl(
979
- conn: InMemConnectionProto,
980
- *,
981
- limit: int | None = None,
982
- batch_size: int = 100,
983
- ) -> tuple[int, int]:
984
- # Not implemented for inmem server
985
- return (0, 0)
986
-
987
- class State(Authenticated):
988
- # We will treat this like a runs resource for now.
989
- resource = "threads"
990
-
991
- @staticmethod
992
- async def get(
993
- conn: InMemConnectionProto,
994
- config: Config,
995
- subgraphs: bool = False,
996
- ctx: Auth.types.BaseAuthContext | None = None,
997
- ) -> StateSnapshot:
998
- """Get state for a thread."""
999
- checkpointer = await asyncio.to_thread(
1000
- Checkpointer, conn, unpack_hook=_msgpack_ext_hook_to_json
1001
- )
1002
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1003
- # Auth will be applied here so no need to use filters downstream
1004
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1005
- thread = await anext(thread_iter)
1006
- checkpoint = await checkpointer.aget(config)
1007
-
1008
- if not thread:
1009
- return StateSnapshot(
1010
- values={},
1011
- next=[],
1012
- config=None,
1013
- metadata=None,
1014
- created_at=None,
1015
- parent_config=None,
1016
- tasks=tuple(),
1017
- )
1018
-
1019
- metadata = thread.get("metadata", {})
1020
- thread_config = thread.get("config", {})
1021
-
1022
- if graph_id := metadata.get("graph_id"):
1023
- # format latest checkpoint for response
1024
- checkpointer.latest_iter = checkpoint
1025
- async with get_graph(
1026
- graph_id, thread_config, checkpointer=checkpointer
1027
- ) as graph:
1028
- result = await graph.aget_state(config, subgraphs=subgraphs)
1029
- if (
1030
- result.metadata is not None
1031
- and "checkpoint_ns" in result.metadata
1032
- and result.metadata["checkpoint_ns"] == ""
1033
- ):
1034
- result.metadata.pop("checkpoint_ns")
1035
- return result
1036
- else:
1037
- return StateSnapshot(
1038
- values={},
1039
- next=[],
1040
- config=None,
1041
- metadata=None,
1042
- created_at=None,
1043
- parent_config=None,
1044
- tasks=tuple(),
1045
- )
1046
-
1047
- @staticmethod
1048
- async def post(
1049
- conn: InMemConnectionProto,
1050
- config: Config,
1051
- values: Sequence[dict] | dict[str, Any] | None,
1052
- as_node: str | None = None,
1053
- ctx: Auth.types.BaseAuthContext | None = None,
1054
- ) -> ThreadUpdateResponse:
1055
- """Add state to a thread."""
1056
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1057
- filters = await Threads.handle_event(
1058
- ctx,
1059
- "update",
1060
- Auth.types.ThreadsUpdate(thread_id=thread_id),
1061
- )
1062
-
1063
- checkpointer = Checkpointer(conn)
1064
-
1065
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1066
- thread = await fetchone(
1067
- thread_iter, not_found_detail=f"Thread {thread_id} not found."
1068
- )
1069
- checkpoint = await checkpointer.aget(config)
1070
-
1071
- if not thread:
1072
- raise HTTPException(status_code=404, detail="Thread not found")
1073
- if not _check_filter_match(thread["metadata"], filters):
1074
- raise HTTPException(status_code=403, detail="Forbidden")
1075
-
1076
- metadata = thread["metadata"]
1077
- thread_config = thread["config"]
1078
-
1079
- if graph_id := metadata.get("graph_id"):
1080
- config["configurable"].setdefault("graph_id", graph_id)
1081
-
1082
- checkpointer.latest_iter = checkpoint
1083
- async with get_graph(
1084
- graph_id, thread_config, checkpointer=checkpointer
1085
- ) as graph:
1086
- update_config = config.copy()
1087
- update_config["configurable"] = {
1088
- **config["configurable"],
1089
- "checkpoint_ns": config["configurable"].get(
1090
- "checkpoint_ns", ""
1091
- ),
1092
- }
1093
- next_config = await graph.aupdate_state(
1094
- update_config, values, as_node=as_node
1095
- )
1096
-
1097
- # Get current state
1098
- state = await Threads.State.get(
1099
- conn, config, subgraphs=False, ctx=ctx
1100
- )
1101
- # Update thread values
1102
- for thread in conn.store["threads"]:
1103
- if thread["thread_id"] == thread_id:
1104
- thread["values"] = state.values
1105
- break
1106
-
1107
- return ThreadUpdateResponse(
1108
- checkpoint=next_config["configurable"],
1109
- # Including deprecated fields
1110
- configurable=next_config["configurable"],
1111
- checkpoint_id=next_config["configurable"]["checkpoint_id"],
1112
- )
1113
- else:
1114
- raise HTTPException(status_code=400, detail="Thread has no graph ID.")
1115
-
1116
- @staticmethod
1117
- async def bulk(
1118
- conn: InMemConnectionProto,
1119
- *,
1120
- config: Config,
1121
- supersteps: Sequence[dict],
1122
- ctx: Auth.types.BaseAuthContext | None = None,
1123
- ) -> ThreadUpdateResponse:
1124
- """Update a thread with a batch of state updates."""
1125
-
1126
- from langgraph.pregel.types import StateUpdate
1127
-
1128
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1129
- filters = await Threads.handle_event(
1130
- ctx,
1131
- "update",
1132
- Auth.types.ThreadsUpdate(thread_id=thread_id),
1133
- )
1134
-
1135
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1136
- thread = await fetchone(
1137
- thread_iter, not_found_detail=f"Thread {thread_id} not found."
1138
- )
1139
-
1140
- thread_config = thread["config"]
1141
- metadata = thread["metadata"]
1142
-
1143
- if not thread:
1144
- raise HTTPException(status_code=404, detail="Thread not found")
1145
-
1146
- if not _check_filter_match(metadata, filters):
1147
- raise HTTPException(status_code=403, detail="Forbidden")
1148
-
1149
- if graph_id := metadata.get("graph_id"):
1150
- config["configurable"].setdefault("graph_id", graph_id)
1151
- config["configurable"].setdefault("checkpoint_ns", "")
1152
-
1153
- async with get_graph(
1154
- graph_id, thread_config, checkpointer=Checkpointer(conn)
1155
- ) as graph:
1156
- next_config = await graph.abulk_update_state(
1157
- config,
1158
- [
1159
- [
1160
- StateUpdate(
1161
- map_cmd(update.get("command"))
1162
- if update.get("command")
1163
- else update.get("values"),
1164
- update.get("as_node"),
1165
- )
1166
- for update in superstep.get("updates", [])
1167
- ]
1168
- for superstep in supersteps
1169
- ],
1170
- )
1171
-
1172
- state = await Threads.State.get(
1173
- conn, config, subgraphs=False, ctx=ctx
1174
- )
1175
-
1176
- # update thread values
1177
- for thread in conn.store["threads"]:
1178
- if thread["thread_id"] == thread_id:
1179
- thread["values"] = state.values
1180
- break
1181
-
1182
- return ThreadUpdateResponse(
1183
- checkpoint=next_config["configurable"],
1184
- )
1185
- else:
1186
- raise HTTPException(status_code=400, detail="Thread has no graph ID")
1187
-
1188
- @staticmethod
1189
- async def list(
1190
- conn: InMemConnectionProto,
1191
- *,
1192
- config: Config,
1193
- limit: int = 10,
1194
- before: str | Checkpoint | None = None,
1195
- metadata: MetadataInput = None,
1196
- ctx: Auth.types.BaseAuthContext | None = None,
1197
- ) -> list[StateSnapshot]:
1198
- """Get the history of a thread."""
1199
-
1200
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1201
- thread = None
1202
- filters = await Threads.handle_event(
1203
- ctx,
1204
- "read",
1205
- Auth.types.ThreadsRead(thread_id=thread_id),
1206
- )
1207
- thread = await fetchone(
1208
- await Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx)
1209
- )
1210
-
1211
- # Parse thread metadata and config
1212
- thread_metadata = thread["metadata"]
1213
- if not _check_filter_match(thread_metadata, filters):
1214
- return []
1215
-
1216
- thread_config = thread["config"]
1217
- # If graph_id exists, get state history
1218
- if graph_id := thread_metadata.get("graph_id"):
1219
- async with get_graph(
1220
- graph_id,
1221
- thread_config,
1222
- checkpointer=await asyncio.to_thread(
1223
- Checkpointer, conn, unpack_hook=_msgpack_ext_hook_to_json
1224
- ),
1225
- ) as graph:
1226
- # Convert before parameter if it's a string
1227
- before_param = (
1228
- {"configurable": {"checkpoint_id": before}}
1229
- if isinstance(before, str)
1230
- else before
1231
- )
1232
-
1233
- states = [
1234
- state
1235
- async for state in graph.aget_state_history(
1236
- config, limit=limit, filter=metadata, before=before_param
1237
- )
1238
- ]
1239
-
1240
- return states
1241
-
1242
- return []
1243
-
1244
-
1245
- RUN_LOCK = asyncio.Lock()
1246
-
1247
-
1248
- class Runs(Authenticated):
1249
- resource = "threads"
1250
-
1251
- @staticmethod
1252
- async def stats(conn: InMemConnectionProto) -> QueueStats:
1253
- """Get stats about the queue."""
1254
- pending_runs = [run for run in conn.store["runs"] if run["status"] == "pending"]
1255
- running_runs = [run for run in conn.store["runs"] if run["status"] == "running"]
1256
-
1257
- if not pending_runs and not running_runs:
1258
- return {
1259
- "n_pending": 0,
1260
- "max_age_secs": None,
1261
- "med_age_secs": None,
1262
- "n_running": 0,
1263
- }
1264
-
1265
- # Get all creation timestamps
1266
- created_times = [run.get("created_at") for run in (pending_runs + running_runs)]
1267
- created_times = [
1268
- t for t in created_times if t is not None
1269
- ] # Filter out None values
1270
-
1271
- if not created_times:
1272
- return {
1273
- "n_pending": len(pending_runs),
1274
- "n_running": len(running_runs),
1275
- "max_age_secs": None,
1276
- "med_age_secs": None,
1277
- }
1278
-
1279
- # Find oldest (max age)
1280
- oldest_time = min(created_times) # Earliest timestamp = oldest run
1281
-
1282
- # Find median age
1283
- sorted_times = sorted(created_times)
1284
- median_idx = len(sorted_times) // 2
1285
- median_time = sorted_times[median_idx]
1286
-
1287
- return {
1288
- "n_pending": len(pending_runs),
1289
- "n_running": len(running_runs),
1290
- "max_age_secs": oldest_time,
1291
- "med_age_secs": median_time,
1292
- }
1293
-
1294
- @staticmethod
1295
- async def next(wait: bool, limit: int = 1) -> AsyncIterator[tuple[Run, int]]:
1296
- """Get the next run from the queue, and the attempt number.
1297
- 1 is the first attempt, 2 is the first retry, etc."""
1298
- now = datetime.now(UTC)
1299
-
1300
- if wait:
1301
- await asyncio.sleep(0.5)
1302
- else:
1303
- await asyncio.sleep(0)
1304
-
1305
- async with connect() as conn, RUN_LOCK:
1306
- pending_runs = sorted(
1307
- [
1308
- run
1309
- for run in conn.store["runs"]
1310
- if run["status"] == "pending" and run.get("created_at", now) < now
1311
- ],
1312
- key=lambda x: x.get("created_at", datetime.min),
1313
- )
1314
-
1315
- if not pending_runs:
1316
- return
1317
-
1318
- # Try to lock and get the first available run
1319
- for _, run in zip(range(limit), pending_runs, strict=False):
1320
- if run["status"] != "pending":
1321
- continue
1322
-
1323
- run_id = run["run_id"]
1324
- thread_id = run["thread_id"]
1325
- thread = next(
1326
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id),
1327
- None,
1328
- )
1329
-
1330
- if thread is None:
1331
- await logger.awarning(
1332
- "Unexpected missing thread in Runs.next",
1333
- thread_id=run["thread_id"],
1334
- )
1335
- continue
1336
-
1337
- if run["status"] != "pending":
1338
- continue
1339
-
1340
- if any(
1341
- run["status"] == "running"
1342
- for run in conn.store["runs"]
1343
- if run["thread_id"] == thread_id
1344
- ):
1345
- continue
1346
- # Increment attempt counter
1347
- attempt = await conn.retry_counter.increment(run_id)
1348
- # Set run as "running"
1349
- run["status"] = "running"
1350
- yield run, attempt
1351
-
1352
- @asynccontextmanager
1353
- @staticmethod
1354
- async def enter(
1355
- run_id: UUID, loop: asyncio.AbstractEventLoop
1356
- ) -> AsyncIterator[ValueEvent]:
1357
- """Enter a run, listen for cancellation while running, signal when done."
1358
- This method should be called as a context manager by a worker executing a run.
1359
- """
1360
- stream_manager = get_stream_manager()
1361
- # Get queue for this run
1362
- queue = await Runs.Stream.subscribe(run_id)
1363
-
1364
- async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
1365
- done = ValueEvent()
1366
- tg.create_task(listen_for_cancellation(queue, run_id, done))
1367
-
1368
- # Give done event to caller
1369
- yield done
1370
- # Signal done to all subscribers
1371
- control_message = Message(
1372
- topic=f"run:{run_id}:control".encode(), data=b"done"
1373
- )
1374
-
1375
- # Store the control message for late subscribers
1376
- await stream_manager.put(run_id, control_message)
1377
- stream_manager.control_queues[run_id].append(control_message)
1378
- # Clean up this queue
1379
- await stream_manager.remove_queue(run_id, queue)
1380
-
1381
- @staticmethod
1382
- async def sweep(conn: InMemConnectionProto) -> list[UUID]:
1383
- """Sweep runs that are no longer running"""
1384
- return []
1385
-
1386
- @staticmethod
1387
- def _merge_jsonb(*objects: dict) -> dict:
1388
- """Mimics PostgreSQL's JSONB merge behavior"""
1389
- result = {}
1390
- for obj in objects:
1391
- if obj is not None:
1392
- result.update(copy.deepcopy(obj))
1393
- return result
1394
-
1395
- @staticmethod
1396
- def _get_configurable(config: dict) -> dict:
1397
- """Extract configurable from config, mimicking PostgreSQL's coalesce"""
1398
- return config.get("configurable", {})
1399
-
1400
- @staticmethod
1401
- async def put(
1402
- conn: InMemConnectionProto,
1403
- assistant_id: UUID,
1404
- kwargs: dict,
1405
- *,
1406
- thread_id: UUID | None = None,
1407
- user_id: str | None = None,
1408
- run_id: UUID | None = None,
1409
- status: RunStatus | None = "pending",
1410
- metadata: MetadataInput,
1411
- prevent_insert_if_inflight: bool,
1412
- multitask_strategy: MultitaskStrategy = "reject",
1413
- if_not_exists: IfNotExists = "reject",
1414
- after_seconds: int = 0,
1415
- ctx: Auth.types.BaseAuthContext | None = None,
1416
- ) -> AsyncIterator[Run]:
1417
- """Create a run."""
1418
- assistant_id = _ensure_uuid(assistant_id)
1419
- assistant = next(
1420
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
1421
- None,
1422
- )
1423
-
1424
- if not assistant:
1425
- return _empty_generator()
1426
-
1427
- thread_id = _ensure_uuid(thread_id) if thread_id else None
1428
- run_id = _ensure_uuid(run_id) if run_id else None
1429
- metadata = metadata if metadata is not None else {}
1430
- config = kwargs.get("config", {})
1431
-
1432
- # Handle thread creation/update
1433
- existing_thread = next(
1434
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
1435
- )
1436
- filters = await Runs.handle_event(
1437
- ctx,
1438
- "create_run",
1439
- Auth.types.RunsCreate(
1440
- thread_id=thread_id,
1441
- assistant_id=assistant_id,
1442
- run_id=run_id,
1443
- status=status,
1444
- metadata=metadata,
1445
- prevent_insert_if_inflight=prevent_insert_if_inflight,
1446
- multitask_strategy=multitask_strategy,
1447
- if_not_exists=if_not_exists,
1448
- after_seconds=after_seconds,
1449
- kwargs=kwargs,
1450
- ),
1451
- )
1452
- if existing_thread and filters:
1453
- # Reject if the user doesn't own the thread
1454
- if not _check_filter_match(existing_thread["metadata"], filters):
1455
- return _empty_generator()
1456
-
1457
- if not existing_thread and (thread_id is None or if_not_exists == "create"):
1458
- # Create new thread
1459
- if thread_id is None:
1460
- thread_id = uuid4()
1461
- thread = Thread(
1462
- thread_id=thread_id,
1463
- status="busy",
1464
- metadata={
1465
- "graph_id": assistant["graph_id"],
1466
- "assistant_id": str(assistant_id),
1467
- **metadata,
1468
- },
1469
- config=Runs._merge_jsonb(
1470
- assistant["config"],
1471
- config,
1472
- {
1473
- "configurable": Runs._merge_jsonb(
1474
- Runs._get_configurable(assistant["config"]),
1475
- Runs._get_configurable(config),
1476
- )
1477
- },
1478
- ),
1479
- created_at=datetime.now(UTC),
1480
- updated_at=datetime.now(UTC),
1481
- values=b"",
1482
- )
1483
- await logger.ainfo("Creating thread", thread_id=thread_id)
1484
- conn.store["threads"].append(thread)
1485
- elif existing_thread:
1486
- # Update existing thread
1487
- if existing_thread["status"] != "busy":
1488
- existing_thread["status"] = "busy"
1489
- existing_thread["metadata"] = Runs._merge_jsonb(
1490
- existing_thread["metadata"],
1491
- {
1492
- "graph_id": assistant["graph_id"],
1493
- "assistant_id": str(assistant_id),
1494
- },
1495
- )
1496
- existing_thread["config"] = Runs._merge_jsonb(
1497
- assistant["config"],
1498
- existing_thread["config"],
1499
- config,
1500
- {
1501
- "configurable": Runs._merge_jsonb(
1502
- Runs._get_configurable(assistant["config"]),
1503
- Runs._get_configurable(existing_thread["config"]),
1504
- Runs._get_configurable(config),
1505
- )
1506
- },
1507
- )
1508
- existing_thread["updated_at"] = datetime.now(UTC)
1509
- else:
1510
- return _empty_generator()
1511
-
1512
- # Check for inflight runs if needed
1513
- inflight_runs = [
1514
- r
1515
- for r in conn.store["runs"]
1516
- if r["thread_id"] == thread_id and r["status"] in ("pending", "running")
1517
- ]
1518
- if prevent_insert_if_inflight:
1519
- if inflight_runs:
1520
-
1521
- async def _return_inflight():
1522
- for run in inflight_runs:
1523
- yield run
1524
-
1525
- return _return_inflight()
1526
-
1527
- # Create new run
1528
- configurable = Runs._merge_jsonb(
1529
- Runs._get_configurable(assistant["config"]),
1530
- (
1531
- Runs._get_configurable(existing_thread["config"])
1532
- if existing_thread
1533
- else {}
1534
- ),
1535
- Runs._get_configurable(config),
1536
- {
1537
- "run_id": str(run_id),
1538
- "thread_id": str(thread_id),
1539
- "graph_id": assistant["graph_id"],
1540
- "assistant_id": str(assistant_id),
1541
- "user_id": (
1542
- config.get("configurable", {}).get("user_id")
1543
- or (
1544
- existing_thread["config"].get("configurable", {}).get("user_id")
1545
- if existing_thread
1546
- else None
1547
- )
1548
- or assistant["config"].get("configurable", {}).get("user_id")
1549
- or user_id
1550
- ),
1551
- },
1552
- )
1553
- merged_metadata = Runs._merge_jsonb(
1554
- assistant["metadata"],
1555
- existing_thread["metadata"] if existing_thread else {},
1556
- metadata,
1557
- )
1558
- new_run = Run(
1559
- run_id=run_id,
1560
- thread_id=thread_id,
1561
- assistant_id=assistant_id,
1562
- metadata=merged_metadata,
1563
- status=status,
1564
- kwargs=Runs._merge_jsonb(
1565
- kwargs,
1566
- {
1567
- "config": Runs._merge_jsonb(
1568
- assistant["config"],
1569
- config,
1570
- {"configurable": configurable},
1571
- {
1572
- "metadata": merged_metadata,
1573
- },
1574
- )
1575
- },
1576
- ),
1577
- multitask_strategy=multitask_strategy,
1578
- created_at=datetime.now(UTC) + timedelta(seconds=after_seconds),
1579
- updated_at=datetime.now(UTC),
1580
- )
1581
- conn.store["runs"].append(new_run)
1582
-
1583
- async def _yield_new():
1584
- yield new_run
1585
- for r in inflight_runs:
1586
- yield r
1587
-
1588
- return _yield_new()
1589
-
1590
- @staticmethod
1591
- async def get(
1592
- conn: InMemConnectionProto,
1593
- run_id: UUID,
1594
- *,
1595
- thread_id: UUID,
1596
- ctx: Auth.types.BaseAuthContext | None = None,
1597
- ) -> AsyncIterator[Run]:
1598
- """Get a run by ID."""
1599
-
1600
- run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1601
- filters = await Runs.handle_event(
1602
- ctx,
1603
- "read",
1604
- Auth.types.ThreadsRead(thread_id=thread_id),
1605
- )
1606
-
1607
- async def _yield_result():
1608
- matching_run = None
1609
- for run in conn.store["runs"]:
1610
- if run["run_id"] == run_id and run["thread_id"] == thread_id:
1611
- matching_run = run
1612
- break
1613
- if matching_run:
1614
- if filters:
1615
- thread = await Threads._get_with_filters(
1616
- conn, matching_run["thread_id"], filters
1617
- )
1618
- if not thread:
1619
- return
1620
- yield matching_run
1621
-
1622
- return _yield_result()
1623
-
1624
- @staticmethod
1625
- async def delete(
1626
- conn: InMemConnectionProto,
1627
- run_id: UUID,
1628
- *,
1629
- thread_id: UUID,
1630
- ctx: Auth.types.BaseAuthContext | None = None,
1631
- ) -> AsyncIterator[UUID]:
1632
- """Delete a run by ID."""
1633
- run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1634
- filters = await Runs.handle_event(
1635
- ctx,
1636
- "delete",
1637
- Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
1638
- )
1639
-
1640
- if filters:
1641
- thread = await Threads._get_with_filters(conn, thread_id, filters)
1642
- if not thread:
1643
- return _empty_generator()
1644
- _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
1645
- found = False
1646
- for i, run in enumerate(conn.store["runs"]):
1647
- if run["run_id"] == run_id and run["thread_id"] == thread_id:
1648
- del conn.store["runs"][i]
1649
- found = True
1650
- break
1651
- if not found:
1652
- raise HTTPException(status_code=404, detail="Run not found")
1653
-
1654
- async def _yield_deleted():
1655
- await logger.ainfo("Run deleted", run_id=run_id)
1656
- yield run_id
1657
-
1658
- return _yield_deleted()
1659
-
1660
- @staticmethod
1661
- async def join(
1662
- run_id: UUID,
1663
- *,
1664
- thread_id: UUID,
1665
- ctx: Auth.types.BaseAuthContext | None = None,
1666
- ) -> Fragment:
1667
- """Wait for a run to complete. If already done, return immediately.
1668
-
1669
- Returns:
1670
- the final state of the run.
1671
- """
1672
- async with connect() as conn:
1673
- # Validate ownership
1674
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1675
- await fetchone(thread_iter)
1676
- last_chunk: bytes | None = None
1677
- # wait for the run to complete
1678
- # Rely on this join's auth
1679
- async for mode, chunk in Runs.Stream.join(
1680
- run_id, thread_id=thread_id, stream_mode="values", ctx=ctx, ignore_404=True
1681
- ):
1682
- if mode == b"values":
1683
- last_chunk = chunk
1684
- # if we received a final chunk, return it
1685
- if last_chunk is not None:
1686
- # ie. if the run completed while we were waiting for it
1687
- return Fragment(last_chunk)
1688
- else:
1689
- # otherwise, the run had already finished, so fetch the state from thread
1690
- async with connect() as conn:
1691
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1692
- thread = await fetchone(thread_iter)
1693
- return thread["values"]
1694
-
1695
- @staticmethod
1696
- async def cancel(
1697
- conn: InMemConnectionProto,
1698
- run_ids: Sequence[UUID] | None = None,
1699
- *,
1700
- action: Literal["interrupt", "rollback"] = "interrupt",
1701
- thread_id: UUID | None = None,
1702
- status: Literal["pending", "running", "all"] | None = None,
1703
- ctx: Auth.types.BaseAuthContext | None = None,
1704
- ) -> None:
1705
- """
1706
- Cancel runs in memory. Must provide either:
1707
- 1) thread_id + run_ids, or
1708
- 2) status in {"pending", "running", "all"}.
1709
-
1710
- Steps:
1711
- - Validate arguments (one usage pattern or the other).
1712
- - Auth check: 'update' event via handle_event().
1713
- - Gather runs matching either the (thread_id, run_ids) set or the given status.
1714
- - For each run found:
1715
- * Send a cancellation message through the stream manager.
1716
- * If 'pending', set to 'interrupted' or delete (if action='rollback' and not actively queued).
1717
- * If 'running', the worker will pick up the message.
1718
- * Otherwise, log a warning for non-cancelable states.
1719
- - 404 if no runs are found or authorized.
1720
- """
1721
- # 1. Validate arguments
1722
- if status is not None:
1723
- # If status is set, user must NOT specify thread_id or run_ids
1724
- if thread_id is not None or run_ids is not None:
1725
- raise HTTPException(
1726
- status_code=422,
1727
- detail="Cannot specify 'thread_id' or 'run_ids' when using 'status'",
1728
- )
1729
- else:
1730
- # If status is not set, user must specify both thread_id and run_ids
1731
- if thread_id is None or run_ids is None:
1732
- raise HTTPException(
1733
- status_code=422,
1734
- detail="Must provide either a status or both 'thread_id' and 'run_ids'",
1735
- )
1736
-
1737
- # Convert and normalize inputs
1738
- if run_ids is not None:
1739
- run_ids = [_ensure_uuid(rid) for rid in run_ids]
1740
- if thread_id is not None:
1741
- thread_id = _ensure_uuid(thread_id)
1742
-
1743
- filters = await Runs.handle_event(
1744
- ctx,
1745
- "update",
1746
- Auth.types.ThreadsUpdate(
1747
- thread_id=thread_id, # type: ignore
1748
- action=action,
1749
- metadata={
1750
- "run_ids": run_ids,
1751
- "status": status,
1752
- },
1753
- ),
1754
- )
1755
-
1756
- status_list: tuple[str, ...] = ()
1757
- if status is not None:
1758
- if status == "all":
1759
- status_list = ("pending", "running")
1760
- elif status in ("pending", "running"):
1761
- status_list = (status,)
1762
- else:
1763
- raise ValueError(f"Unsupported status: {status}")
1764
-
1765
- def is_run_match(r: dict) -> bool:
1766
- """
1767
- Check whether a run in `conn.store["runs"]` meets the selection criteria.
1768
- """
1769
- if status_list:
1770
- return r["status"] in status_list
1771
- else:
1772
- return r["thread_id"] == thread_id and r["run_id"] in run_ids # type: ignore
1773
-
1774
- candidate_runs = [r for r in conn.store["runs"] if is_run_match(r)]
1775
-
1776
- if filters:
1777
- # If a run is found but not authorized by the thread filters, skip it
1778
- thread = (
1779
- await Threads._get_with_filters(conn, thread_id, filters)
1780
- if thread_id
1781
- else None
1782
- )
1783
- # If there's no matching thread, no runs are authorized.
1784
- if thread_id and not thread:
1785
- candidate_runs = []
1786
- # Otherwise, we might trust that `_get_with_filters` is the only constraint
1787
- # on thread. If your filters also apply to runs, you might do more checks here.
1788
-
1789
- if not candidate_runs:
1790
- raise HTTPException(status_code=404, detail="No runs found to cancel.")
1791
-
1792
- stream_manager = get_stream_manager()
1793
- coros = []
1794
- for run in candidate_runs:
1795
- run_id = run["run_id"]
1796
- control_message = Message(
1797
- topic=f"run:{run_id}:control".encode(),
1798
- data=action.encode(),
1799
- )
1800
- coros.append(stream_manager.put(run_id, control_message))
1801
-
1802
- queues = stream_manager.get_queues(run_id)
1803
-
1804
- if run["status"] in ("pending", "running"):
1805
- if queues or action != "rollback":
1806
- run["status"] = "interrupted"
1807
- run["updated_at"] = datetime.now(tz=UTC)
1808
- else:
1809
- await logger.ainfo(
1810
- "Eagerly deleting pending run with rollback action",
1811
- run_id=str(run_id),
1812
- status=run["status"],
1813
- )
1814
- coros.append(Runs.delete(conn, run_id, thread_id=run["thread_id"]))
1815
- else:
1816
- await logger.awarning(
1817
- "Attempted to cancel non-pending run.",
1818
- run_id=str(run_id),
1819
- status=run["status"],
1820
- )
1821
-
1822
- if coros:
1823
- await asyncio.gather(*coros)
1824
-
1825
- await logger.ainfo(
1826
- "Cancelled runs",
1827
- run_ids=[str(r["run_id"]) for r in candidate_runs],
1828
- thread_id=str(thread_id) if thread_id else None,
1829
- status=status,
1830
- action=action,
1831
- )
1832
-
1833
- @staticmethod
1834
- async def search(
1835
- conn: InMemConnectionProto,
1836
- thread_id: UUID,
1837
- *,
1838
- limit: int = 10,
1839
- offset: int = 0,
1840
- metadata: MetadataInput,
1841
- status: RunStatus | None = None,
1842
- ctx: Auth.types.BaseAuthContext | None = None,
1843
- ) -> AsyncIterator[Run]:
1844
- """List all runs by thread."""
1845
- runs = conn.store["runs"]
1846
- metadata = metadata if metadata is not None else {}
1847
- thread_id = _ensure_uuid(thread_id)
1848
- filters = await Runs.handle_event(
1849
- ctx,
1850
- "search",
1851
- Auth.types.ThreadsSearch(thread_id=thread_id, metadata=metadata),
1852
- )
1853
- filtered_runs = [
1854
- run
1855
- for run in runs
1856
- if run["thread_id"] == thread_id
1857
- and is_jsonb_contained(run["metadata"], metadata)
1858
- and (
1859
- not filters
1860
- or (await Threads._get_with_filters(conn, thread_id, filters))
1861
- )
1862
- and (status is None or run["status"] == status)
1863
- ]
1864
- sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
1865
- sliced_runs = sorted_runs[offset : offset + limit]
1866
-
1867
- async def _return():
1868
- for run in sliced_runs:
1869
- yield run
1870
-
1871
- return _return()
1872
-
1873
- @staticmethod
1874
- async def set_status(
1875
- conn: InMemConnectionProto, run_id: UUID, status: RunStatus
1876
- ) -> None:
1877
- """Set the status of a run."""
1878
- # Find the run in the store
1879
- run_id = _ensure_uuid(run_id)
1880
- run = next((run for run in conn.store["runs"] if run["run_id"] == run_id), None)
1881
-
1882
- if run:
1883
- # Update the status and updated_at timestamp
1884
- run["status"] = status
1885
- run["updated_at"] = datetime.now(tz=UTC)
1886
- return run
1887
- return None
1888
-
1889
- class Stream:
1890
- @staticmethod
1891
- async def subscribe(
1892
- run_id: UUID,
1893
- *,
1894
- stream_mode: "StreamMode | None" = None,
1895
- ) -> asyncio.Queue:
1896
- """Subscribe to the run stream, returning a queue."""
1897
- stream_manager = get_stream_manager()
1898
- queue = await stream_manager.add_queue(_ensure_uuid(run_id))
1899
-
1900
- # If there's a control message already stored, send it to the new subscriber
1901
- if control_messages := stream_manager.control_queues.get(run_id):
1902
- for control_msg in control_messages:
1903
- await queue.put(control_msg)
1904
- return queue
1905
-
1906
- @staticmethod
1907
- async def join(
1908
- run_id: UUID,
1909
- *,
1910
- thread_id: UUID,
1911
- ignore_404: bool = False,
1912
- cancel_on_disconnect: bool = False,
1913
- stream_mode: "StreamMode | asyncio.Queue | None" = None,
1914
- ctx: Auth.types.BaseAuthContext | None = None,
1915
- ) -> AsyncIterator[tuple[bytes, bytes]]:
1916
- """Stream the run output."""
1917
- queue = (
1918
- stream_mode
1919
- if isinstance(stream_mode, asyncio.Queue)
1920
- else await Runs.Stream.subscribe(run_id, stream_mode=stream_mode)
1921
- )
1922
-
1923
- try:
1924
- async with connect() as conn:
1925
- filters = await Runs.handle_event(
1926
- ctx,
1927
- "read",
1928
- Auth.types.ThreadsRead(thread_id=thread_id),
1929
- )
1930
- if filters:
1931
- thread = await Threads._get_with_filters(
1932
- cast(InMemConnectionProto, conn), thread_id, filters
1933
- )
1934
- if not thread:
1935
- raise WrappedHTTPException(
1936
- HTTPException(
1937
- status_code=404, detail="Thread not found"
1938
- )
1939
- )
1940
- channel_prefix = f"run:{run_id}:stream:"
1941
- len_prefix = len(channel_prefix.encode())
1942
-
1943
- while True:
1944
- try:
1945
- # Wait for messages with a timeout
1946
- message = await asyncio.wait_for(queue.get(), timeout=0.5)
1947
- topic, data = message.topic, message.data
1948
-
1949
- if topic.decode() == f"run:{run_id}:control":
1950
- if data == b"done":
1951
- break
1952
- else:
1953
- # Extract mode from topic
1954
- yield topic[len_prefix:], data
1955
- logger.debug(
1956
- "Streamed run event",
1957
- run_id=str(run_id),
1958
- stream_mode=topic[len_prefix:],
1959
- data=data,
1960
- )
1961
- except TimeoutError:
1962
- # Check if the run is still pending
1963
- run_iter = await Runs.get(
1964
- conn, run_id, thread_id=thread_id, ctx=ctx
1965
- )
1966
- run = await anext(run_iter, None)
1967
-
1968
- if ignore_404 and run is None:
1969
- break
1970
- elif run is None:
1971
- yield (
1972
- b"error",
1973
- HTTPException(
1974
- status_code=404, detail="Run not found"
1975
- ),
1976
- )
1977
- break
1978
- elif run["status"] not in ("pending", "running"):
1979
- break
1980
- except WrappedHTTPException as e:
1981
- raise e.http_exception from None
1982
- except:
1983
- if cancel_on_disconnect:
1984
- create_task(cancel_run(thread_id, run_id))
1985
- raise
1986
- finally:
1987
- stream_manager = get_stream_manager()
1988
- await stream_manager.remove_queue(run_id, queue)
1989
-
1990
- @staticmethod
1991
- async def publish(
1992
- run_id: UUID,
1993
- event: str,
1994
- message: bytes,
1995
- ) -> None:
1996
- """Publish a message to all subscribers of the run stream."""
1997
- topic = f"run:{run_id}:stream:{event}".encode()
1998
-
1999
- stream_manager = get_stream_manager()
2000
- # Send to all queues subscribed to this run_id
2001
- await stream_manager.put(run_id, Message(topic=topic, data=message))
2002
-
2003
-
2004
- async def listen_for_cancellation(
2005
- queue: asyncio.Queue, run_id: UUID, done: "ValueEvent"
2006
- ):
2007
- """Listen for cancellation messages and set the done event accordingly."""
2008
- stream_manager = get_stream_manager()
2009
- control_key = f"run:{run_id}:control"
2010
-
2011
- if existing_queue := stream_manager.control_queues.get(run_id):
2012
- for message in existing_queue:
2013
- payload = message.data
2014
- if payload == b"rollback":
2015
- done.set(UserRollback())
2016
- elif payload == b"interrupt":
2017
- done.set(UserInterrupt())
2018
-
2019
- while not done.is_set():
2020
- try:
2021
- # This task gets cancelled when Runs.enter exits anyway,
2022
- # so we can have a pretty length timeout here
2023
- message = await asyncio.wait_for(queue.get(), timeout=240)
2024
- payload = message.data
2025
- if payload == b"rollback":
2026
- done.set(UserRollback())
2027
- elif payload == b"interrupt":
2028
- done.set(UserInterrupt())
2029
- elif payload == b"done":
2030
- done.set()
2031
- break
2032
-
2033
- # Store control messages for late subscribers
2034
- if message.topic.decode() == control_key:
2035
- stream_manager.control_queues[run_id].append(message)
2036
- except TimeoutError:
2037
- break
2038
-
2039
-
2040
- class Crons:
2041
- @staticmethod
2042
- async def put(
2043
- conn: InMemConnectionProto,
2044
- *,
2045
- payload: dict,
2046
- schedule: str,
2047
- cron_id: UUID | None = None,
2048
- thread_id: UUID | None = None,
2049
- end_time: datetime | None = None,
2050
- ctx: Auth.types.BaseAuthContext | None = None,
2051
- ) -> AsyncIterator[Cron]:
2052
- raise NotImplementedError
2053
-
2054
- @staticmethod
2055
- async def delete(
2056
- conn: InMemConnectionProto,
2057
- cron_id: UUID,
2058
- ctx: Auth.types.BaseAuthContext | None = None,
2059
- ) -> AsyncIterator[UUID]:
2060
- raise NotImplementedError
2061
-
2062
- @staticmethod
2063
- async def next(
2064
- conn: InMemConnectionProto,
2065
- ctx: Auth.types.BaseAuthContext | None = None,
2066
- ) -> AsyncIterator[Cron]:
2067
- yield
2068
- raise NotImplementedError("The in-mem server does not implement Crons.")
2069
-
2070
- @staticmethod
2071
- async def set_next_run_date(
2072
- conn: InMemConnectionProto,
2073
- cron_id: UUID,
2074
- next_run_date: datetime,
2075
- ctx: Auth.types.BaseAuthContext | None = None,
2076
- ) -> None:
2077
- raise NotImplementedError
2078
-
2079
- @staticmethod
2080
- async def search(
2081
- conn: InMemConnectionProto,
2082
- *,
2083
- assistant_id: UUID | None,
2084
- thread_id: UUID | None,
2085
- limit: int,
2086
- offset: int,
2087
- ctx: Auth.types.BaseAuthContext | None = None,
2088
- ) -> AsyncIterator[Cron]:
2089
- raise NotImplementedError
2090
-
2091
-
2092
- async def cancel_run(
2093
- thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
2094
- ) -> None:
2095
- async with connect() as conn:
2096
- await Runs.cancel(conn, [run_id], thread_id=thread_id, ctx=ctx)
2097
-
2098
-
2099
- def _delete_checkpoints_for_thread(
2100
- thread_id: str | UUID,
2101
- conn: InMemConnectionProto,
2102
- run_id: str | UUID | None = None,
2103
- ):
2104
- checkpointer = Checkpointer(conn)
2105
- thread_id = str(thread_id)
2106
- if thread_id not in checkpointer.storage:
2107
- return
2108
- if run_id:
2109
- # Look through metadata
2110
- run_id = str(run_id)
2111
- for checkpoint_ns, checkpoints in list(checkpointer.storage[thread_id].items()):
2112
- for checkpoint_id, (_, metadata_b, _) in list(checkpoints.items()):
2113
- metadata = checkpointer.serde.loads_typed(metadata_b)
2114
- if metadata.get("run_id") == run_id:
2115
- del checkpointer.storage[thread_id][checkpoint_ns][checkpoint_id]
2116
- if not checkpointer.storage[thread_id][checkpoint_ns]:
2117
- del checkpointer.storage[thread_id][checkpoint_ns]
2118
- else:
2119
- del checkpointer.storage[thread_id]
2120
- # Keys are (thread_id, checkpoint_ns, checkpoint_id)
2121
- checkpointer.writes = defaultdict(
2122
- dict, {k: v for k, v in checkpointer.writes.items() if k[0] != thread_id}
2123
- )
2124
-
2125
-
2126
- def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
2127
- """Check if metadata matches the filter conditions.
2128
-
2129
- Args:
2130
- metadata: The metadata to check
2131
- filters: The filter conditions to apply
2132
-
2133
- Returns:
2134
- True if the metadata matches all filter conditions, False otherwise
2135
- """
2136
- if not filters:
2137
- return True
2138
-
2139
- for key, value in filters.items():
2140
- if isinstance(value, dict):
2141
- op = next(iter(value))
2142
- filter_value = value[op]
2143
-
2144
- if op == "$eq":
2145
- if key not in metadata or metadata[key] != filter_value:
2146
- return False
2147
- elif op == "$contains":
2148
- if (
2149
- key not in metadata
2150
- or not isinstance(metadata[key], list)
2151
- or filter_value not in metadata[key]
2152
- ):
2153
- return False
2154
- else:
2155
- # Direct equality
2156
- if key not in metadata or metadata[key] != value:
2157
- return False
2158
-
2159
- return True
2160
-
2161
-
2162
- async def _empty_generator():
2163
- if False:
2164
- yield
2165
-
2166
-
2167
- __all__ = [
2168
- "Assistants",
2169
- "Crons",
2170
- "Runs",
2171
- "Threads",
2172
- ]