langgraph-api 0.0.47__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

langgraph_storage/ops.py DELETED
@@ -1,2175 +0,0 @@
1
- """Implementation of the LangGraph API using in-memory checkpointer & store."""
2
-
3
- import asyncio
4
- import base64
5
- import copy
6
- import json
7
- import logging
8
- import uuid
9
- from collections import defaultdict
10
- from collections.abc import AsyncIterator, Sequence
11
- from contextlib import asynccontextmanager
12
- from copy import deepcopy
13
- from datetime import UTC, datetime, timedelta
14
- from typing import Any, Literal, cast
15
- from uuid import UUID, uuid4
16
-
17
- import structlog
18
- from langgraph.checkpoint.serde.jsonplus import _msgpack_ext_hook_to_json
19
- from langgraph.pregel.debug import CheckpointPayload
20
- from langgraph.pregel.types import StateSnapshot
21
- from langgraph_sdk import Auth
22
- from starlette.exceptions import HTTPException
23
-
24
- from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent, create_task
25
- from langgraph_api.auth.custom import handle_event
26
- from langgraph_api.command import map_cmd
27
- from langgraph_api.config import ThreadTTLConfig
28
- from langgraph_api.errors import UserInterrupt, UserRollback
29
- from langgraph_api.graph import get_graph
30
- from langgraph_api.schema import (
31
- Assistant,
32
- Checkpoint,
33
- Config,
34
- Cron,
35
- IfNotExists,
36
- MetadataInput,
37
- MetadataValue,
38
- MultitaskStrategy,
39
- OnConflictBehavior,
40
- QueueStats,
41
- Run,
42
- RunStatus,
43
- StreamMode,
44
- Thread,
45
- ThreadStatus,
46
- ThreadUpdateResponse,
47
- )
48
- from langgraph_api.serde import Fragment
49
- from langgraph_api.utils import fetchone, get_auth_ctx
50
- from langgraph_storage.checkpoint import Checkpointer
51
- from langgraph_storage.database import InMemConnectionProto, connect
52
- from langgraph_storage.inmem_stream import Message, get_stream_manager
53
-
54
- logger = structlog.stdlib.get_logger(__name__)
55
-
56
-
57
- def _ensure_uuid(id_: str | uuid.UUID | None) -> uuid.UUID:
58
- if isinstance(id_, str):
59
- return uuid.UUID(id_)
60
- if id_ is None:
61
- return uuid4()
62
- return id_
63
-
64
-
65
- class WrappedHTTPException(Exception):
66
- def __init__(self, http_exception: HTTPException):
67
- self.http_exception = http_exception
68
-
69
-
70
- # Right now the whole API types as UUID but frequently passes a str
71
- # We ensure UUIDs for eveerything EXCEPT the checkpoint storage/writes,
72
- # which we leave as strings. This is because I'm too lazy to subclass fully
73
- # and we use non-UUID examples in the OSS version
74
-
75
-
76
- class Authenticated:
77
- resource: Literal["threads", "crons", "assistants"]
78
-
79
- @classmethod
80
- def _context(
81
- cls,
82
- ctx: Auth.types.BaseAuthContext | None,
83
- action: Literal["create", "read", "update", "delete", "create_run"],
84
- ) -> Auth.types.AuthContext | None:
85
- if not ctx:
86
- return
87
- return Auth.types.AuthContext(
88
- user=ctx.user,
89
- permissions=ctx.permissions,
90
- resource=cls.resource,
91
- action=action,
92
- )
93
-
94
- @classmethod
95
- async def handle_event(
96
- cls,
97
- ctx: Auth.types.BaseAuthContext | None,
98
- action: Literal["create", "read", "update", "delete", "search", "create_run"],
99
- value: Any,
100
- ) -> Auth.types.FilterType | None:
101
- ctx = ctx or get_auth_ctx()
102
- if not ctx:
103
- return
104
- return await handle_event(cls._context(ctx, action), value)
105
-
106
-
107
- class Assistants(Authenticated):
108
- resource = "assistants"
109
-
110
- @staticmethod
111
- async def search(
112
- conn: InMemConnectionProto,
113
- *,
114
- graph_id: str | None,
115
- metadata: MetadataInput,
116
- limit: int,
117
- offset: int,
118
- ctx: Auth.types.BaseAuthContext | None = None,
119
- ) -> AsyncIterator[Assistant]:
120
- metadata = metadata if metadata is not None else {}
121
- filters = await Assistants.handle_event(
122
- ctx,
123
- "search",
124
- Auth.types.AssistantsSearch(
125
- graph_id=graph_id, metadata=metadata, limit=limit, offset=offset
126
- ),
127
- )
128
-
129
- async def filter_and_yield() -> AsyncIterator[Assistant]:
130
- assistants = conn.store["assistants"]
131
- filtered_assistants = [
132
- assistant
133
- for assistant in assistants
134
- if (not graph_id or assistant["graph_id"] == graph_id)
135
- and (
136
- not metadata or is_jsonb_contained(assistant["metadata"], metadata)
137
- )
138
- and (not filters or _check_filter_match(assistant["metadata"], filters))
139
- ]
140
- filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
141
- for assistant in filtered_assistants[offset : offset + limit]:
142
- yield assistant
143
-
144
- return filter_and_yield()
145
-
146
- @staticmethod
147
- async def get(
148
- conn: InMemConnectionProto,
149
- assistant_id: UUID,
150
- ctx: Auth.types.BaseAuthContext | None = None,
151
- ) -> AsyncIterator[Assistant]:
152
- """Get an assistant by ID."""
153
- assistant_id = _ensure_uuid(assistant_id)
154
- filters = await Assistants.handle_event(
155
- ctx,
156
- "read",
157
- Auth.types.AssistantsRead(assistant_id=assistant_id),
158
- )
159
-
160
- async def _yield_result():
161
- for assistant in conn.store["assistants"]:
162
- if assistant["assistant_id"] == assistant_id and (
163
- not filters or _check_filter_match(assistant["metadata"], filters)
164
- ):
165
- yield assistant
166
-
167
- return _yield_result()
168
-
169
- @staticmethod
170
- async def put(
171
- conn: InMemConnectionProto,
172
- assistant_id: UUID,
173
- *,
174
- graph_id: str,
175
- config: Config,
176
- metadata: MetadataInput,
177
- if_exists: OnConflictBehavior,
178
- name: str,
179
- ctx: Auth.types.BaseAuthContext | None = None,
180
- description: str | None = None,
181
- ) -> AsyncIterator[Assistant]:
182
- """Insert an assistant."""
183
- assistant_id = _ensure_uuid(assistant_id)
184
- metadata = metadata if metadata is not None else {}
185
- filters = await Assistants.handle_event(
186
- ctx,
187
- "create",
188
- Auth.types.AssistantsCreate(
189
- assistant_id=assistant_id,
190
- graph_id=graph_id,
191
- config=config,
192
- metadata=metadata,
193
- name=name,
194
- ),
195
- )
196
- existing_assistant = next(
197
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
198
- None,
199
- )
200
- if existing_assistant:
201
- if filters and not _check_filter_match(
202
- existing_assistant["metadata"], filters
203
- ):
204
- raise HTTPException(
205
- status_code=409, detail=f"Assistant {assistant_id} already exists"
206
- )
207
- if if_exists == "raise":
208
- raise HTTPException(
209
- status_code=409, detail=f"Assistant {assistant_id} already exists"
210
- )
211
- elif if_exists == "do_nothing":
212
-
213
- async def _yield_existing():
214
- yield existing_assistant
215
-
216
- return _yield_existing()
217
-
218
- now = datetime.now(UTC)
219
- new_assistant: Assistant = {
220
- "assistant_id": assistant_id,
221
- "graph_id": graph_id,
222
- "config": config or {},
223
- "metadata": metadata or {},
224
- "name": name,
225
- "created_at": now,
226
- "updated_at": now,
227
- "version": 1,
228
- "description": description,
229
- }
230
- new_version = {
231
- "assistant_id": assistant_id,
232
- "version": 1,
233
- "graph_id": graph_id,
234
- "config": config or {},
235
- "metadata": metadata or {},
236
- "created_at": now,
237
- "name": name,
238
- }
239
- conn.store["assistants"].append(new_assistant)
240
- conn.store["assistant_versions"].append(new_version)
241
-
242
- async def _yield_new():
243
- yield new_assistant
244
-
245
- return _yield_new()
246
-
247
- @staticmethod
248
- async def patch(
249
- conn: InMemConnectionProto,
250
- assistant_id: UUID,
251
- *,
252
- config: dict | None = None,
253
- graph_id: str | None = None,
254
- metadata: MetadataInput | None = None,
255
- name: str | None = None,
256
- description: str | None = None,
257
- ctx: Auth.types.BaseAuthContext | None = None,
258
- ) -> AsyncIterator[Assistant]:
259
- """Update an assistant.
260
-
261
- Args:
262
- conn: The connection to the in-memory store.
263
- assistant_id: The assistant ID.
264
- graph_id: The graph ID.
265
- config: The assistant config.
266
- metadata: The assistant metadata.
267
- name: The assistant name.
268
- description: The assistant description.
269
- ctx: The auth context.
270
-
271
- Returns:
272
- return the updated assistant model.
273
- """
274
- assistant_id = _ensure_uuid(assistant_id)
275
- metadata = metadata if metadata is not None else {}
276
- filters = await Assistants.handle_event(
277
- ctx,
278
- "update",
279
- Auth.types.AssistantsUpdate(
280
- assistant_id=assistant_id,
281
- graph_id=graph_id,
282
- config=config,
283
- metadata=metadata,
284
- name=name,
285
- ),
286
- )
287
- assistant = next(
288
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
289
- None,
290
- )
291
- if not assistant:
292
- raise HTTPException(
293
- status_code=404, detail=f"Assistant {assistant_id} not found"
294
- )
295
- elif filters and not _check_filter_match(assistant["metadata"], filters):
296
- raise HTTPException(
297
- status_code=404, detail=f"Assistant {assistant_id} not found"
298
- )
299
-
300
- now = datetime.now(UTC)
301
- new_version = (
302
- max(
303
- v["version"]
304
- for v in conn.store["assistant_versions"]
305
- if v["assistant_id"] == assistant_id
306
- )
307
- + 1
308
- if conn.store["assistant_versions"]
309
- else 1
310
- )
311
-
312
- # Update assistant_versions table
313
- if metadata:
314
- metadata = {
315
- **assistant["metadata"],
316
- **metadata,
317
- }
318
- new_version_entry = {
319
- "assistant_id": assistant_id,
320
- "version": new_version,
321
- "graph_id": graph_id if graph_id is not None else assistant["graph_id"],
322
- "config": config if config is not None else assistant["config"],
323
- "metadata": metadata if metadata is not None else assistant["metadata"],
324
- "created_at": now,
325
- "name": name if name is not None else assistant["name"],
326
- "description": description
327
- if description is not None
328
- else assistant["description"],
329
- }
330
- conn.store["assistant_versions"].append(new_version_entry)
331
-
332
- # Update assistants table
333
- assistant.update(
334
- {
335
- "graph_id": new_version_entry["graph_id"],
336
- "config": new_version_entry["config"],
337
- "metadata": new_version_entry["metadata"],
338
- "name": name if name is not None else assistant["name"],
339
- "description": description
340
- if description is not None
341
- else assistant["description"],
342
- "updated_at": now,
343
- "version": new_version,
344
- }
345
- )
346
-
347
- async def _yield_updated():
348
- yield assistant
349
-
350
- return _yield_updated()
351
-
352
- @staticmethod
353
- async def delete(
354
- conn: InMemConnectionProto,
355
- assistant_id: UUID,
356
- ctx: Auth.types.BaseAuthContext | None = None,
357
- ) -> AsyncIterator[UUID]:
358
- """Delete an assistant by ID."""
359
- assistant_id = _ensure_uuid(assistant_id)
360
- filters = await Assistants.handle_event(
361
- ctx,
362
- "delete",
363
- Auth.types.AssistantsDelete(
364
- assistant_id=assistant_id,
365
- ),
366
- )
367
- assistant = next(
368
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
369
- None,
370
- )
371
-
372
- if not assistant:
373
- raise HTTPException(
374
- status_code=404, detail=f"Assistant with ID {assistant_id} not found"
375
- )
376
- elif filters and not _check_filter_match(assistant["metadata"], filters):
377
- raise HTTPException(
378
- status_code=404, detail=f"Assistant with ID {assistant_id} not found"
379
- )
380
-
381
- conn.store["assistants"] = [
382
- a for a in conn.store["assistants"] if a["assistant_id"] != assistant_id
383
- ]
384
- # Cascade delete assistant versions, crons, & runs on this assistant
385
- conn.store["assistant_versions"] = [
386
- v
387
- for v in conn.store["assistant_versions"]
388
- if v["assistant_id"] != assistant_id
389
- ]
390
- retained = []
391
- for run in conn.store["runs"]:
392
- if run["assistant_id"] == assistant_id:
393
- res = await Runs.delete(
394
- conn, run["run_id"], thread_id=run["thread_id"], ctx=ctx
395
- )
396
- await anext(res)
397
- else:
398
- retained.append(run)
399
-
400
- async def _yield_deleted():
401
- yield assistant_id
402
-
403
- return _yield_deleted()
404
-
405
- @staticmethod
406
- async def set_latest(
407
- conn: InMemConnectionProto,
408
- assistant_id: UUID,
409
- version: int,
410
- ctx: Auth.types.BaseAuthContext | None = None,
411
- ) -> AsyncIterator[Assistant]:
412
- """Change the version of an assistant."""
413
- assistant_id = _ensure_uuid(assistant_id)
414
- filters = await Assistants.handle_event(
415
- ctx,
416
- "update",
417
- Auth.types.AssistantsUpdate(
418
- assistant_id=assistant_id,
419
- version=version,
420
- ),
421
- )
422
- assistant = next(
423
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
424
- None,
425
- )
426
- if not assistant:
427
- raise HTTPException(
428
- status_code=404, detail=f"Assistant {assistant_id} not found"
429
- )
430
- elif filters and not _check_filter_match(assistant["metadata"], filters):
431
- raise HTTPException(
432
- status_code=404, detail=f"Assistant {assistant_id} not found"
433
- )
434
-
435
- version_data = next(
436
- (
437
- v
438
- for v in conn.store["assistant_versions"]
439
- if v["assistant_id"] == assistant_id and v["version"] == version
440
- ),
441
- None,
442
- )
443
- if not version_data:
444
- raise HTTPException(
445
- status_code=404,
446
- detail=f"Version {version} not found for assistant {assistant_id}",
447
- )
448
-
449
- assistant.update(
450
- {
451
- "config": version_data["config"],
452
- "metadata": version_data["metadata"],
453
- "version": version_data["version"],
454
- "updated_at": datetime.now(UTC),
455
- }
456
- )
457
-
458
- async def _yield_updated():
459
- yield assistant
460
-
461
- return _yield_updated()
462
-
463
- @staticmethod
464
- async def get_versions(
465
- conn: InMemConnectionProto,
466
- assistant_id: UUID,
467
- metadata: MetadataInput,
468
- limit: int,
469
- offset: int,
470
- ctx: Auth.types.BaseAuthContext | None = None,
471
- ) -> AsyncIterator[Assistant]:
472
- """Get all versions of an assistant."""
473
- assistant_id = _ensure_uuid(assistant_id)
474
- filters = await Assistants.handle_event(
475
- ctx,
476
- "read",
477
- Auth.types.AssistantsRead(assistant_id=assistant_id),
478
- )
479
- assistant = next(
480
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
481
- None,
482
- )
483
- if not assistant:
484
- raise HTTPException(
485
- status_code=404, detail=f"Assistant {assistant_id} not found"
486
- )
487
- versions = [
488
- v
489
- for v in conn.store["assistant_versions"]
490
- if v["assistant_id"] == assistant_id
491
- and (not metadata or is_jsonb_contained(v["metadata"], metadata))
492
- and (not filters or _check_filter_match(v["metadata"], filters))
493
- ]
494
-
495
- # Previously, the name was not included in the assistant_versions table. So we should add them here.
496
- for v in versions:
497
- if "name" not in v:
498
- v["name"] = assistant["name"]
499
-
500
- versions.sort(key=lambda x: x["version"], reverse=True)
501
-
502
- async def _yield_versions():
503
- for version in versions[offset : offset + limit]:
504
- yield version
505
-
506
- return _yield_versions()
507
-
508
-
509
- def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
510
- """
511
- Implements Postgres' @> (containment) operator for dictionaries.
512
- Returns True if superset contains all key/value pairs from subset.
513
- """
514
- for key, value in subset.items():
515
- if key not in superset:
516
- return False
517
- if isinstance(value, dict) and isinstance(superset[key], dict):
518
- if not is_jsonb_contained(superset[key], value):
519
- return False
520
- elif superset[key] != value:
521
- return False
522
- return True
523
-
524
-
525
- def bytes_decoder(obj):
526
- """Custom JSON decoder that converts base64 back to bytes."""
527
- if "__type__" in obj and obj["__type__"] == "bytes":
528
- return base64.b64decode(obj["value"].encode("utf-8"))
529
- return obj
530
-
531
-
532
- def _replace_thread_id(data, new_thread_id, thread_id):
533
- class BytesEncoder(json.JSONEncoder):
534
- """Custom JSON encoder that handles bytes by converting them to base64."""
535
-
536
- def default(self, obj):
537
- if isinstance(obj, bytes | bytearray):
538
- return {
539
- "__type__": "bytes",
540
- "value": base64.b64encode(
541
- obj.replace(
542
- str(thread_id).encode(), str(new_thread_id).encode()
543
- )
544
- ).decode("utf-8"),
545
- }
546
-
547
- return super().default(obj)
548
-
549
- try:
550
- json_str = json.dumps(data, cls=BytesEncoder, indent=2)
551
- except Exception as e:
552
- raise ValueError(data) from e
553
- json_str = json_str.replace(str(thread_id), str(new_thread_id))
554
-
555
- # Decoding back from JSON
556
- d = json.loads(json_str, object_hook=bytes_decoder)
557
- return d
558
-
559
-
560
- class Threads(Authenticated):
561
- resource = "threads"
562
-
563
- @staticmethod
564
- async def search(
565
- conn: InMemConnectionProto,
566
- *,
567
- metadata: MetadataInput,
568
- values: MetadataInput,
569
- status: ThreadStatus | None,
570
- limit: int,
571
- offset: int,
572
- ctx: Auth.types.BaseAuthContext | None = None,
573
- ) -> AsyncIterator[Thread]:
574
- threads = conn.store["threads"]
575
- filtered_threads: list[Thread] = []
576
- metadata = metadata if metadata is not None else {}
577
- values = values if values is not None else {}
578
- filters = await Threads.handle_event(
579
- ctx,
580
- "search",
581
- Auth.types.ThreadsSearch(
582
- metadata=metadata,
583
- values=values,
584
- status=status,
585
- limit=limit,
586
- offset=offset,
587
- ),
588
- )
589
-
590
- # Apply filters
591
- for thread in threads:
592
- if filters and not _check_filter_match(thread["metadata"], filters):
593
- continue
594
-
595
- if metadata and not is_jsonb_contained(thread["metadata"], metadata):
596
- continue
597
-
598
- if (
599
- values
600
- and "values" in thread
601
- and not is_jsonb_contained(thread["values"], values)
602
- ):
603
- continue
604
-
605
- if status and thread.get("status") != status:
606
- continue
607
-
608
- filtered_threads.append(thread)
609
-
610
- # Sort by created_at in descending order
611
- sorted_threads = sorted(
612
- filtered_threads, key=lambda x: x["created_at"], reverse=True
613
- )
614
-
615
- # Apply limit and offset
616
- paginated_threads = sorted_threads[offset : offset + limit]
617
-
618
- async def thread_iterator() -> AsyncIterator[Thread]:
619
- for thread in paginated_threads:
620
- yield thread
621
-
622
- return thread_iterator()
623
-
624
- @staticmethod
625
- async def _get_with_filters(
626
- conn: InMemConnectionProto,
627
- thread_id: UUID,
628
- filters: Auth.types.FilterType | None,
629
- ) -> Thread | None:
630
- thread_id = _ensure_uuid(thread_id)
631
- matching_thread = next(
632
- (
633
- thread
634
- for thread in conn.store["threads"]
635
- if thread["thread_id"] == thread_id
636
- ),
637
- None,
638
- )
639
- if not matching_thread or (
640
- filters and not _check_filter_match(matching_thread["metadata"], filters)
641
- ):
642
- return
643
-
644
- return matching_thread
645
-
646
- @staticmethod
647
- async def _get(
648
- conn: InMemConnectionProto,
649
- thread_id: UUID,
650
- ctx: Auth.types.BaseAuthContext | None = None,
651
- ) -> Thread | None:
652
- """Get a thread by ID."""
653
- thread_id = _ensure_uuid(thread_id)
654
- filters = await Threads.handle_event(
655
- ctx,
656
- "read",
657
- Auth.types.ThreadsRead(thread_id=thread_id),
658
- )
659
- return await Threads._get_with_filters(conn, thread_id, filters)
660
-
661
- @staticmethod
662
- async def get(
663
- conn: InMemConnectionProto,
664
- thread_id: UUID,
665
- ctx: Auth.types.BaseAuthContext | None = None,
666
- ) -> AsyncIterator[Thread]:
667
- """Get a thread by ID."""
668
- matching_thread = await Threads._get(conn, thread_id, ctx)
669
-
670
- if not matching_thread:
671
- raise HTTPException(
672
- status_code=404, detail=f"Thread with ID {thread_id} not found"
673
- )
674
-
675
- async def _yield_result():
676
- if matching_thread:
677
- yield matching_thread
678
-
679
- return _yield_result()
680
-
681
- @staticmethod
682
- async def put(
683
- conn: InMemConnectionProto,
684
- thread_id: UUID,
685
- *,
686
- metadata: MetadataInput,
687
- if_exists: OnConflictBehavior,
688
- ttl: ThreadTTLConfig | None = None,
689
- ctx: Auth.types.BaseAuthContext | None = None,
690
- ) -> AsyncIterator[Thread]:
691
- """Insert or update a thread."""
692
- thread_id = _ensure_uuid(thread_id)
693
- if metadata is None:
694
- metadata = {}
695
-
696
- # Check if thread already exists
697
- existing_thread = next(
698
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
699
- )
700
- filters = await Threads.handle_event(
701
- ctx,
702
- "create",
703
- Auth.types.ThreadsCreate(
704
- thread_id=thread_id, metadata=metadata, if_exists=if_exists
705
- ),
706
- )
707
-
708
- if existing_thread:
709
- if filters and not _check_filter_match(
710
- existing_thread["metadata"], filters
711
- ):
712
- # Should we use a different status code here?
713
- raise HTTPException(
714
- status_code=409, detail=f"Thread with ID {thread_id} already exists"
715
- )
716
- if if_exists == "raise":
717
- raise HTTPException(
718
- status_code=409, detail=f"Thread with ID {thread_id} already exists"
719
- )
720
- elif if_exists == "do_nothing":
721
-
722
- async def _yield_existing():
723
- yield existing_thread
724
-
725
- return _yield_existing()
726
- # Create new thread
727
- new_thread: Thread = {
728
- "thread_id": thread_id,
729
- "created_at": datetime.now(UTC),
730
- "updated_at": datetime.now(UTC),
731
- "metadata": copy.deepcopy(metadata),
732
- "status": "idle",
733
- "config": {},
734
- "values": None,
735
- }
736
-
737
- # Add to store
738
- conn.store["threads"].append(new_thread)
739
-
740
- async def _yield_new():
741
- yield new_thread
742
-
743
- return _yield_new()
744
-
745
- @staticmethod
746
- async def patch(
747
- conn: InMemConnectionProto,
748
- thread_id: UUID,
749
- *,
750
- metadata: MetadataValue,
751
- ctx: Auth.types.BaseAuthContext | None = None,
752
- ) -> AsyncIterator[Thread]:
753
- """Update a thread."""
754
- thread_list = conn.store["threads"]
755
- thread_idx = None
756
- thread_id = _ensure_uuid(thread_id)
757
-
758
- for idx, thread in enumerate(thread_list):
759
- if thread["thread_id"] == thread_id:
760
- thread_idx = idx
761
- break
762
-
763
- if thread_idx is not None:
764
- filters = await Threads.handle_event(
765
- ctx,
766
- "update",
767
- Auth.types.ThreadsUpdate(thread_id=thread_id, metadata=metadata),
768
- )
769
- if not filters or _check_filter_match(
770
- thread_list[thread_idx]["metadata"], filters
771
- ):
772
- thread = copy.deepcopy(thread_list[thread_idx])
773
- thread["metadata"] = {**thread["metadata"], **metadata}
774
- thread["updated_at"] = datetime.now(UTC)
775
- thread_list[thread_idx] = thread
776
-
777
- async def thread_iterator() -> AsyncIterator[Thread]:
778
- yield thread
779
-
780
- return thread_iterator()
781
-
782
- async def empty_iterator() -> AsyncIterator[Thread]:
783
- if False: # This ensures the iterator is empty
784
- yield
785
-
786
- return empty_iterator()
787
-
788
- @staticmethod
789
- async def set_status(
790
- conn: InMemConnectionProto,
791
- thread_id: UUID,
792
- checkpoint: CheckpointPayload | None,
793
- exception: BaseException | None,
794
- # This does not accept the auth context since it's only used internally
795
- ) -> None:
796
- """Set the status of a thread."""
797
- thread_id = _ensure_uuid(thread_id)
798
-
799
- async def has_pending_runs(conn_: InMemConnectionProto, tid: UUID) -> bool:
800
- """Check if thread has any pending runs."""
801
- return any(
802
- run["status"] in ("pending", "running") and run["thread_id"] == tid
803
- for run in conn_.store["runs"]
804
- )
805
-
806
- # Find the thread
807
- thread = next(
808
- (
809
- thread
810
- for thread in conn.store["threads"]
811
- if thread["thread_id"] == thread_id
812
- ),
813
- None,
814
- )
815
-
816
- if not thread:
817
- raise HTTPException(
818
- status_code=404, detail=f"Thread {thread_id} not found."
819
- )
820
-
821
- # Determine has_next from checkpoint
822
- has_next = False if checkpoint is None else bool(checkpoint["next"])
823
-
824
- # Determine base status
825
- if exception:
826
- status = "error"
827
- elif has_next:
828
- status = "interrupted"
829
- else:
830
- status = "idle"
831
-
832
- # Check for pending runs and update to busy if found
833
- if await has_pending_runs(conn, thread_id):
834
- status = "busy"
835
-
836
- # Update thread
837
- thread.update(
838
- {
839
- "updated_at": datetime.now(UTC),
840
- "values": checkpoint["values"] if checkpoint else None,
841
- "status": status,
842
- "interrupts": (
843
- {
844
- t["id"]: t["interrupts"]
845
- for t in checkpoint["tasks"]
846
- if t.get("interrupts")
847
- }
848
- if checkpoint
849
- else {}
850
- ),
851
- }
852
- )
853
-
854
- @staticmethod
855
- async def delete(
856
- conn: InMemConnectionProto,
857
- thread_id: UUID,
858
- ctx: Auth.types.BaseAuthContext | None = None,
859
- ) -> AsyncIterator[UUID]:
860
- """Delete a thread by ID and cascade delete all associated runs."""
861
- thread_list = conn.store["threads"]
862
- thread_idx = None
863
- thread_id = _ensure_uuid(thread_id)
864
-
865
- # Find the thread to delete
866
- for idx, thread in enumerate(thread_list):
867
- if thread["thread_id"] == thread_id:
868
- thread_idx = idx
869
- break
870
- filters = await Threads.handle_event(
871
- ctx,
872
- "delete",
873
- Auth.types.ThreadsDelete(thread_id=thread_id),
874
- )
875
- if (filters and not _check_filter_match(thread["metadata"], filters)) or (
876
- thread_idx is None
877
- ):
878
- raise HTTPException(
879
- status_code=404, detail=f"Thread with ID {thread_id} not found"
880
- )
881
- # Cascade delete all runs associated with this thread
882
- conn.store["runs"] = [
883
- run for run in conn.store["runs"] if run["thread_id"] != thread_id
884
- ]
885
- _delete_checkpoints_for_thread(thread_id, conn)
886
-
887
- if thread_idx is not None:
888
- # Remove the thread from the store
889
- deleted_thread = thread_list.pop(thread_idx)
890
-
891
- # Return an async iterator with the deleted thread_id
892
- async def id_iterator() -> AsyncIterator[UUID]:
893
- yield deleted_thread["thread_id"]
894
-
895
- return id_iterator()
896
-
897
- # If thread not found, return empty iterator
898
- async def empty_iterator() -> AsyncIterator[UUID]:
899
- if False: # This ensures the iterator is empty
900
- yield
901
-
902
- return empty_iterator()
903
-
904
- @staticmethod
905
- async def copy(
906
- conn: InMemConnectionProto,
907
- thread_id: UUID,
908
- ctx: Auth.types.BaseAuthContext | None = None,
909
- ) -> AsyncIterator[Thread]:
910
- """Create a copy of an existing thread."""
911
- thread_id = _ensure_uuid(thread_id)
912
- new_thread_id = uuid4()
913
- filters = await Threads.handle_event(
914
- ctx,
915
- "read",
916
- Auth.types.ThreadsRead(
917
- thread_id=new_thread_id,
918
- ),
919
- )
920
- async with conn.pipeline():
921
- # Find the original thread in our store
922
- original_thread = next(
923
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
924
- )
925
-
926
- if not original_thread:
927
- return _empty_generator()
928
- if filters and not _check_filter_match(
929
- original_thread["metadata"], filters
930
- ):
931
- return _empty_generator()
932
-
933
- # Create new thread with copied metadata
934
- new_thread: Thread = {
935
- "thread_id": new_thread_id,
936
- "created_at": datetime.now(tz=UTC),
937
- "updated_at": datetime.now(tz=UTC),
938
- "metadata": deepcopy(original_thread["metadata"]),
939
- "status": "idle",
940
- "config": {},
941
- }
942
-
943
- # Add new thread to store
944
- conn.store["threads"].append(new_thread)
945
-
946
- checkpointer = Checkpointer(conn)
947
- copied_storage = _replace_thread_id(
948
- checkpointer.storage[str(thread_id)], new_thread_id, thread_id
949
- )
950
- checkpointer.storage[str(new_thread_id)] = copied_storage
951
- # Copy the writes over (if any)
952
- outer_keys = []
953
- for k in checkpointer.writes:
954
- if k[0] == str(thread_id):
955
- outer_keys.append(k)
956
- for tid, checkpoint_ns, checkpoint_id in outer_keys:
957
- mapped = {
958
- k: _replace_thread_id(v, new_thread_id, thread_id)
959
- for k, v in checkpointer.writes[
960
- (str(tid), checkpoint_ns, checkpoint_id)
961
- ].items()
962
- }
963
-
964
- checkpointer.writes[
965
- (str(new_thread_id), checkpoint_ns, checkpoint_id)
966
- ] = mapped
967
- # Copy the blobs
968
- for k in list(checkpointer.blobs):
969
- if str(k[0]) == str(thread_id):
970
- new_key = (str(new_thread_id), *k[1:])
971
- checkpointer.blobs[new_key] = checkpointer.blobs[k]
972
-
973
- async def row_generator() -> AsyncIterator[Thread]:
974
- yield new_thread
975
-
976
- return row_generator()
977
-
978
- @staticmethod
979
- async def sweep_ttl(
980
- conn: InMemConnectionProto,
981
- *,
982
- limit: int | None = None,
983
- batch_size: int = 100,
984
- ) -> tuple[int, int]:
985
- # Not implemented for inmem server
986
- return (0, 0)
987
-
988
- class State(Authenticated):
989
- # We will treat this like a runs resource for now.
990
- resource = "threads"
991
-
992
- @staticmethod
993
- async def get(
994
- conn: InMemConnectionProto,
995
- config: Config,
996
- subgraphs: bool = False,
997
- ctx: Auth.types.BaseAuthContext | None = None,
998
- ) -> StateSnapshot:
999
- """Get state for a thread."""
1000
- checkpointer = await asyncio.to_thread(
1001
- Checkpointer, conn, unpack_hook=_msgpack_ext_hook_to_json
1002
- )
1003
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1004
- # Auth will be applied here so no need to use filters downstream
1005
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1006
- thread = await anext(thread_iter)
1007
- checkpoint = await checkpointer.aget(config)
1008
-
1009
- if not thread:
1010
- return StateSnapshot(
1011
- values={},
1012
- next=[],
1013
- config=None,
1014
- metadata=None,
1015
- created_at=None,
1016
- parent_config=None,
1017
- tasks=tuple(),
1018
- )
1019
-
1020
- metadata = thread.get("metadata", {})
1021
- thread_config = thread.get("config", {})
1022
-
1023
- if graph_id := metadata.get("graph_id"):
1024
- # format latest checkpoint for response
1025
- checkpointer.latest_iter = checkpoint
1026
- async with get_graph(
1027
- graph_id, thread_config, checkpointer=checkpointer
1028
- ) as graph:
1029
- result = await graph.aget_state(config, subgraphs=subgraphs)
1030
- if (
1031
- result.metadata is not None
1032
- and "checkpoint_ns" in result.metadata
1033
- and result.metadata["checkpoint_ns"] == ""
1034
- ):
1035
- result.metadata.pop("checkpoint_ns")
1036
- return result
1037
- else:
1038
- return StateSnapshot(
1039
- values={},
1040
- next=[],
1041
- config=None,
1042
- metadata=None,
1043
- created_at=None,
1044
- parent_config=None,
1045
- tasks=tuple(),
1046
- )
1047
-
1048
- @staticmethod
1049
- async def post(
1050
- conn: InMemConnectionProto,
1051
- config: Config,
1052
- values: Sequence[dict] | dict[str, Any] | None,
1053
- as_node: str | None = None,
1054
- ctx: Auth.types.BaseAuthContext | None = None,
1055
- ) -> ThreadUpdateResponse:
1056
- """Add state to a thread."""
1057
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1058
- filters = await Threads.handle_event(
1059
- ctx,
1060
- "update",
1061
- Auth.types.ThreadsUpdate(thread_id=thread_id),
1062
- )
1063
-
1064
- checkpointer = Checkpointer(conn)
1065
-
1066
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1067
- thread = await fetchone(
1068
- thread_iter, not_found_detail=f"Thread {thread_id} not found."
1069
- )
1070
- checkpoint = await checkpointer.aget(config)
1071
-
1072
- if not thread:
1073
- raise HTTPException(status_code=404, detail="Thread not found")
1074
- if not _check_filter_match(thread["metadata"], filters):
1075
- raise HTTPException(status_code=403, detail="Forbidden")
1076
-
1077
- metadata = thread["metadata"]
1078
- thread_config = thread["config"]
1079
-
1080
- if graph_id := metadata.get("graph_id"):
1081
- config["configurable"].setdefault("graph_id", graph_id)
1082
-
1083
- checkpointer.latest_iter = checkpoint
1084
- async with get_graph(
1085
- graph_id, thread_config, checkpointer=checkpointer
1086
- ) as graph:
1087
- update_config = config.copy()
1088
- update_config["configurable"] = {
1089
- **config["configurable"],
1090
- "checkpoint_ns": config["configurable"].get(
1091
- "checkpoint_ns", ""
1092
- ),
1093
- }
1094
- next_config = await graph.aupdate_state(
1095
- update_config, values, as_node=as_node
1096
- )
1097
-
1098
- # Get current state
1099
- state = await Threads.State.get(
1100
- conn, config, subgraphs=False, ctx=ctx
1101
- )
1102
- # Update thread values
1103
- for thread in conn.store["threads"]:
1104
- if thread["thread_id"] == thread_id:
1105
- thread["values"] = state.values
1106
- break
1107
-
1108
- return ThreadUpdateResponse(
1109
- checkpoint=next_config["configurable"],
1110
- # Including deprecated fields
1111
- configurable=next_config["configurable"],
1112
- checkpoint_id=next_config["configurable"]["checkpoint_id"],
1113
- )
1114
- else:
1115
- raise HTTPException(status_code=400, detail="Thread has no graph ID.")
1116
-
1117
- @staticmethod
1118
- async def bulk(
1119
- conn: InMemConnectionProto,
1120
- *,
1121
- config: Config,
1122
- supersteps: Sequence[dict],
1123
- ctx: Auth.types.BaseAuthContext | None = None,
1124
- ) -> ThreadUpdateResponse:
1125
- """Update a thread with a batch of state updates."""
1126
-
1127
- from langgraph.pregel.types import StateUpdate
1128
-
1129
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1130
- filters = await Threads.handle_event(
1131
- ctx,
1132
- "update",
1133
- Auth.types.ThreadsUpdate(thread_id=thread_id),
1134
- )
1135
-
1136
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1137
- thread = await fetchone(
1138
- thread_iter, not_found_detail=f"Thread {thread_id} not found."
1139
- )
1140
-
1141
- thread_config = thread["config"]
1142
- metadata = thread["metadata"]
1143
-
1144
- if not thread:
1145
- raise HTTPException(status_code=404, detail="Thread not found")
1146
-
1147
- if not _check_filter_match(metadata, filters):
1148
- raise HTTPException(status_code=403, detail="Forbidden")
1149
-
1150
- if graph_id := metadata.get("graph_id"):
1151
- config["configurable"].setdefault("graph_id", graph_id)
1152
- config["configurable"].setdefault("checkpoint_ns", "")
1153
-
1154
- async with get_graph(
1155
- graph_id, thread_config, checkpointer=Checkpointer(conn)
1156
- ) as graph:
1157
- next_config = await graph.abulk_update_state(
1158
- config,
1159
- [
1160
- [
1161
- StateUpdate(
1162
- map_cmd(update.get("command"))
1163
- if update.get("command")
1164
- else update.get("values"),
1165
- update.get("as_node"),
1166
- )
1167
- for update in superstep.get("updates", [])
1168
- ]
1169
- for superstep in supersteps
1170
- ],
1171
- )
1172
-
1173
- state = await Threads.State.get(
1174
- conn, config, subgraphs=False, ctx=ctx
1175
- )
1176
-
1177
- # update thread values
1178
- for thread in conn.store["threads"]:
1179
- if thread["thread_id"] == thread_id:
1180
- thread["values"] = state.values
1181
- break
1182
-
1183
- return ThreadUpdateResponse(
1184
- checkpoint=next_config["configurable"],
1185
- )
1186
- else:
1187
- raise HTTPException(status_code=400, detail="Thread has no graph ID")
1188
-
1189
- @staticmethod
1190
- async def list(
1191
- conn: InMemConnectionProto,
1192
- *,
1193
- config: Config,
1194
- limit: int = 10,
1195
- before: str | Checkpoint | None = None,
1196
- metadata: MetadataInput = None,
1197
- ctx: Auth.types.BaseAuthContext | None = None,
1198
- ) -> list[StateSnapshot]:
1199
- """Get the history of a thread."""
1200
-
1201
- thread_id = _ensure_uuid(config["configurable"]["thread_id"])
1202
- thread = None
1203
- filters = await Threads.handle_event(
1204
- ctx,
1205
- "read",
1206
- Auth.types.ThreadsRead(thread_id=thread_id),
1207
- )
1208
- thread = await fetchone(
1209
- await Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx)
1210
- )
1211
-
1212
- # Parse thread metadata and config
1213
- thread_metadata = thread["metadata"]
1214
- if not _check_filter_match(thread_metadata, filters):
1215
- return []
1216
-
1217
- thread_config = thread["config"]
1218
- # If graph_id exists, get state history
1219
- if graph_id := thread_metadata.get("graph_id"):
1220
- async with get_graph(
1221
- graph_id,
1222
- thread_config,
1223
- checkpointer=await asyncio.to_thread(
1224
- Checkpointer, conn, unpack_hook=_msgpack_ext_hook_to_json
1225
- ),
1226
- ) as graph:
1227
- # Convert before parameter if it's a string
1228
- before_param = (
1229
- {"configurable": {"checkpoint_id": before}}
1230
- if isinstance(before, str)
1231
- else before
1232
- )
1233
-
1234
- states = [
1235
- state
1236
- async for state in graph.aget_state_history(
1237
- config, limit=limit, filter=metadata, before=before_param
1238
- )
1239
- ]
1240
-
1241
- return states
1242
-
1243
- return []
1244
-
1245
-
1246
- RUN_LOCK = asyncio.Lock()
1247
-
1248
-
1249
- class Runs(Authenticated):
1250
- resource = "threads"
1251
-
1252
- @staticmethod
1253
- async def stats(conn: InMemConnectionProto) -> QueueStats:
1254
- """Get stats about the queue."""
1255
- pending_runs = [run for run in conn.store["runs"] if run["status"] == "pending"]
1256
- running_runs = [run for run in conn.store["runs"] if run["status"] == "running"]
1257
-
1258
- if not pending_runs and not running_runs:
1259
- return {
1260
- "n_pending": 0,
1261
- "max_age_secs": None,
1262
- "med_age_secs": None,
1263
- "n_running": 0,
1264
- }
1265
-
1266
- # Get all creation timestamps
1267
- created_times = [run.get("created_at") for run in (pending_runs + running_runs)]
1268
- created_times = [
1269
- t for t in created_times if t is not None
1270
- ] # Filter out None values
1271
-
1272
- if not created_times:
1273
- return {
1274
- "n_pending": len(pending_runs),
1275
- "n_running": len(running_runs),
1276
- "max_age_secs": None,
1277
- "med_age_secs": None,
1278
- }
1279
-
1280
- # Find oldest (max age)
1281
- oldest_time = min(created_times) # Earliest timestamp = oldest run
1282
-
1283
- # Find median age
1284
- sorted_times = sorted(created_times)
1285
- median_idx = len(sorted_times) // 2
1286
- median_time = sorted_times[median_idx]
1287
-
1288
- return {
1289
- "n_pending": len(pending_runs),
1290
- "n_running": len(running_runs),
1291
- "max_age_secs": oldest_time,
1292
- "med_age_secs": median_time,
1293
- }
1294
-
1295
- @staticmethod
1296
- async def next(wait: bool, limit: int = 1) -> AsyncIterator[tuple[Run, int]]:
1297
- """Get the next run from the queue, and the attempt number.
1298
- 1 is the first attempt, 2 is the first retry, etc."""
1299
- now = datetime.now(UTC)
1300
-
1301
- if wait:
1302
- await asyncio.sleep(0.5)
1303
- else:
1304
- await asyncio.sleep(0)
1305
-
1306
- async with connect() as conn, RUN_LOCK:
1307
- pending_runs = sorted(
1308
- [
1309
- run
1310
- for run in conn.store["runs"]
1311
- if run["status"] == "pending" and run.get("created_at", now) < now
1312
- ],
1313
- key=lambda x: x.get("created_at", datetime.min),
1314
- )
1315
-
1316
- if not pending_runs:
1317
- return
1318
-
1319
- # Try to lock and get the first available run
1320
- for _, run in zip(range(limit), pending_runs, strict=False):
1321
- if run["status"] != "pending":
1322
- continue
1323
-
1324
- run_id = run["run_id"]
1325
- thread_id = run["thread_id"]
1326
- thread = next(
1327
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id),
1328
- None,
1329
- )
1330
-
1331
- if thread is None:
1332
- await logger.awarning(
1333
- "Unexpected missing thread in Runs.next",
1334
- thread_id=run["thread_id"],
1335
- )
1336
- continue
1337
-
1338
- if run["status"] != "pending":
1339
- continue
1340
-
1341
- if any(
1342
- run["status"] == "running"
1343
- for run in conn.store["runs"]
1344
- if run["thread_id"] == thread_id
1345
- ):
1346
- continue
1347
- # Increment attempt counter
1348
- attempt = await conn.retry_counter.increment(run_id)
1349
- # Set run as "running"
1350
- run["status"] = "running"
1351
- yield run, attempt
1352
-
1353
- @asynccontextmanager
1354
- @staticmethod
1355
- async def enter(
1356
- run_id: UUID, loop: asyncio.AbstractEventLoop
1357
- ) -> AsyncIterator[ValueEvent]:
1358
- """Enter a run, listen for cancellation while running, signal when done."
1359
- This method should be called as a context manager by a worker executing a run.
1360
- """
1361
- stream_manager = get_stream_manager()
1362
- # Get queue for this run
1363
- queue = await Runs.Stream.subscribe(run_id)
1364
-
1365
- async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
1366
- done = ValueEvent()
1367
- tg.create_task(listen_for_cancellation(queue, run_id, done))
1368
-
1369
- # Give done event to caller
1370
- yield done
1371
- # Signal done to all subscribers
1372
- control_message = Message(
1373
- topic=f"run:{run_id}:control".encode(), data=b"done"
1374
- )
1375
-
1376
- # Store the control message for late subscribers
1377
- await stream_manager.put(run_id, control_message)
1378
- stream_manager.control_queues[run_id].append(control_message)
1379
- # Clean up this queue
1380
- await stream_manager.remove_queue(run_id, queue)
1381
-
1382
- @staticmethod
1383
- async def sweep(conn: InMemConnectionProto) -> list[UUID]:
1384
- """Sweep runs that are no longer running"""
1385
- return []
1386
-
1387
- @staticmethod
1388
- def _merge_jsonb(*objects: dict) -> dict:
1389
- """Mimics PostgreSQL's JSONB merge behavior"""
1390
- result = {}
1391
- for obj in objects:
1392
- if obj is not None:
1393
- result.update(copy.deepcopy(obj))
1394
- return result
1395
-
1396
- @staticmethod
1397
- def _get_configurable(config: dict) -> dict:
1398
- """Extract configurable from config, mimicking PostgreSQL's coalesce"""
1399
- return config.get("configurable", {})
1400
-
1401
- @staticmethod
1402
- async def put(
1403
- conn: InMemConnectionProto,
1404
- assistant_id: UUID,
1405
- kwargs: dict,
1406
- *,
1407
- thread_id: UUID | None = None,
1408
- user_id: str | None = None,
1409
- run_id: UUID | None = None,
1410
- status: RunStatus | None = "pending",
1411
- metadata: MetadataInput,
1412
- prevent_insert_if_inflight: bool,
1413
- multitask_strategy: MultitaskStrategy = "reject",
1414
- if_not_exists: IfNotExists = "reject",
1415
- after_seconds: int = 0,
1416
- ctx: Auth.types.BaseAuthContext | None = None,
1417
- ) -> AsyncIterator[Run]:
1418
- """Create a run."""
1419
- assistant_id = _ensure_uuid(assistant_id)
1420
- assistant = next(
1421
- (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
1422
- None,
1423
- )
1424
-
1425
- if not assistant:
1426
- return _empty_generator()
1427
-
1428
- thread_id = _ensure_uuid(thread_id) if thread_id else None
1429
- run_id = _ensure_uuid(run_id) if run_id else None
1430
- metadata = metadata if metadata is not None else {}
1431
- config = kwargs.get("config", {})
1432
-
1433
- # Handle thread creation/update
1434
- existing_thread = next(
1435
- (t for t in conn.store["threads"] if t["thread_id"] == thread_id), None
1436
- )
1437
- filters = await Runs.handle_event(
1438
- ctx,
1439
- "create_run",
1440
- Auth.types.RunsCreate(
1441
- thread_id=thread_id,
1442
- assistant_id=assistant_id,
1443
- run_id=run_id,
1444
- status=status,
1445
- metadata=metadata,
1446
- prevent_insert_if_inflight=prevent_insert_if_inflight,
1447
- multitask_strategy=multitask_strategy,
1448
- if_not_exists=if_not_exists,
1449
- after_seconds=after_seconds,
1450
- kwargs=kwargs,
1451
- ),
1452
- )
1453
- if existing_thread and filters:
1454
- # Reject if the user doesn't own the thread
1455
- if not _check_filter_match(existing_thread["metadata"], filters):
1456
- return _empty_generator()
1457
-
1458
- if not existing_thread and (thread_id is None or if_not_exists == "create"):
1459
- # Create new thread
1460
- if thread_id is None:
1461
- thread_id = uuid4()
1462
- thread = Thread(
1463
- thread_id=thread_id,
1464
- status="busy",
1465
- metadata={
1466
- "graph_id": assistant["graph_id"],
1467
- "assistant_id": str(assistant_id),
1468
- **metadata,
1469
- },
1470
- config=Runs._merge_jsonb(
1471
- assistant["config"],
1472
- config,
1473
- {
1474
- "configurable": Runs._merge_jsonb(
1475
- Runs._get_configurable(assistant["config"]),
1476
- Runs._get_configurable(config),
1477
- )
1478
- },
1479
- ),
1480
- created_at=datetime.now(UTC),
1481
- updated_at=datetime.now(UTC),
1482
- values=b"",
1483
- )
1484
- await logger.ainfo("Creating thread", thread_id=thread_id)
1485
- conn.store["threads"].append(thread)
1486
- elif existing_thread:
1487
- # Update existing thread
1488
- if existing_thread["status"] != "busy":
1489
- existing_thread["status"] = "busy"
1490
- existing_thread["metadata"] = Runs._merge_jsonb(
1491
- existing_thread["metadata"],
1492
- {
1493
- "graph_id": assistant["graph_id"],
1494
- "assistant_id": str(assistant_id),
1495
- },
1496
- )
1497
- existing_thread["config"] = Runs._merge_jsonb(
1498
- assistant["config"],
1499
- existing_thread["config"],
1500
- config,
1501
- {
1502
- "configurable": Runs._merge_jsonb(
1503
- Runs._get_configurable(assistant["config"]),
1504
- Runs._get_configurable(existing_thread["config"]),
1505
- Runs._get_configurable(config),
1506
- )
1507
- },
1508
- )
1509
- existing_thread["updated_at"] = datetime.now(UTC)
1510
- else:
1511
- return _empty_generator()
1512
-
1513
- # Check for inflight runs if needed
1514
- inflight_runs = [
1515
- r
1516
- for r in conn.store["runs"]
1517
- if r["thread_id"] == thread_id and r["status"] in ("pending", "running")
1518
- ]
1519
- if prevent_insert_if_inflight:
1520
- if inflight_runs:
1521
-
1522
- async def _return_inflight():
1523
- for run in inflight_runs:
1524
- yield run
1525
-
1526
- return _return_inflight()
1527
-
1528
- # Create new run
1529
- configurable = Runs._merge_jsonb(
1530
- Runs._get_configurable(assistant["config"]),
1531
- (
1532
- Runs._get_configurable(existing_thread["config"])
1533
- if existing_thread
1534
- else {}
1535
- ),
1536
- Runs._get_configurable(config),
1537
- {
1538
- "run_id": str(run_id),
1539
- "thread_id": str(thread_id),
1540
- "graph_id": assistant["graph_id"],
1541
- "assistant_id": str(assistant_id),
1542
- "user_id": (
1543
- config.get("configurable", {}).get("user_id")
1544
- or (
1545
- existing_thread["config"].get("configurable", {}).get("user_id")
1546
- if existing_thread
1547
- else None
1548
- )
1549
- or assistant["config"].get("configurable", {}).get("user_id")
1550
- or user_id
1551
- ),
1552
- },
1553
- )
1554
- merged_metadata = Runs._merge_jsonb(
1555
- assistant["metadata"],
1556
- existing_thread["metadata"] if existing_thread else {},
1557
- metadata,
1558
- )
1559
- new_run = Run(
1560
- run_id=run_id,
1561
- thread_id=thread_id,
1562
- assistant_id=assistant_id,
1563
- metadata=merged_metadata,
1564
- status=status,
1565
- kwargs=Runs._merge_jsonb(
1566
- kwargs,
1567
- {
1568
- "config": Runs._merge_jsonb(
1569
- assistant["config"],
1570
- config,
1571
- {"configurable": configurable},
1572
- {
1573
- "metadata": merged_metadata,
1574
- },
1575
- )
1576
- },
1577
- ),
1578
- multitask_strategy=multitask_strategy,
1579
- created_at=datetime.now(UTC) + timedelta(seconds=after_seconds),
1580
- updated_at=datetime.now(UTC),
1581
- )
1582
- conn.store["runs"].append(new_run)
1583
-
1584
- async def _yield_new():
1585
- yield new_run
1586
- for r in inflight_runs:
1587
- yield r
1588
-
1589
- return _yield_new()
1590
-
1591
- @staticmethod
1592
- async def get(
1593
- conn: InMemConnectionProto,
1594
- run_id: UUID,
1595
- *,
1596
- thread_id: UUID,
1597
- ctx: Auth.types.BaseAuthContext | None = None,
1598
- ) -> AsyncIterator[Run]:
1599
- """Get a run by ID."""
1600
-
1601
- run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1602
- filters = await Runs.handle_event(
1603
- ctx,
1604
- "read",
1605
- Auth.types.ThreadsRead(thread_id=thread_id),
1606
- )
1607
-
1608
- async def _yield_result():
1609
- matching_run = None
1610
- for run in conn.store["runs"]:
1611
- if run["run_id"] == run_id and run["thread_id"] == thread_id:
1612
- matching_run = run
1613
- break
1614
- if matching_run:
1615
- if filters:
1616
- thread = await Threads._get_with_filters(
1617
- conn, matching_run["thread_id"], filters
1618
- )
1619
- if not thread:
1620
- return
1621
- yield matching_run
1622
-
1623
- return _yield_result()
1624
-
1625
- @staticmethod
1626
- async def delete(
1627
- conn: InMemConnectionProto,
1628
- run_id: UUID,
1629
- *,
1630
- thread_id: UUID,
1631
- ctx: Auth.types.BaseAuthContext | None = None,
1632
- ) -> AsyncIterator[UUID]:
1633
- """Delete a run by ID."""
1634
- run_id, thread_id = _ensure_uuid(run_id), _ensure_uuid(thread_id)
1635
- filters = await Runs.handle_event(
1636
- ctx,
1637
- "delete",
1638
- Auth.types.ThreadsDelete(run_id=run_id, thread_id=thread_id),
1639
- )
1640
-
1641
- if filters:
1642
- thread = await Threads._get_with_filters(conn, thread_id, filters)
1643
- if not thread:
1644
- return _empty_generator()
1645
- _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
1646
- found = False
1647
- for i, run in enumerate(conn.store["runs"]):
1648
- if run["run_id"] == run_id and run["thread_id"] == thread_id:
1649
- del conn.store["runs"][i]
1650
- found = True
1651
- break
1652
- if not found:
1653
- raise HTTPException(status_code=404, detail="Run not found")
1654
-
1655
- async def _yield_deleted():
1656
- await logger.ainfo("Run deleted", run_id=run_id)
1657
- yield run_id
1658
-
1659
- return _yield_deleted()
1660
-
1661
- @staticmethod
1662
- async def join(
1663
- run_id: UUID,
1664
- *,
1665
- thread_id: UUID,
1666
- ctx: Auth.types.BaseAuthContext | None = None,
1667
- ) -> Fragment:
1668
- """Wait for a run to complete. If already done, return immediately.
1669
-
1670
- Returns:
1671
- the final state of the run.
1672
- """
1673
- async with connect() as conn:
1674
- # Validate ownership
1675
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1676
- await fetchone(thread_iter)
1677
- last_chunk: bytes | None = None
1678
- # wait for the run to complete
1679
- # Rely on this join's auth
1680
- async for mode, chunk in Runs.Stream.join(
1681
- run_id, thread_id=thread_id, stream_mode="values", ctx=ctx, ignore_404=True
1682
- ):
1683
- if mode == b"values":
1684
- last_chunk = chunk
1685
- # if we received a final chunk, return it
1686
- if last_chunk is not None:
1687
- # ie. if the run completed while we were waiting for it
1688
- return Fragment(last_chunk)
1689
- else:
1690
- # otherwise, the run had already finished, so fetch the state from thread
1691
- async with connect() as conn:
1692
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1693
- thread = await fetchone(thread_iter)
1694
- return thread["values"]
1695
-
1696
- @staticmethod
1697
- async def cancel(
1698
- conn: InMemConnectionProto,
1699
- run_ids: Sequence[UUID] | None = None,
1700
- *,
1701
- action: Literal["interrupt", "rollback"] = "interrupt",
1702
- thread_id: UUID | None = None,
1703
- status: Literal["pending", "running", "all"] | None = None,
1704
- ctx: Auth.types.BaseAuthContext | None = None,
1705
- ) -> None:
1706
- """
1707
- Cancel runs in memory. Must provide either:
1708
- 1) thread_id + run_ids, or
1709
- 2) status in {"pending", "running", "all"}.
1710
-
1711
- Steps:
1712
- - Validate arguments (one usage pattern or the other).
1713
- - Auth check: 'update' event via handle_event().
1714
- - Gather runs matching either the (thread_id, run_ids) set or the given status.
1715
- - For each run found:
1716
- * Send a cancellation message through the stream manager.
1717
- * If 'pending', set to 'interrupted' or delete (if action='rollback' and not actively queued).
1718
- * If 'running', the worker will pick up the message.
1719
- * Otherwise, log a warning for non-cancelable states.
1720
- - 404 if no runs are found or authorized.
1721
- """
1722
- # 1. Validate arguments
1723
- if status is not None:
1724
- # If status is set, user must NOT specify thread_id or run_ids
1725
- if thread_id is not None or run_ids is not None:
1726
- raise HTTPException(
1727
- status_code=422,
1728
- detail="Cannot specify 'thread_id' or 'run_ids' when using 'status'",
1729
- )
1730
- else:
1731
- # If status is not set, user must specify both thread_id and run_ids
1732
- if thread_id is None or run_ids is None:
1733
- raise HTTPException(
1734
- status_code=422,
1735
- detail="Must provide either a status or both 'thread_id' and 'run_ids'",
1736
- )
1737
-
1738
- # Convert and normalize inputs
1739
- if run_ids is not None:
1740
- run_ids = [_ensure_uuid(rid) for rid in run_ids]
1741
- if thread_id is not None:
1742
- thread_id = _ensure_uuid(thread_id)
1743
-
1744
- filters = await Runs.handle_event(
1745
- ctx,
1746
- "update",
1747
- Auth.types.ThreadsUpdate(
1748
- thread_id=thread_id, # type: ignore
1749
- action=action,
1750
- metadata={
1751
- "run_ids": run_ids,
1752
- "status": status,
1753
- },
1754
- ),
1755
- )
1756
-
1757
- status_list: tuple[str, ...] = ()
1758
- if status is not None:
1759
- if status == "all":
1760
- status_list = ("pending", "running")
1761
- elif status in ("pending", "running"):
1762
- status_list = (status,)
1763
- else:
1764
- raise ValueError(f"Unsupported status: {status}")
1765
-
1766
- def is_run_match(r: dict) -> bool:
1767
- """
1768
- Check whether a run in `conn.store["runs"]` meets the selection criteria.
1769
- """
1770
- if status_list:
1771
- return r["status"] in status_list
1772
- else:
1773
- return r["thread_id"] == thread_id and r["run_id"] in run_ids # type: ignore
1774
-
1775
- candidate_runs = [r for r in conn.store["runs"] if is_run_match(r)]
1776
-
1777
- if filters:
1778
- # If a run is found but not authorized by the thread filters, skip it
1779
- thread = (
1780
- await Threads._get_with_filters(conn, thread_id, filters)
1781
- if thread_id
1782
- else None
1783
- )
1784
- # If there's no matching thread, no runs are authorized.
1785
- if thread_id and not thread:
1786
- candidate_runs = []
1787
- # Otherwise, we might trust that `_get_with_filters` is the only constraint
1788
- # on thread. If your filters also apply to runs, you might do more checks here.
1789
-
1790
- if not candidate_runs:
1791
- raise HTTPException(status_code=404, detail="No runs found to cancel.")
1792
-
1793
- stream_manager = get_stream_manager()
1794
- coros = []
1795
- for run in candidate_runs:
1796
- run_id = run["run_id"]
1797
- control_message = Message(
1798
- topic=f"run:{run_id}:control".encode(),
1799
- data=action.encode(),
1800
- )
1801
- coros.append(stream_manager.put(run_id, control_message))
1802
-
1803
- queues = stream_manager.get_queues(run_id)
1804
-
1805
- if run["status"] in ("pending", "running"):
1806
- if queues or action != "rollback":
1807
- run["status"] = "interrupted"
1808
- run["updated_at"] = datetime.now(tz=UTC)
1809
- else:
1810
- await logger.ainfo(
1811
- "Eagerly deleting pending run with rollback action",
1812
- run_id=str(run_id),
1813
- status=run["status"],
1814
- )
1815
- coros.append(Runs.delete(conn, run_id, thread_id=run["thread_id"]))
1816
- else:
1817
- await logger.awarning(
1818
- "Attempted to cancel non-pending run.",
1819
- run_id=str(run_id),
1820
- status=run["status"],
1821
- )
1822
-
1823
- if coros:
1824
- await asyncio.gather(*coros)
1825
-
1826
- await logger.ainfo(
1827
- "Cancelled runs",
1828
- run_ids=[str(r["run_id"]) for r in candidate_runs],
1829
- thread_id=str(thread_id) if thread_id else None,
1830
- status=status,
1831
- action=action,
1832
- )
1833
-
1834
- @staticmethod
1835
- async def search(
1836
- conn: InMemConnectionProto,
1837
- thread_id: UUID,
1838
- *,
1839
- limit: int = 10,
1840
- offset: int = 0,
1841
- metadata: MetadataInput,
1842
- status: RunStatus | None = None,
1843
- ctx: Auth.types.BaseAuthContext | None = None,
1844
- ) -> AsyncIterator[Run]:
1845
- """List all runs by thread."""
1846
- runs = conn.store["runs"]
1847
- metadata = metadata if metadata is not None else {}
1848
- thread_id = _ensure_uuid(thread_id)
1849
- filters = await Runs.handle_event(
1850
- ctx,
1851
- "search",
1852
- Auth.types.ThreadsSearch(thread_id=thread_id, metadata=metadata),
1853
- )
1854
- filtered_runs = [
1855
- run
1856
- for run in runs
1857
- if run["thread_id"] == thread_id
1858
- and is_jsonb_contained(run["metadata"], metadata)
1859
- and (
1860
- not filters
1861
- or (await Threads._get_with_filters(conn, thread_id, filters))
1862
- )
1863
- and (status is None or run["status"] == status)
1864
- ]
1865
- sorted_runs = sorted(filtered_runs, key=lambda x: x["created_at"], reverse=True)
1866
- sliced_runs = sorted_runs[offset : offset + limit]
1867
-
1868
- async def _return():
1869
- for run in sliced_runs:
1870
- yield run
1871
-
1872
- return _return()
1873
-
1874
- @staticmethod
1875
- async def set_status(
1876
- conn: InMemConnectionProto, run_id: UUID, status: RunStatus
1877
- ) -> None:
1878
- """Set the status of a run."""
1879
- # Find the run in the store
1880
- run_id = _ensure_uuid(run_id)
1881
- run = next((run for run in conn.store["runs"] if run["run_id"] == run_id), None)
1882
-
1883
- if run:
1884
- # Update the status and updated_at timestamp
1885
- run["status"] = status
1886
- run["updated_at"] = datetime.now(tz=UTC)
1887
- return run
1888
- return None
1889
-
1890
- class Stream:
1891
- @staticmethod
1892
- async def subscribe(
1893
- run_id: UUID,
1894
- *,
1895
- stream_mode: "StreamMode | None" = None,
1896
- ) -> asyncio.Queue:
1897
- """Subscribe to the run stream, returning a queue."""
1898
- stream_manager = get_stream_manager()
1899
- queue = await stream_manager.add_queue(_ensure_uuid(run_id))
1900
-
1901
- # If there's a control message already stored, send it to the new subscriber
1902
- if control_messages := stream_manager.control_queues.get(run_id):
1903
- for control_msg in control_messages:
1904
- await queue.put(control_msg)
1905
- return queue
1906
-
1907
- @staticmethod
1908
- async def join(
1909
- run_id: UUID,
1910
- *,
1911
- thread_id: UUID,
1912
- ignore_404: bool = False,
1913
- cancel_on_disconnect: bool = False,
1914
- stream_mode: "StreamMode | asyncio.Queue | None" = None,
1915
- ctx: Auth.types.BaseAuthContext | None = None,
1916
- ) -> AsyncIterator[tuple[bytes, bytes]]:
1917
- """Stream the run output."""
1918
- log = logger.isEnabledFor(logging.DEBUG)
1919
- queue = (
1920
- stream_mode
1921
- if isinstance(stream_mode, asyncio.Queue)
1922
- else await Runs.Stream.subscribe(run_id, stream_mode=stream_mode)
1923
- )
1924
-
1925
- try:
1926
- async with connect() as conn:
1927
- filters = await Runs.handle_event(
1928
- ctx,
1929
- "read",
1930
- Auth.types.ThreadsRead(thread_id=thread_id),
1931
- )
1932
- if filters:
1933
- thread = await Threads._get_with_filters(
1934
- cast(InMemConnectionProto, conn), thread_id, filters
1935
- )
1936
- if not thread:
1937
- raise WrappedHTTPException(
1938
- HTTPException(
1939
- status_code=404, detail="Thread not found"
1940
- )
1941
- )
1942
- channel_prefix = f"run:{run_id}:stream:"
1943
- len_prefix = len(channel_prefix.encode())
1944
-
1945
- while True:
1946
- try:
1947
- # Wait for messages with a timeout
1948
- message = await asyncio.wait_for(queue.get(), timeout=0.5)
1949
- topic, data = message.topic, message.data
1950
-
1951
- if topic.decode() == f"run:{run_id}:control":
1952
- if data == b"done":
1953
- break
1954
- else:
1955
- # Extract mode from topic
1956
- yield topic[len_prefix:], data
1957
- if log:
1958
- await logger.adebug(
1959
- "Streamed run event",
1960
- run_id=str(run_id),
1961
- stream_mode=topic[len_prefix:],
1962
- data=data,
1963
- )
1964
- except TimeoutError:
1965
- # Check if the run is still pending
1966
- run_iter = await Runs.get(
1967
- conn, run_id, thread_id=thread_id, ctx=ctx
1968
- )
1969
- run = await anext(run_iter, None)
1970
-
1971
- if ignore_404 and run is None:
1972
- break
1973
- elif run is None:
1974
- yield (
1975
- b"error",
1976
- HTTPException(
1977
- status_code=404, detail="Run not found"
1978
- ),
1979
- )
1980
- break
1981
- elif run["status"] not in ("pending", "running"):
1982
- break
1983
- except WrappedHTTPException as e:
1984
- raise e.http_exception from None
1985
- except:
1986
- if cancel_on_disconnect:
1987
- create_task(cancel_run(thread_id, run_id))
1988
- raise
1989
- finally:
1990
- stream_manager = get_stream_manager()
1991
- await stream_manager.remove_queue(run_id, queue)
1992
-
1993
- @staticmethod
1994
- async def publish(
1995
- run_id: UUID,
1996
- event: str,
1997
- message: bytes,
1998
- ) -> None:
1999
- """Publish a message to all subscribers of the run stream."""
2000
- topic = f"run:{run_id}:stream:{event}".encode()
2001
-
2002
- stream_manager = get_stream_manager()
2003
- # Send to all queues subscribed to this run_id
2004
- await stream_manager.put(run_id, Message(topic=topic, data=message))
2005
-
2006
-
2007
- async def listen_for_cancellation(
2008
- queue: asyncio.Queue, run_id: UUID, done: "ValueEvent"
2009
- ):
2010
- """Listen for cancellation messages and set the done event accordingly."""
2011
- stream_manager = get_stream_manager()
2012
- control_key = f"run:{run_id}:control"
2013
-
2014
- if existing_queue := stream_manager.control_queues.get(run_id):
2015
- for message in existing_queue:
2016
- payload = message.data
2017
- if payload == b"rollback":
2018
- done.set(UserRollback())
2019
- elif payload == b"interrupt":
2020
- done.set(UserInterrupt())
2021
-
2022
- while not done.is_set():
2023
- try:
2024
- # This task gets cancelled when Runs.enter exits anyway,
2025
- # so we can have a pretty length timeout here
2026
- message = await asyncio.wait_for(queue.get(), timeout=240)
2027
- payload = message.data
2028
- if payload == b"rollback":
2029
- done.set(UserRollback())
2030
- elif payload == b"interrupt":
2031
- done.set(UserInterrupt())
2032
- elif payload == b"done":
2033
- done.set()
2034
- break
2035
-
2036
- # Store control messages for late subscribers
2037
- if message.topic.decode() == control_key:
2038
- stream_manager.control_queues[run_id].append(message)
2039
- except TimeoutError:
2040
- break
2041
-
2042
-
2043
- class Crons:
2044
- @staticmethod
2045
- async def put(
2046
- conn: InMemConnectionProto,
2047
- *,
2048
- payload: dict,
2049
- schedule: str,
2050
- cron_id: UUID | None = None,
2051
- thread_id: UUID | None = None,
2052
- end_time: datetime | None = None,
2053
- ctx: Auth.types.BaseAuthContext | None = None,
2054
- ) -> AsyncIterator[Cron]:
2055
- raise NotImplementedError
2056
-
2057
- @staticmethod
2058
- async def delete(
2059
- conn: InMemConnectionProto,
2060
- cron_id: UUID,
2061
- ctx: Auth.types.BaseAuthContext | None = None,
2062
- ) -> AsyncIterator[UUID]:
2063
- raise NotImplementedError
2064
-
2065
- @staticmethod
2066
- async def next(
2067
- conn: InMemConnectionProto,
2068
- ctx: Auth.types.BaseAuthContext | None = None,
2069
- ) -> AsyncIterator[Cron]:
2070
- yield
2071
- raise NotImplementedError("The in-mem server does not implement Crons.")
2072
-
2073
- @staticmethod
2074
- async def set_next_run_date(
2075
- conn: InMemConnectionProto,
2076
- cron_id: UUID,
2077
- next_run_date: datetime,
2078
- ctx: Auth.types.BaseAuthContext | None = None,
2079
- ) -> None:
2080
- raise NotImplementedError
2081
-
2082
- @staticmethod
2083
- async def search(
2084
- conn: InMemConnectionProto,
2085
- *,
2086
- assistant_id: UUID | None,
2087
- thread_id: UUID | None,
2088
- limit: int,
2089
- offset: int,
2090
- ctx: Auth.types.BaseAuthContext | None = None,
2091
- ) -> AsyncIterator[Cron]:
2092
- raise NotImplementedError
2093
-
2094
-
2095
- async def cancel_run(
2096
- thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
2097
- ) -> None:
2098
- async with connect() as conn:
2099
- await Runs.cancel(conn, [run_id], thread_id=thread_id, ctx=ctx)
2100
-
2101
-
2102
- def _delete_checkpoints_for_thread(
2103
- thread_id: str | UUID,
2104
- conn: InMemConnectionProto,
2105
- run_id: str | UUID | None = None,
2106
- ):
2107
- checkpointer = Checkpointer(conn)
2108
- thread_id = str(thread_id)
2109
- if thread_id not in checkpointer.storage:
2110
- return
2111
- if run_id:
2112
- # Look through metadata
2113
- run_id = str(run_id)
2114
- for checkpoint_ns, checkpoints in list(checkpointer.storage[thread_id].items()):
2115
- for checkpoint_id, (_, metadata_b, _) in list(checkpoints.items()):
2116
- metadata = checkpointer.serde.loads_typed(metadata_b)
2117
- if metadata.get("run_id") == run_id:
2118
- del checkpointer.storage[thread_id][checkpoint_ns][checkpoint_id]
2119
- if not checkpointer.storage[thread_id][checkpoint_ns]:
2120
- del checkpointer.storage[thread_id][checkpoint_ns]
2121
- else:
2122
- del checkpointer.storage[thread_id]
2123
- # Keys are (thread_id, checkpoint_ns, checkpoint_id)
2124
- checkpointer.writes = defaultdict(
2125
- dict, {k: v for k, v in checkpointer.writes.items() if k[0] != thread_id}
2126
- )
2127
-
2128
-
2129
- def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
2130
- """Check if metadata matches the filter conditions.
2131
-
2132
- Args:
2133
- metadata: The metadata to check
2134
- filters: The filter conditions to apply
2135
-
2136
- Returns:
2137
- True if the metadata matches all filter conditions, False otherwise
2138
- """
2139
- if not filters:
2140
- return True
2141
-
2142
- for key, value in filters.items():
2143
- if isinstance(value, dict):
2144
- op = next(iter(value))
2145
- filter_value = value[op]
2146
-
2147
- if op == "$eq":
2148
- if key not in metadata or metadata[key] != filter_value:
2149
- return False
2150
- elif op == "$contains":
2151
- if (
2152
- key not in metadata
2153
- or not isinstance(metadata[key], list)
2154
- or filter_value not in metadata[key]
2155
- ):
2156
- return False
2157
- else:
2158
- # Direct equality
2159
- if key not in metadata or metadata[key] != value:
2160
- return False
2161
-
2162
- return True
2163
-
2164
-
2165
- async def _empty_generator():
2166
- if False:
2167
- yield
2168
-
2169
-
2170
- __all__ = [
2171
- "Assistants",
2172
- "Crons",
2173
- "Runs",
2174
- "Threads",
2175
- ]