langgraph-runtime-inmem 0.6.4__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,6 @@ import uuid
11
11
  from collections import defaultdict
12
12
  from collections.abc import AsyncIterator, Sequence
13
13
  from contextlib import asynccontextmanager
14
- from copy import deepcopy
15
14
  from datetime import UTC, datetime, timedelta
16
15
  from typing import Any, Literal, cast
17
16
  from uuid import UUID, uuid4
@@ -27,17 +26,24 @@ from starlette.exceptions import HTTPException
27
26
 
28
27
  from langgraph_runtime_inmem.checkpoint import Checkpointer
29
28
  from langgraph_runtime_inmem.database import InMemConnectionProto, connect
30
- from langgraph_runtime_inmem.inmem_stream import Message, get_stream_manager
29
+ from langgraph_runtime_inmem.inmem_stream import (
30
+ THREADLESS_KEY,
31
+ ContextQueue,
32
+ Message,
33
+ get_stream_manager,
34
+ )
31
35
 
32
36
  if typing.TYPE_CHECKING:
33
37
  from langgraph_api.asyncio import ValueEvent
34
38
  from langgraph_api.config import ThreadTTLConfig
35
39
  from langgraph_api.schema import (
36
40
  Assistant,
41
+ AssistantSelectField,
37
42
  Checkpoint,
38
43
  Config,
39
44
  Context,
40
45
  Cron,
46
+ CronSelectField,
41
47
  DeprecatedInterrupt,
42
48
  IfNotExists,
43
49
  MetadataInput,
@@ -46,15 +52,19 @@ if typing.TYPE_CHECKING:
46
52
  OnConflictBehavior,
47
53
  QueueStats,
48
54
  Run,
55
+ RunSelectField,
49
56
  RunStatus,
50
57
  StreamMode,
51
58
  Thread,
59
+ ThreadSelectField,
52
60
  ThreadStatus,
61
+ ThreadStreamMode,
53
62
  ThreadUpdateResponse,
54
63
  )
55
64
  from langgraph_api.schema import Interrupt as InterruptSchema
56
- from langgraph_api.serde import Fragment
65
+ from langgraph_api.utils import AsyncConnectionProto
57
66
 
67
+ StreamHandler = ContextQueue
58
68
 
59
69
  logger = structlog.stdlib.get_logger(__name__)
60
70
 
@@ -136,6 +146,7 @@ class Assistants(Authenticated):
136
146
  offset: int,
137
147
  sort_by: str | None = None,
138
148
  sort_order: str | None = None,
149
+ select: list[AssistantSelectField] | None = None,
139
150
  ctx: Auth.types.BaseAuthContext | None = None,
140
151
  ) -> tuple[AsyncIterator[Assistant], int]:
141
152
  metadata = metadata if metadata is not None else {}
@@ -157,9 +168,6 @@ class Assistants(Authenticated):
157
168
  and (not filters or _check_filter_match(assistant["metadata"], filters))
158
169
  ]
159
170
 
160
- # Get total count before sorting and pagination
161
- total_count = len(filtered_assistants)
162
-
163
171
  # Sort based on sort_by and sort_order
164
172
  sort_by = sort_by.lower() if sort_by else None
165
173
  if sort_by and sort_by in (
@@ -181,22 +189,31 @@ class Assistants(Authenticated):
181
189
  else:
182
190
  filtered_assistants.sort(key=lambda x: x.get(sort_by), reverse=reverse)
183
191
  else:
192
+ sort_by = "created_at"
184
193
  # Default sorting by created_at in descending order
185
194
  filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
186
195
 
187
196
  # Apply pagination
188
197
  paginated_assistants = filtered_assistants[offset : offset + limit]
198
+ cur = offset + limit if len(filtered_assistants) > offset + limit else None
189
199
 
190
200
  async def assistant_iterator() -> AsyncIterator[Assistant]:
191
201
  for assistant in paginated_assistants:
192
- yield assistant
202
+ if select:
203
+ # Filter to only selected fields
204
+ filtered_assistant = {
205
+ k: v for k, v in assistant.items() if k in select
206
+ }
207
+ yield filtered_assistant
208
+ else:
209
+ yield assistant
193
210
 
194
- return assistant_iterator(), total_count
211
+ return assistant_iterator(), cur
195
212
 
196
213
  @staticmethod
197
214
  async def get(
198
215
  conn: InMemConnectionProto,
199
- assistant_id: UUID,
216
+ assistant_id: UUID | str,
200
217
  ctx: Auth.types.BaseAuthContext | None = None,
201
218
  ) -> AsyncIterator[Assistant]:
202
219
  """Get an assistant by ID."""
@@ -212,14 +229,14 @@ class Assistants(Authenticated):
212
229
  if assistant["assistant_id"] == assistant_id and (
213
230
  not filters or _check_filter_match(assistant["metadata"], filters)
214
231
  ):
215
- yield assistant
232
+ yield copy.deepcopy(assistant)
216
233
 
217
234
  return _yield_result()
218
235
 
219
236
  @staticmethod
220
237
  async def put(
221
238
  conn: InMemConnectionProto,
222
- assistant_id: UUID,
239
+ assistant_id: UUID | str,
223
240
  *,
224
241
  graph_id: str,
225
242
  config: Config,
@@ -231,6 +248,8 @@ class Assistants(Authenticated):
231
248
  description: str | None = None,
232
249
  ) -> AsyncIterator[Assistant]:
233
250
  """Insert an assistant."""
251
+ from langgraph_api.graph import GRAPHS
252
+
234
253
  assistant_id = _ensure_uuid(assistant_id)
235
254
  metadata = metadata if metadata is not None else {}
236
255
  filters = await Assistants.handle_event(
@@ -245,6 +264,22 @@ class Assistants(Authenticated):
245
264
  name=name,
246
265
  ),
247
266
  )
267
+
268
+ if config.get("configurable") and context:
269
+ raise HTTPException(
270
+ status_code=400,
271
+ detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
272
+ )
273
+
274
+ if graph_id not in GRAPHS:
275
+ raise HTTPException(status_code=404, detail=f"Graph {graph_id} not found")
276
+
277
+ # Keep config and context up to date with one another
278
+ if config.get("configurable"):
279
+ context = config["configurable"]
280
+ elif context:
281
+ config["configurable"] = context
282
+
248
283
  existing_assistant = next(
249
284
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
250
285
  None,
@@ -330,6 +365,7 @@ class Assistants(Authenticated):
330
365
  """
331
366
  assistant_id = _ensure_uuid(assistant_id)
332
367
  metadata = metadata if metadata is not None else {}
368
+ config = config if config is not None else {}
333
369
  filters = await Assistants.handle_event(
334
370
  ctx,
335
371
  "update",
@@ -342,6 +378,19 @@ class Assistants(Authenticated):
342
378
  name=name,
343
379
  ),
344
380
  )
381
+
382
+ if config.get("configurable") and context:
383
+ raise HTTPException(
384
+ status_code=400,
385
+ detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
386
+ )
387
+
388
+ # Keep config and context up to date with one another
389
+ if config.get("configurable"):
390
+ context = config["configurable"]
391
+ elif context:
392
+ config["configurable"] = context
393
+
345
394
  assistant = next(
346
395
  (a for a in conn.store["assistants"] if a["assistant_id"] == assistant_id),
347
396
  None,
@@ -367,19 +416,17 @@ class Assistants(Authenticated):
367
416
  else 1
368
417
  )
369
418
 
370
- # Update assistant_versions table
371
- if metadata:
372
- metadata = {
373
- **assistant["metadata"],
374
- **metadata,
375
- }
376
419
  new_version_entry = {
377
420
  "assistant_id": assistant_id,
378
421
  "version": new_version,
379
422
  "graph_id": graph_id if graph_id is not None else assistant["graph_id"],
380
- "config": config if config is not None else assistant["config"],
381
- "context": context if context is not None else assistant["context"],
382
- "metadata": metadata if metadata is not None else assistant["metadata"],
423
+ "config": config if config else assistant["config"],
424
+ "context": context if context is not None else assistant.get("context", {}),
425
+ "metadata": (
426
+ {**assistant["metadata"], **metadata}
427
+ if metadata is not None
428
+ else assistant["metadata"]
429
+ ),
383
430
  "created_at": now,
384
431
  "name": name if name is not None else assistant["name"],
385
432
  "description": (
@@ -514,6 +561,8 @@ class Assistants(Authenticated):
514
561
  "metadata": version_data["metadata"],
515
562
  "version": version_data["version"],
516
563
  "updated_at": datetime.now(UTC),
564
+ "name": version_data["name"],
565
+ "description": version_data["description"],
517
566
  }
518
567
  )
519
568
 
@@ -572,6 +621,37 @@ class Assistants(Authenticated):
572
621
 
573
622
  return _yield_versions()
574
623
 
624
+ @staticmethod
625
+ async def count(
626
+ conn: InMemConnectionProto,
627
+ *,
628
+ graph_id: str | None = None,
629
+ metadata: MetadataInput = None,
630
+ ctx: Auth.types.BaseAuthContext | None = None,
631
+ ) -> int:
632
+ """Get count of assistants."""
633
+ metadata = metadata if metadata is not None else {}
634
+ filters = await Assistants.handle_event(
635
+ ctx,
636
+ "search",
637
+ Auth.types.AssistantsSearch(
638
+ graph_id=graph_id, metadata=metadata, limit=0, offset=0
639
+ ),
640
+ )
641
+
642
+ count = 0
643
+ for assistant in conn.store["assistants"]:
644
+ if (
645
+ (not graph_id or assistant["graph_id"] == graph_id)
646
+ and (
647
+ not metadata or is_jsonb_contained(assistant["metadata"], metadata)
648
+ )
649
+ and (not filters or _check_filter_match(assistant["metadata"], filters))
650
+ ):
651
+ count += 1
652
+
653
+ return count
654
+
575
655
 
576
656
  def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
577
657
  """
@@ -649,13 +729,13 @@ def _patch_interrupt(
649
729
  interrupt = Interrupt(**interrupt)
650
730
 
651
731
  return {
652
- "id": interrupt.interrupt_id
653
- if hasattr(interrupt, "interrupt_id")
654
- else None,
732
+ "id": (
733
+ interrupt.interrupt_id if hasattr(interrupt, "interrupt_id") else None
734
+ ),
655
735
  "value": interrupt.value,
656
736
  "resumable": interrupt.resumable,
657
737
  "ns": interrupt.ns,
658
- "when": interrupt.when,
738
+ "when": interrupt.when, # type: ignore[unresolved-attribute]
659
739
  }
660
740
 
661
741
 
@@ -666,6 +746,7 @@ class Threads(Authenticated):
666
746
  async def search(
667
747
  conn: InMemConnectionProto,
668
748
  *,
749
+ ids: list[str] | list[UUID] | None = None,
669
750
  metadata: MetadataInput,
670
751
  values: MetadataInput,
671
752
  status: ThreadStatus | None,
@@ -673,6 +754,7 @@ class Threads(Authenticated):
673
754
  offset: int,
674
755
  sort_by: str | None = None,
675
756
  sort_order: str | None = None,
757
+ select: list[ThreadSelectField] | None = None,
676
758
  ctx: Auth.types.BaseAuthContext | None = None,
677
759
  ) -> tuple[AsyncIterator[Thread], int]:
678
760
  threads = conn.store["threads"]
@@ -692,7 +774,19 @@ class Threads(Authenticated):
692
774
  )
693
775
 
694
776
  # Apply filters
777
+ id_set: set[UUID] | None = None
778
+ if ids:
779
+ id_set = set()
780
+ for i in ids:
781
+ try:
782
+ id_set.add(_ensure_uuid(i))
783
+ except Exception:
784
+ raise HTTPException(
785
+ status_code=400, detail="Invalid thread ID " + str(i)
786
+ ) from None
695
787
  for thread in threads:
788
+ if id_set is not None and thread.get("thread_id") not in id_set:
789
+ continue
696
790
  if filters and not _check_filter_match(thread["metadata"], filters):
697
791
  continue
698
792
 
@@ -710,8 +804,6 @@ class Threads(Authenticated):
710
804
  continue
711
805
 
712
806
  filtered_threads.append(thread)
713
- # Get total count before pagination
714
- total_count = len(filtered_threads)
715
807
 
716
808
  if sort_by and sort_by in [
717
809
  "thread_id",
@@ -724,20 +816,26 @@ class Threads(Authenticated):
724
816
  filtered_threads, key=lambda x: x.get(sort_by), reverse=reverse
725
817
  )
726
818
  else:
819
+ sort_by = "created_at"
727
820
  # Default sorting by created_at in descending order
728
821
  sorted_threads = sorted(
729
- filtered_threads, key=lambda x: x["created_at"], reverse=True
822
+ filtered_threads, key=lambda x: x["updated_at"], reverse=True
730
823
  )
731
824
 
732
825
  # Apply limit and offset
733
826
  paginated_threads = sorted_threads[offset : offset + limit]
827
+ cursor = offset + limit if len(sorted_threads) > offset + limit else None
734
828
 
735
829
  async def thread_iterator() -> AsyncIterator[Thread]:
736
830
  for thread in paginated_threads:
737
- yield thread
831
+ if select:
832
+ # Filter to only selected fields
833
+ filtered_thread = {k: v for k, v in thread.items() if k in select}
834
+ yield filtered_thread
835
+ else:
836
+ yield thread
738
837
 
739
- # Return both the iterator and the total count
740
- return thread_iterator(), total_count
838
+ return thread_iterator(), cursor
741
839
 
742
840
  @staticmethod
743
841
  async def _get_with_filters(
@@ -799,7 +897,7 @@ class Threads(Authenticated):
799
897
  @staticmethod
800
898
  async def put(
801
899
  conn: InMemConnectionProto,
802
- thread_id: UUID,
900
+ thread_id: UUID | str,
803
901
  *,
804
902
  metadata: MetadataInput,
805
903
  if_exists: OnConflictBehavior,
@@ -866,6 +964,7 @@ class Threads(Authenticated):
866
964
  thread_id: UUID,
867
965
  *,
868
966
  metadata: MetadataValue,
967
+ ttl: ThreadTTLConfig | None = None,
869
968
  ctx: Auth.types.BaseAuthContext | None = None,
870
969
  ) -> AsyncIterator[Thread]:
871
970
  """Update a thread."""
@@ -978,6 +1077,7 @@ class Threads(Authenticated):
978
1077
  thread_id: UUID,
979
1078
  run_id: UUID,
980
1079
  run_status: RunStatus | Literal["rollback"],
1080
+ graph_id: str,
981
1081
  checkpoint: CheckpointPayload | None = None,
982
1082
  exception: BaseException | None = None,
983
1083
  ) -> None:
@@ -1057,6 +1157,7 @@ class Threads(Authenticated):
1057
1157
  final_thread_status = "busy"
1058
1158
  else:
1059
1159
  final_thread_status = base_thread_status
1160
+ thread["metadata"]["graph_id"] = graph_id
1060
1161
  thread.update(
1061
1162
  {
1062
1163
  "updated_at": now,
@@ -1136,13 +1237,23 @@ class Threads(Authenticated):
1136
1237
  """Create a copy of an existing thread."""
1137
1238
  thread_id = _ensure_uuid(thread_id)
1138
1239
  new_thread_id = uuid4()
1139
- filters = await Threads.handle_event(
1240
+ read_filters = await Threads.handle_event(
1140
1241
  ctx,
1141
1242
  "read",
1142
1243
  Auth.types.ThreadsRead(
1244
+ thread_id=thread_id,
1245
+ ),
1246
+ )
1247
+ # Assert that the user has permissions to create a new thread.
1248
+ # (We don't actually need the filters.)
1249
+ await Threads.handle_event(
1250
+ ctx,
1251
+ "create",
1252
+ Auth.types.ThreadsCreate(
1143
1253
  thread_id=new_thread_id,
1144
1254
  ),
1145
1255
  )
1256
+
1146
1257
  async with conn.pipeline():
1147
1258
  # Find the original thread in our store
1148
1259
  original_thread = next(
@@ -1151,8 +1262,8 @@ class Threads(Authenticated):
1151
1262
 
1152
1263
  if not original_thread:
1153
1264
  return _empty_generator()
1154
- if filters and not _check_filter_match(
1155
- original_thread["metadata"], filters
1265
+ if read_filters and not _check_filter_match(
1266
+ original_thread["metadata"], read_filters
1156
1267
  ):
1157
1268
  return _empty_generator()
1158
1269
 
@@ -1161,7 +1272,7 @@ class Threads(Authenticated):
1161
1272
  "thread_id": new_thread_id,
1162
1273
  "created_at": datetime.now(tz=UTC),
1163
1274
  "updated_at": datetime.now(tz=UTC),
1164
- "metadata": deepcopy(original_thread["metadata"]),
1275
+ "metadata": copy.deepcopy(original_thread["metadata"]),
1165
1276
  "status": "idle",
1166
1277
  "config": {},
1167
1278
  }
@@ -1248,9 +1359,24 @@ class Threads(Authenticated):
1248
1359
  )
1249
1360
 
1250
1361
  metadata = thread.get("metadata", {})
1251
- thread_config = thread.get("config", {})
1362
+ thread_config = cast(dict[str, Any], thread.get("config", {}))
1363
+ thread_config = {
1364
+ **thread_config,
1365
+ "configurable": {
1366
+ **thread_config.get("configurable", {}),
1367
+ **config.get("configurable", {}),
1368
+ },
1369
+ }
1252
1370
 
1253
- if graph_id := metadata.get("graph_id"):
1371
+ # Fallback to graph_id from run if not in thread metadata
1372
+ graph_id = metadata.get("graph_id")
1373
+ if not graph_id:
1374
+ for run in conn.store["runs"]:
1375
+ if run["thread_id"] == thread_id:
1376
+ graph_id = run["kwargs"]["config"]["configurable"]["graph_id"]
1377
+ break
1378
+
1379
+ if graph_id:
1254
1380
  # format latest checkpoint for response
1255
1381
  checkpointer.latest_iter = checkpoint
1256
1382
  async with get_graph(
@@ -1290,6 +1416,7 @@ class Threads(Authenticated):
1290
1416
  """Add state to a thread."""
1291
1417
  from langgraph_api.graph import get_graph
1292
1418
  from langgraph_api.schema import ThreadUpdateResponse
1419
+ from langgraph_api.state import state_snapshot_to_thread_state
1293
1420
  from langgraph_api.store import get_store
1294
1421
  from langgraph_api.utils import fetchone
1295
1422
 
@@ -1327,8 +1454,23 @@ class Threads(Authenticated):
1327
1454
  status_code=409,
1328
1455
  detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
1329
1456
  )
1457
+ thread_config = {
1458
+ **thread_config,
1459
+ "configurable": {
1460
+ **thread_config.get("configurable", {}),
1461
+ **config.get("configurable", {}),
1462
+ },
1463
+ }
1330
1464
 
1331
- if graph_id := metadata.get("graph_id"):
1465
+ # Fallback to graph_id from run if not in thread metadata
1466
+ graph_id = metadata.get("graph_id")
1467
+ if not graph_id:
1468
+ for run in conn.store["runs"]:
1469
+ if run["thread_id"] == thread_id:
1470
+ graph_id = run["kwargs"]["config"]["configurable"]["graph_id"]
1471
+ break
1472
+
1473
+ if graph_id:
1332
1474
  config["configurable"].setdefault("graph_id", graph_id)
1333
1475
 
1334
1476
  checkpointer.latest_iter = checkpoint
@@ -1359,6 +1501,19 @@ class Threads(Authenticated):
1359
1501
  thread["values"] = state.values
1360
1502
  break
1361
1503
 
1504
+ # Publish state update event
1505
+ from langgraph_api.serde import json_dumpb
1506
+
1507
+ event_data = {
1508
+ "state": state_snapshot_to_thread_state(state),
1509
+ "thread_id": str(thread_id),
1510
+ }
1511
+ await Threads.Stream.publish(
1512
+ thread_id,
1513
+ "state_update",
1514
+ json_dumpb(event_data),
1515
+ )
1516
+
1362
1517
  return ThreadUpdateResponse(
1363
1518
  checkpoint=next_config["configurable"],
1364
1519
  # Including deprecated fields
@@ -1366,7 +1521,11 @@ class Threads(Authenticated):
1366
1521
  checkpoint_id=next_config["configurable"]["checkpoint_id"],
1367
1522
  )
1368
1523
  else:
1369
- raise HTTPException(status_code=400, detail="Thread has no graph ID.")
1524
+ raise HTTPException(
1525
+ status_code=400,
1526
+ detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
1527
+ " This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
1528
+ )
1370
1529
 
1371
1530
  @staticmethod
1372
1531
  async def bulk(
@@ -1397,7 +1556,14 @@ class Threads(Authenticated):
1397
1556
  thread_iter, not_found_detail=f"Thread {thread_id} not found."
1398
1557
  )
1399
1558
 
1400
- thread_config = thread["config"]
1559
+ thread_config = cast(dict[str, Any], thread["config"])
1560
+ thread_config = {
1561
+ **thread_config,
1562
+ "configurable": {
1563
+ **thread_config.get("configurable", {}),
1564
+ **config.get("configurable", {}),
1565
+ },
1566
+ }
1401
1567
  metadata = thread["metadata"]
1402
1568
 
1403
1569
  if not thread:
@@ -1444,18 +1610,35 @@ class Threads(Authenticated):
1444
1610
  thread["values"] = state.values
1445
1611
  break
1446
1612
 
1613
+ # Publish state update event
1614
+ from langgraph_api.serde import json_dumpb
1615
+
1616
+ event_data = {
1617
+ "state": state,
1618
+ "thread_id": str(thread_id),
1619
+ }
1620
+ await Threads.Stream.publish(
1621
+ thread_id,
1622
+ "state_update",
1623
+ json_dumpb(event_data),
1624
+ )
1625
+
1447
1626
  return ThreadUpdateResponse(
1448
1627
  checkpoint=next_config["configurable"],
1449
1628
  )
1450
1629
  else:
1451
- raise HTTPException(status_code=400, detail="Thread has no graph ID")
1630
+ raise HTTPException(
1631
+ status_code=400,
1632
+ detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
1633
+ " This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
1634
+ )
1452
1635
 
1453
1636
  @staticmethod
1454
1637
  async def list(
1455
1638
  conn: InMemConnectionProto,
1456
1639
  *,
1457
1640
  config: Config,
1458
- limit: int = 10,
1641
+ limit: int = 1,
1459
1642
  before: str | Checkpoint | None = None,
1460
1643
  metadata: MetadataInput = None,
1461
1644
  ctx: Auth.types.BaseAuthContext | None = None,
@@ -1481,7 +1664,14 @@ class Threads(Authenticated):
1481
1664
  if not _check_filter_match(thread_metadata, filters):
1482
1665
  return []
1483
1666
 
1484
- thread_config = thread["config"]
1667
+ thread_config = cast(dict[str, Any], thread["config"])
1668
+ thread_config = {
1669
+ **thread_config,
1670
+ "configurable": {
1671
+ **thread_config.get("configurable", {}),
1672
+ **config.get("configurable", {}),
1673
+ },
1674
+ }
1485
1675
  # If graph_id exists, get state history
1486
1676
  if graph_id := thread_metadata.get("graph_id"):
1487
1677
  async with get_graph(
@@ -1510,6 +1700,265 @@ class Threads(Authenticated):
1510
1700
 
1511
1701
  return []
1512
1702
 
1703
+ class Stream(Authenticated):
1704
+ resource = "threads"
1705
+
1706
+ @staticmethod
1707
+ async def subscribe(
1708
+ conn: InMemConnectionProto | AsyncConnectionProto,
1709
+ thread_id: UUID,
1710
+ seen_runs: set[UUID],
1711
+ ) -> list[tuple[UUID, asyncio.Queue]]:
1712
+ """Subscribe to the thread stream, creating queues for unseen runs."""
1713
+ stream_manager = get_stream_manager()
1714
+ queues = []
1715
+
1716
+ # Create new queues only for runs not yet seen
1717
+ thread_id = _ensure_uuid(thread_id)
1718
+
1719
+ # Add thread stream queue
1720
+ if thread_id not in seen_runs:
1721
+ queue = await stream_manager.add_thread_stream(thread_id)
1722
+ queues.append((thread_id, queue))
1723
+ seen_runs.add(thread_id)
1724
+
1725
+ for run in conn.store["runs"]:
1726
+ if run["thread_id"] == thread_id:
1727
+ run_id = run["run_id"]
1728
+ if run_id not in seen_runs:
1729
+ queue = await stream_manager.add_queue(run_id, thread_id)
1730
+ queues.append((run_id, queue))
1731
+ seen_runs.add(run_id)
1732
+
1733
+ return queues
1734
+
1735
+ @staticmethod
1736
+ async def join(
1737
+ thread_id: UUID,
1738
+ *,
1739
+ last_event_id: str | None = None,
1740
+ stream_modes: list[ThreadStreamMode],
1741
+ ctx: Auth.types.BaseAuthContext | None = None,
1742
+ ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
1743
+ """Stream the thread output."""
1744
+ await Threads.Stream.check_thread_stream_auth(thread_id, ctx)
1745
+
1746
+ from langgraph_api.utils.stream_codec import (
1747
+ decode_stream_message,
1748
+ )
1749
+
1750
+ def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
1751
+ """Check if an event should be filtered out based on stream_modes."""
1752
+ if "run_modes" in stream_modes and event_name != "state_update":
1753
+ return False
1754
+ if "state_update" in stream_modes and event_name == "state_update":
1755
+ return False
1756
+ if "lifecycle" in stream_modes and event_name == "metadata":
1757
+ try:
1758
+ message_data = orjson.loads(message_bytes)
1759
+ if message_data.get("status") == "run_done":
1760
+ return False
1761
+ if "attempt" in message_data and "run_id" in message_data:
1762
+ return False
1763
+ except (orjson.JSONDecodeError, TypeError):
1764
+ pass
1765
+ return True
1766
+
1767
+ stream_manager = get_stream_manager()
1768
+ seen_runs: set[UUID] = set()
1769
+ created_queues: list[tuple[UUID, asyncio.Queue]] = []
1770
+
1771
+ try:
1772
+ async with connect() as conn:
1773
+ await logger.ainfo(
1774
+ "Joined thread stream",
1775
+ thread_id=str(thread_id),
1776
+ )
1777
+
1778
+ # Restore messages if resuming from a specific event
1779
+ if last_event_id is not None:
1780
+ # Collect all events from all message stores for this thread
1781
+ all_events = []
1782
+ for run_id in stream_manager.message_stores.get(
1783
+ str(thread_id), []
1784
+ ):
1785
+ for message in stream_manager.restore_messages(
1786
+ run_id, thread_id, last_event_id
1787
+ ):
1788
+ all_events.append((message, run_id))
1789
+
1790
+ # Sort by message ID (which is ms-seq format)
1791
+ all_events.sort(key=lambda x: x[0].id.decode())
1792
+
1793
+ # Yield sorted events
1794
+ for message, run_id in all_events:
1795
+ decoded = decode_stream_message(
1796
+ message.data, channel=message.topic
1797
+ )
1798
+ event_bytes = decoded.event_bytes
1799
+ message_bytes = decoded.message_bytes
1800
+
1801
+ if event_bytes == b"control":
1802
+ if message_bytes == b"done":
1803
+ event_bytes = b"metadata"
1804
+ message_bytes = orjson.dumps(
1805
+ {"status": "run_done", "run_id": run_id}
1806
+ )
1807
+ if not should_filter_event(
1808
+ event_bytes.decode("utf-8"), message_bytes
1809
+ ):
1810
+ yield (
1811
+ event_bytes,
1812
+ message_bytes,
1813
+ message.id,
1814
+ )
1815
+
1816
+ # Listen for live messages from all queues
1817
+ while True:
1818
+ # Refresh queues to pick up any new runs that joined this thread
1819
+ new_queue_tuples = await Threads.Stream.subscribe(
1820
+ conn, thread_id, seen_runs
1821
+ )
1822
+ # Track new queues for cleanup
1823
+ for run_id, queue in new_queue_tuples:
1824
+ created_queues.append((run_id, queue))
1825
+
1826
+ for run_id, queue in created_queues:
1827
+ try:
1828
+ message = await asyncio.wait_for(
1829
+ queue.get(), timeout=0.2
1830
+ )
1831
+ decoded = decode_stream_message(
1832
+ message.data, channel=message.topic
1833
+ )
1834
+ event = decoded.event_bytes
1835
+ event_name = event.decode("utf-8")
1836
+ payload = decoded.message_bytes
1837
+
1838
+ if event == b"control" and payload == b"done":
1839
+ topic = message.topic.decode()
1840
+ run_id = topic.split("run:")[1].split(":")[0]
1841
+ meta_event = b"metadata"
1842
+ meta_payload = orjson.dumps(
1843
+ {"status": "run_done", "run_id": run_id}
1844
+ )
1845
+ if not should_filter_event(
1846
+ "metadata", meta_payload
1847
+ ):
1848
+ yield (meta_event, meta_payload, message.id)
1849
+ else:
1850
+ if not should_filter_event(event_name, payload):
1851
+ yield (event, payload, message.id)
1852
+
1853
+ except TimeoutError:
1854
+ continue
1855
+ except (ValueError, KeyError):
1856
+ continue
1857
+
1858
+ # Yield execution to other tasks to prevent event loop starvation
1859
+ await asyncio.sleep(0)
1860
+
1861
+ except WrappedHTTPException as e:
1862
+ raise e.http_exception from None
1863
+ except asyncio.CancelledError:
1864
+ await logger.awarning(
1865
+ "Thread stream client disconnected",
1866
+ thread_id=str(thread_id),
1867
+ )
1868
+ raise
1869
+ except:
1870
+ raise
1871
+ finally:
1872
+ # Clean up all created queues
1873
+ for run_id, queue in created_queues:
1874
+ try:
1875
+ await stream_manager.remove_queue(run_id, thread_id, queue)
1876
+ except Exception:
1877
+ # Ignore cleanup errors
1878
+ pass
1879
+
1880
+ @staticmethod
1881
+ async def publish(
1882
+ thread_id: UUID | str,
1883
+ event: str,
1884
+ message: bytes,
1885
+ ) -> None:
1886
+ """Publish a thread-level event to the thread stream."""
1887
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
1888
+
1889
+ topic = f"thread:{thread_id}:stream".encode()
1890
+
1891
+ stream_manager = get_stream_manager()
1892
+ payload = STREAM_CODEC.encode(event, message)
1893
+ await stream_manager.put_thread(
1894
+ str(thread_id), Message(topic=topic, data=payload)
1895
+ )
1896
+
1897
+ @staticmethod
1898
+ async def check_thread_stream_auth(
1899
+ thread_id: UUID,
1900
+ ctx: Auth.types.BaseAuthContext | None = None,
1901
+ ) -> None:
1902
+ async with connect() as conn:
1903
+ filters = await Threads.Stream.handle_event(
1904
+ ctx,
1905
+ "read",
1906
+ Auth.types.ThreadsRead(thread_id=thread_id),
1907
+ )
1908
+ if filters:
1909
+ thread = await Threads._get_with_filters(
1910
+ cast(InMemConnectionProto, conn), thread_id, filters
1911
+ )
1912
+ if not thread:
1913
+ raise HTTPException(status_code=404, detail="Thread not found")
1914
+
1915
+ @staticmethod
1916
+ async def count(
1917
+ conn: InMemConnectionProto,
1918
+ *,
1919
+ metadata: MetadataInput = None,
1920
+ values: MetadataInput = None,
1921
+ status: ThreadStatus | None = None,
1922
+ ctx: Auth.types.BaseAuthContext | None = None,
1923
+ ) -> int:
1924
+ """Get count of threads."""
1925
+ threads = conn.store["threads"]
1926
+ metadata = metadata if metadata is not None else {}
1927
+ values = values if values is not None else {}
1928
+ filters = await Threads.handle_event(
1929
+ ctx,
1930
+ "search",
1931
+ Auth.types.ThreadsSearch(
1932
+ metadata=metadata,
1933
+ values=values,
1934
+ status=status,
1935
+ limit=0,
1936
+ offset=0,
1937
+ ),
1938
+ )
1939
+
1940
+ count = 0
1941
+ for thread in threads:
1942
+ if filters and not _check_filter_match(thread["metadata"], filters):
1943
+ continue
1944
+
1945
+ if metadata and not is_jsonb_contained(thread["metadata"], metadata):
1946
+ continue
1947
+
1948
+ if (
1949
+ values
1950
+ and "values" in thread
1951
+ and not is_jsonb_contained(thread["values"], values)
1952
+ ):
1953
+ continue
1954
+
1955
+ if status and thread.get("status") != status:
1956
+ continue
1957
+
1958
+ count += 1
1959
+
1960
+ return count
1961
+
1513
1962
 
1514
1963
  RUN_LOCK = asyncio.Lock()
1515
1964
 
@@ -1526,38 +1975,37 @@ class Runs(Authenticated):
1526
1975
  if not pending_runs and not running_runs:
1527
1976
  return {
1528
1977
  "n_pending": 0,
1529
- "max_age_secs": None,
1530
- "med_age_secs": None,
1978
+ "pending_runs_wait_time_max_secs": None,
1979
+ "pending_runs_wait_time_med_secs": None,
1531
1980
  "n_running": 0,
1532
1981
  }
1533
1982
 
1534
- # Get all creation timestamps
1535
- created_times = [run.get("created_at") for run in (pending_runs + running_runs)]
1536
- created_times = [
1537
- t for t in created_times if t is not None
1538
- ] # Filter out None values
1539
-
1540
- if not created_times:
1541
- return {
1542
- "n_pending": len(pending_runs),
1543
- "n_running": len(running_runs),
1544
- "max_age_secs": None,
1545
- "med_age_secs": None,
1546
- }
1547
-
1548
- # Find oldest (max age)
1549
- oldest_time = min(created_times) # Earliest timestamp = oldest run
1550
-
1551
- # Find median age
1552
- sorted_times = sorted(created_times)
1553
- median_idx = len(sorted_times) // 2
1554
- median_time = sorted_times[median_idx]
1983
+ now = datetime.now(UTC)
1984
+ pending_waits: list[float] = []
1985
+ for run in pending_runs:
1986
+ created_at = run.get("created_at")
1987
+ if not isinstance(created_at, datetime):
1988
+ continue
1989
+ if created_at.tzinfo is None:
1990
+ created_at = created_at.replace(tzinfo=UTC)
1991
+ pending_waits.append((now - created_at).total_seconds())
1992
+
1993
+ max_pending_wait = max(pending_waits) if pending_waits else None
1994
+ if pending_waits:
1995
+ sorted_waits = sorted(pending_waits)
1996
+ half = len(sorted_waits) // 2
1997
+ if len(sorted_waits) % 2 == 1:
1998
+ med_pending_wait = sorted_waits[half]
1999
+ else:
2000
+ med_pending_wait = (sorted_waits[half - 1] + sorted_waits[half]) / 2
2001
+ else:
2002
+ med_pending_wait = None
1555
2003
 
1556
2004
  return {
1557
2005
  "n_pending": len(pending_runs),
1558
2006
  "n_running": len(running_runs),
1559
- "max_age_secs": oldest_time,
1560
- "med_age_secs": median_time,
2007
+ "pending_runs_wait_time_max_secs": max_pending_wait,
2008
+ "pending_runs_wait_time_med_secs": med_pending_wait,
1561
2009
  }
1562
2010
 
1563
2011
  @staticmethod
@@ -1621,38 +2069,51 @@ class Runs(Authenticated):
1621
2069
  @asynccontextmanager
1622
2070
  @staticmethod
1623
2071
  async def enter(
1624
- run_id: UUID, loop: asyncio.AbstractEventLoop
2072
+ run_id: UUID,
2073
+ thread_id: UUID | None,
2074
+ loop: asyncio.AbstractEventLoop,
2075
+ resumable: bool,
1625
2076
  ) -> AsyncIterator[ValueEvent]:
1626
2077
  """Enter a run, listen for cancellation while running, signal when done."
1627
2078
  This method should be called as a context manager by a worker executing a run.
1628
2079
  """
1629
2080
  from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
2081
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
1630
2082
 
1631
2083
  stream_manager = get_stream_manager()
1632
- # Get queue for this run
1633
- queue = await Runs.Stream.subscribe(run_id)
2084
+ # Get control queue for this run (normal queue is created during run creation)
2085
+ control_queue = await stream_manager.add_control_queue(run_id, thread_id)
1634
2086
 
1635
2087
  async with SimpleTaskGroup(cancel=True, taskgroup_name="Runs.enter") as tg:
1636
2088
  done = ValueEvent()
1637
- tg.create_task(listen_for_cancellation(queue, run_id, done))
2089
+ tg.create_task(
2090
+ listen_for_cancellation(control_queue, run_id, thread_id, done)
2091
+ )
1638
2092
 
1639
2093
  # Give done event to caller
1640
2094
  yield done
1641
- # Signal done to all subscribers
2095
+ # Store the control message for late subscribers
1642
2096
  control_message = Message(
1643
2097
  topic=f"run:{run_id}:control".encode(), data=b"done"
1644
2098
  )
2099
+ await stream_manager.put(run_id, thread_id, control_message)
1645
2100
 
1646
- # Store the control message for late subscribers
1647
- await stream_manager.put(run_id, control_message)
1648
- stream_manager.control_queues[run_id].append(control_message)
1649
- # Clean up this queue
1650
- await stream_manager.remove_queue(run_id, queue)
2101
+ # Signal done to all subscribers using stream codec
2102
+ stream_message = Message(
2103
+ topic=f"run:{run_id}:stream".encode(),
2104
+ data=STREAM_CODEC.encode("control", b"done"),
2105
+ )
2106
+ await stream_manager.put(
2107
+ run_id, thread_id, stream_message, resumable=resumable
2108
+ )
2109
+
2110
+ # Remove the control_queue (normal queue is cleaned up during run deletion)
2111
+ await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
1651
2112
 
1652
2113
  @staticmethod
1653
- async def sweep(conn: InMemConnectionProto) -> list[UUID]:
2114
+ async def sweep() -> None:
1654
2115
  """Sweep runs that are no longer running"""
1655
- return []
2116
+ pass
1656
2117
 
1657
2118
  @staticmethod
1658
2119
  def _merge_jsonb(*objects: dict) -> dict:
@@ -1670,7 +2131,7 @@ class Runs(Authenticated):
1670
2131
 
1671
2132
  @staticmethod
1672
2133
  async def put(
1673
- conn: InMemConnectionProto,
2134
+ conn: InMemConnectionProto | AsyncConnectionProto,
1674
2135
  assistant_id: UUID,
1675
2136
  kwargs: dict,
1676
2137
  *,
@@ -1701,6 +2162,7 @@ class Runs(Authenticated):
1701
2162
  run_id = _ensure_uuid(run_id) if run_id else None
1702
2163
  metadata = metadata if metadata is not None else {}
1703
2164
  config = kwargs.get("config", {})
2165
+ temporary = kwargs.get("temporary", False)
1704
2166
 
1705
2167
  # Handle thread creation/update
1706
2168
  existing_thread = next(
@@ -1710,7 +2172,7 @@ class Runs(Authenticated):
1710
2172
  ctx,
1711
2173
  "create_run",
1712
2174
  Auth.types.RunsCreate(
1713
- thread_id=thread_id,
2175
+ thread_id=None if temporary else thread_id,
1714
2176
  assistant_id=assistant_id,
1715
2177
  run_id=run_id,
1716
2178
  status=status,
@@ -1731,6 +2193,7 @@ class Runs(Authenticated):
1731
2193
  # Create new thread
1732
2194
  if thread_id is None:
1733
2195
  thread_id = uuid4()
2196
+
1734
2197
  thread = Thread(
1735
2198
  thread_id=thread_id,
1736
2199
  status="busy",
@@ -1746,7 +2209,6 @@ class Runs(Authenticated):
1746
2209
  {
1747
2210
  "configurable": Runs._merge_jsonb(
1748
2211
  Runs._get_configurable(assistant["config"]),
1749
- Runs._get_configurable(config),
1750
2212
  )
1751
2213
  },
1752
2214
  ),
@@ -1754,6 +2216,7 @@ class Runs(Authenticated):
1754
2216
  updated_at=datetime.now(UTC),
1755
2217
  values=b"",
1756
2218
  )
2219
+
1757
2220
  await logger.ainfo("Creating thread", thread_id=thread_id)
1758
2221
  conn.store["threads"].append(thread)
1759
2222
  elif existing_thread:
@@ -1775,7 +2238,6 @@ class Runs(Authenticated):
1775
2238
  "configurable": Runs._merge_jsonb(
1776
2239
  Runs._get_configurable(assistant["config"]),
1777
2240
  Runs._get_configurable(existing_thread["config"]),
1778
- Runs._get_configurable(config),
1779
2241
  )
1780
2242
  },
1781
2243
  )
@@ -1920,6 +2382,7 @@ class Runs(Authenticated):
1920
2382
  if not thread:
1921
2383
  return _empty_generator()
1922
2384
  _delete_checkpoints_for_thread(thread_id, conn, run_id=run_id)
2385
+
1923
2386
  found = False
1924
2387
  for i, run in enumerate(conn.store["runs"]):
1925
2388
  if run["run_id"] == run_id and run["thread_id"] == thread_id:
@@ -1935,54 +2398,10 @@ class Runs(Authenticated):
1935
2398
 
1936
2399
  return _yield_deleted()
1937
2400
 
1938
- @staticmethod
1939
- async def join(
1940
- run_id: UUID,
1941
- *,
1942
- thread_id: UUID,
1943
- ctx: Auth.types.BaseAuthContext | None = None,
1944
- ) -> Fragment:
1945
- """Wait for a run to complete. If already done, return immediately.
1946
-
1947
- Returns:
1948
- the final state of the run.
1949
- """
1950
- from langgraph_api.serde import Fragment
1951
- from langgraph_api.utils import fetchone
1952
-
1953
- async with connect() as conn:
1954
- # Validate ownership
1955
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1956
- await fetchone(thread_iter)
1957
- last_chunk: bytes | None = None
1958
- # wait for the run to complete
1959
- # Rely on this join's auth
1960
- async for mode, chunk, _ in Runs.Stream.join(
1961
- run_id, thread_id=thread_id, ctx=ctx, ignore_404=True
1962
- ):
1963
- if mode == b"values":
1964
- last_chunk = chunk
1965
- elif mode == b"error":
1966
- last_chunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
1967
- # if we received a final chunk, return it
1968
- if last_chunk is not None:
1969
- # ie. if the run completed while we were waiting for it
1970
- return Fragment(last_chunk)
1971
- else:
1972
- # otherwise, the run had already finished, so fetch the state from thread
1973
- async with connect() as conn:
1974
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
1975
- thread = await fetchone(thread_iter)
1976
- if thread["status"] == "error":
1977
- return Fragment(
1978
- orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
1979
- )
1980
- return thread["values"]
1981
-
1982
2401
  @staticmethod
1983
2402
  async def cancel(
1984
- conn: InMemConnectionProto,
1985
- run_ids: Sequence[UUID] | None = None,
2403
+ conn: InMemConnectionProto | AsyncConnectionProto,
2404
+ run_ids: Sequence[UUID | str] | None = None,
1986
2405
  *,
1987
2406
  action: Literal["interrupt", "rollback"] = "interrupt",
1988
2407
  thread_id: UUID | None = None,
@@ -2086,9 +2505,9 @@ class Runs(Authenticated):
2086
2505
  topic=f"run:{run_id}:control".encode(),
2087
2506
  data=action.encode(),
2088
2507
  )
2089
- coros.append(stream_manager.put(run_id, control_message))
2508
+ coros.append(stream_manager.put(run_id, thread_id, control_message))
2090
2509
 
2091
- queues = stream_manager.get_queues(run_id)
2510
+ queues = stream_manager.get_queues(run_id, thread_id)
2092
2511
 
2093
2512
  if run["status"] in ("pending", "running"):
2094
2513
  cancelable_runs.append(run)
@@ -2146,6 +2565,7 @@ class Runs(Authenticated):
2146
2565
  limit: int = 10,
2147
2566
  offset: int = 0,
2148
2567
  status: RunStatus | None = None,
2568
+ select: list[RunSelectField] | None = None,
2149
2569
  ctx: Auth.types.BaseAuthContext | None = None,
2150
2570
  ) -> AsyncIterator[Run]:
2151
2571
  """List all runs by thread."""
@@ -2173,7 +2593,12 @@ class Runs(Authenticated):
2173
2593
 
2174
2594
  async def _return():
2175
2595
  for run in sliced_runs:
2176
- yield run
2596
+ if select:
2597
+ # Filter to only selected fields
2598
+ filtered_run = {k: v for k, v in run.items() if k in select}
2599
+ yield filtered_run
2600
+ else:
2601
+ yield run
2177
2602
 
2178
2603
  return _return()
2179
2604
 
@@ -2197,73 +2622,81 @@ class Runs(Authenticated):
2197
2622
  @staticmethod
2198
2623
  async def subscribe(
2199
2624
  run_id: UUID,
2200
- *,
2201
- stream_mode: StreamMode | None = None,
2202
- ) -> asyncio.Queue:
2625
+ thread_id: UUID | None = None,
2626
+ ) -> ContextQueue:
2203
2627
  """Subscribe to the run stream, returning a queue."""
2204
2628
  stream_manager = get_stream_manager()
2205
- queue = await stream_manager.add_queue(_ensure_uuid(run_id))
2629
+ queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
2206
2630
 
2207
2631
  # If there's a control message already stored, send it to the new subscriber
2208
- if control_messages := stream_manager.control_queues.get(run_id):
2209
- for control_msg in control_messages:
2210
- await queue.put(control_msg)
2632
+ if thread_id is None:
2633
+ thread_id = THREADLESS_KEY
2634
+ if control_queues := stream_manager.control_queues.get(thread_id, {}).get(
2635
+ run_id
2636
+ ):
2637
+ for control_queue in control_queues:
2638
+ try:
2639
+ while True:
2640
+ control_msg = control_queue.get()
2641
+ await queue.put(control_msg)
2642
+ except asyncio.QueueEmpty:
2643
+ pass
2211
2644
  return queue
2212
2645
 
2213
2646
  @staticmethod
2214
2647
  async def join(
2215
2648
  run_id: UUID,
2216
2649
  *,
2650
+ stream_channel: asyncio.Queue,
2217
2651
  thread_id: UUID,
2218
2652
  ignore_404: bool = False,
2219
2653
  cancel_on_disconnect: bool = False,
2220
- stream_mode: StreamMode | asyncio.Queue | None = None,
2654
+ stream_mode: list[StreamMode] | StreamMode | None = None,
2221
2655
  last_event_id: str | None = None,
2222
2656
  ctx: Auth.types.BaseAuthContext | None = None,
2223
2657
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
2224
2658
  """Stream the run output."""
2225
2659
  from langgraph_api.asyncio import create_task
2660
+ from langgraph_api.serde import json_dumpb
2661
+ from langgraph_api.utils.stream_codec import decode_stream_message
2226
2662
 
2227
- queue = (
2228
- stream_mode
2229
- if isinstance(stream_mode, asyncio.Queue)
2230
- else await Runs.Stream.subscribe(run_id, stream_mode=stream_mode)
2231
- )
2232
-
2663
+ queue = stream_channel
2233
2664
  try:
2234
2665
  async with connect() as conn:
2235
- filters = await Runs.handle_event(
2236
- ctx,
2237
- "read",
2238
- Auth.types.ThreadsRead(thread_id=thread_id),
2239
- )
2240
- if filters:
2241
- thread = await Threads._get_with_filters(
2242
- cast(InMemConnectionProto, conn), thread_id, filters
2243
- )
2244
- if not thread:
2245
- raise WrappedHTTPException(
2246
- HTTPException(
2247
- status_code=404, detail="Thread not found"
2248
- )
2249
- )
2250
- channel_prefix = f"run:{run_id}:stream:"
2251
- len_prefix = len(channel_prefix.encode())
2666
+ try:
2667
+ await Runs.Stream.check_run_stream_auth(run_id, thread_id, ctx)
2668
+ except HTTPException as e:
2669
+ raise WrappedHTTPException(e) from None
2670
+ run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
2252
2671
 
2253
2672
  for message in get_stream_manager().restore_messages(
2254
- run_id, last_event_id
2673
+ run_id, thread_id, last_event_id
2255
2674
  ):
2256
- topic, data, id = message.topic, message.data, message.id
2257
- if topic.decode() == f"run:{run_id}:control":
2258
- if data == b"done":
2675
+ data, id = message.data, message.id
2676
+ decoded = decode_stream_message(data, channel=message.topic)
2677
+ mode = decoded.event_bytes.decode("utf-8")
2678
+ payload = decoded.message_bytes
2679
+
2680
+ if mode == "control":
2681
+ if payload == b"done":
2259
2682
  return
2260
- else:
2261
- yield topic[len_prefix:], data, id
2683
+ elif (
2684
+ not stream_mode
2685
+ or mode in stream_mode
2686
+ or (
2687
+ (
2688
+ "messages" in stream_mode
2689
+ or "messages-tuple" in stream_mode
2690
+ )
2691
+ and mode.startswith("messages")
2692
+ )
2693
+ ):
2694
+ yield mode.encode(), payload, id
2262
2695
  logger.debug(
2263
2696
  "Replayed run event",
2264
2697
  run_id=str(run_id),
2265
2698
  message_id=id,
2266
- stream_mode=topic[len_prefix:],
2699
+ stream_mode=mode,
2267
2700
  data=data,
2268
2701
  )
2269
2702
 
@@ -2271,20 +2704,32 @@ class Runs(Authenticated):
2271
2704
  try:
2272
2705
  # Wait for messages with a timeout
2273
2706
  message = await asyncio.wait_for(queue.get(), timeout=0.5)
2274
- topic, data, id = message.topic, message.data, message.id
2707
+ data, id = message.data, message.id
2708
+ decoded = decode_stream_message(data, channel=message.topic)
2709
+ mode = decoded.event_bytes.decode("utf-8")
2710
+ payload = decoded.message_bytes
2275
2711
 
2276
- if topic.decode() == f"run:{run_id}:control":
2277
- if data == b"done":
2712
+ if mode == "control":
2713
+ if payload == b"done":
2278
2714
  break
2279
- else:
2280
- # Extract mode from topic
2281
- yield topic[len_prefix:], data, id
2715
+ elif (
2716
+ not stream_mode
2717
+ or mode in stream_mode
2718
+ or (
2719
+ (
2720
+ "messages" in stream_mode
2721
+ or "messages-tuple" in stream_mode
2722
+ )
2723
+ and mode.startswith("messages")
2724
+ )
2725
+ ):
2726
+ yield mode.encode(), payload, id
2282
2727
  logger.debug(
2283
2728
  "Streamed run event",
2284
2729
  run_id=str(run_id),
2285
- stream_mode=topic[len_prefix:],
2730
+ stream_mode=mode,
2286
2731
  message_id=id,
2287
- data=data,
2732
+ data=payload,
2288
2733
  )
2289
2734
  except TimeoutError:
2290
2735
  # Check if the run is still pending
@@ -2298,8 +2743,10 @@ class Runs(Authenticated):
2298
2743
  elif run is None:
2299
2744
  yield (
2300
2745
  b"error",
2301
- HTTPException(
2302
- status_code=404, detail="Run not found"
2746
+ json_dumpb(
2747
+ HTTPException(
2748
+ status_code=404, detail="Run not found"
2749
+ )
2303
2750
  ),
2304
2751
  None,
2305
2752
  )
@@ -2314,45 +2761,68 @@ class Runs(Authenticated):
2314
2761
  raise
2315
2762
  finally:
2316
2763
  stream_manager = get_stream_manager()
2317
- await stream_manager.remove_queue(run_id, queue)
2764
+ await stream_manager.remove_queue(run_id, thread_id, queue)
2318
2765
 
2319
2766
  @staticmethod
2320
- async def publish(
2767
+ async def check_run_stream_auth(
2321
2768
  run_id: UUID,
2769
+ thread_id: UUID,
2770
+ ctx: Auth.types.BaseAuthContext | None = None,
2771
+ ) -> None:
2772
+ async with connect() as conn:
2773
+ filters = await Runs.handle_event(
2774
+ ctx,
2775
+ "read",
2776
+ Auth.types.ThreadsRead(thread_id=thread_id),
2777
+ )
2778
+ if filters:
2779
+ thread = await Threads._get_with_filters(
2780
+ cast(InMemConnectionProto, conn), thread_id, filters
2781
+ )
2782
+ if not thread:
2783
+ raise HTTPException(status_code=404, detail="Thread not found")
2784
+
2785
+ @staticmethod
2786
+ async def publish(
2787
+ run_id: UUID | str,
2322
2788
  event: str,
2323
2789
  message: bytes,
2324
2790
  *,
2791
+ thread_id: UUID | str | None = None,
2325
2792
  resumable: bool = False,
2326
2793
  ) -> None:
2327
2794
  """Publish a message to all subscribers of the run stream."""
2328
- topic = f"run:{run_id}:stream:{event}".encode()
2795
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
2796
+
2797
+ topic = f"run:{run_id}:stream".encode()
2329
2798
 
2330
2799
  stream_manager = get_stream_manager()
2331
- # Send to all queues subscribed to this run_id
2800
+ # Send to all queues subscribed to this run_id using protocol frame
2801
+ payload = STREAM_CODEC.encode(event, message)
2332
2802
  await stream_manager.put(
2333
- run_id, Message(topic=topic, data=message), resumable
2803
+ run_id, thread_id, Message(topic=topic, data=payload), resumable
2334
2804
  )
2335
2805
 
2336
2806
 
2337
- async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: ValueEvent):
2807
+ async def listen_for_cancellation(
2808
+ queue: asyncio.Queue, run_id: UUID, thread_id: UUID | None, done: ValueEvent
2809
+ ):
2338
2810
  """Listen for cancellation messages and set the done event accordingly."""
2339
2811
  from langgraph_api.errors import UserInterrupt, UserRollback
2340
2812
 
2341
2813
  stream_manager = get_stream_manager()
2342
- control_key = f"run:{run_id}:control"
2343
2814
 
2344
- if existing_queue := stream_manager.control_queues.get(run_id):
2345
- for message in existing_queue:
2346
- payload = message.data
2347
- if payload == b"rollback":
2348
- done.set(UserRollback())
2349
- elif payload == b"interrupt":
2350
- done.set(UserInterrupt())
2815
+ if control_key := stream_manager.get_control_key(run_id, thread_id):
2816
+ payload = control_key.data
2817
+ if payload == b"rollback":
2818
+ done.set(UserRollback())
2819
+ elif payload == b"interrupt":
2820
+ done.set(UserInterrupt())
2351
2821
 
2352
2822
  while not done.is_set():
2353
2823
  try:
2354
2824
  # This task gets cancelled when Runs.enter exits anyway,
2355
- # so we can have a pretty length timeout here
2825
+ # so we can have a pretty lengthy timeout here
2356
2826
  message = await asyncio.wait_for(queue.get(), timeout=240)
2357
2827
  payload = message.data
2358
2828
  if payload == b"rollback":
@@ -2362,10 +2832,6 @@ async def listen_for_cancellation(queue: asyncio.Queue, run_id: UUID, done: Valu
2362
2832
  elif payload == b"done":
2363
2833
  done.set()
2364
2834
  break
2365
-
2366
- # Store control messages for late subscribers
2367
- if message.topic.decode() == control_key:
2368
- stream_manager.control_queues[run_id].append(message)
2369
2835
  except TimeoutError:
2370
2836
  break
2371
2837
 
@@ -2379,6 +2845,7 @@ class Crons:
2379
2845
  schedule: str,
2380
2846
  cron_id: UUID | None = None,
2381
2847
  thread_id: UUID | None = None,
2848
+ on_run_completed: Literal["delete", "keep"] | None = None,
2382
2849
  end_time: datetime | None = None,
2383
2850
  ctx: Auth.types.BaseAuthContext | None = None,
2384
2851
  ) -> AsyncIterator[Cron]:
@@ -2417,10 +2884,24 @@ class Crons:
2417
2884
  thread_id: UUID | None,
2418
2885
  limit: int,
2419
2886
  offset: int,
2887
+ select: list[CronSelectField] | None = None,
2420
2888
  ctx: Auth.types.BaseAuthContext | None = None,
2421
- ) -> AsyncIterator[Cron]:
2889
+ sort_by: str | None = None,
2890
+ sort_order: Literal["asc", "desc"] | None = None,
2891
+ ) -> tuple[AsyncIterator[Cron], int]:
2422
2892
  raise NotImplementedError
2423
2893
 
2894
+ @staticmethod
2895
+ async def count(
2896
+ conn: InMemConnectionProto,
2897
+ *,
2898
+ assistant_id: UUID | None = None,
2899
+ thread_id: UUID | None = None,
2900
+ ctx: Auth.types.BaseAuthContext | None = None,
2901
+ ) -> int:
2902
+ """Get count of crons."""
2903
+ raise NotImplementedError("The in-mem server does not implement Crons.")
2904
+
2424
2905
 
2425
2906
  async def cancel_run(
2426
2907
  thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
@@ -2478,11 +2959,18 @@ def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -
2478
2959
  if key not in metadata or metadata[key] != filter_value:
2479
2960
  return False
2480
2961
  elif op == "$contains":
2481
- if (
2482
- key not in metadata
2483
- or not isinstance(metadata[key], list)
2484
- or filter_value not in metadata[key]
2485
- ):
2962
+ if key not in metadata or not isinstance(metadata[key], list):
2963
+ return False
2964
+
2965
+ if isinstance(filter_value, list):
2966
+ # Mimick Postgres containment operator behavior.
2967
+ # It would be more efficient to use set operations here,
2968
+ # but we can't assume that elements are hashable.
2969
+ # The Postgres algorithm is also O(n^2).
2970
+ for filter_element in filter_value:
2971
+ if filter_element not in metadata[key]:
2972
+ return False
2973
+ elif filter_value not in metadata[key]:
2486
2974
  return False
2487
2975
  else:
2488
2976
  # Direct equality
@@ -2498,6 +2986,7 @@ async def _empty_generator():
2498
2986
 
2499
2987
 
2500
2988
  __all__ = [
2989
+ "StreamHandler",
2501
2990
  "Assistants",
2502
2991
  "Crons",
2503
2992
  "Runs",