langgraph-runtime-inmem 0.9.0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,6 @@ import uuid
11
11
  from collections import defaultdict
12
12
  from collections.abc import AsyncIterator, Sequence
13
13
  from contextlib import asynccontextmanager
14
- from copy import deepcopy
15
14
  from datetime import UTC, datetime, timedelta
16
15
  from typing import Any, Literal, cast
17
16
  from uuid import UUID, uuid4
@@ -29,6 +28,7 @@ from langgraph_runtime_inmem.checkpoint import Checkpointer
29
28
  from langgraph_runtime_inmem.database import InMemConnectionProto, connect
30
29
  from langgraph_runtime_inmem.inmem_stream import (
31
30
  THREADLESS_KEY,
31
+ ContextQueue,
32
32
  Message,
33
33
  get_stream_manager,
34
34
  )
@@ -58,12 +58,13 @@ if typing.TYPE_CHECKING:
58
58
  Thread,
59
59
  ThreadSelectField,
60
60
  ThreadStatus,
61
+ ThreadStreamMode,
61
62
  ThreadUpdateResponse,
62
63
  )
63
64
  from langgraph_api.schema import Interrupt as InterruptSchema
64
- from langgraph_api.serde import Fragment
65
65
  from langgraph_api.utils import AsyncConnectionProto
66
66
 
67
+ StreamHandler = ContextQueue
67
68
 
68
69
  logger = structlog.stdlib.get_logger(__name__)
69
70
 
@@ -140,6 +141,7 @@ class Assistants(Authenticated):
140
141
  conn: InMemConnectionProto,
141
142
  *,
142
143
  graph_id: str | None,
144
+ name: str | None,
143
145
  metadata: MetadataInput,
144
146
  limit: int,
145
147
  offset: int,
@@ -163,6 +165,7 @@ class Assistants(Authenticated):
163
165
  assistant
164
166
  for assistant in assistants
165
167
  if (not graph_id or assistant["graph_id"] == graph_id)
168
+ and (not name or name.lower() in assistant["name"].lower())
166
169
  and (not metadata or is_jsonb_contained(assistant["metadata"], metadata))
167
170
  and (not filters or _check_filter_match(assistant["metadata"], filters))
168
171
  ]
@@ -228,7 +231,7 @@ class Assistants(Authenticated):
228
231
  if assistant["assistant_id"] == assistant_id and (
229
232
  not filters or _check_filter_match(assistant["metadata"], filters)
230
233
  ):
231
- yield assistant
234
+ yield copy.deepcopy(assistant)
232
235
 
233
236
  return _yield_result()
234
237
 
@@ -247,6 +250,8 @@ class Assistants(Authenticated):
247
250
  description: str | None = None,
248
251
  ) -> AsyncIterator[Assistant]:
249
252
  """Insert an assistant."""
253
+ from langgraph_api.graph import GRAPHS
254
+
250
255
  assistant_id = _ensure_uuid(assistant_id)
251
256
  metadata = metadata if metadata is not None else {}
252
257
  filters = await Assistants.handle_event(
@@ -268,6 +273,9 @@ class Assistants(Authenticated):
268
273
  detail="Cannot specify both configurable and context. Prefer setting context alone. Context was introduced in LangGraph 0.6.0 and is the long term planned replacement for configurable.",
269
274
  )
270
275
 
276
+ if graph_id not in GRAPHS:
277
+ raise HTTPException(status_code=404, detail=f"Graph {graph_id} not found")
278
+
271
279
  # Keep config and context up to date with one another
272
280
  if config.get("configurable"):
273
281
  context = config["configurable"]
@@ -555,6 +563,8 @@ class Assistants(Authenticated):
555
563
  "metadata": version_data["metadata"],
556
564
  "version": version_data["version"],
557
565
  "updated_at": datetime.now(UTC),
566
+ "name": version_data["name"],
567
+ "description": version_data["description"],
558
568
  }
559
569
  )
560
570
 
@@ -618,6 +628,7 @@ class Assistants(Authenticated):
618
628
  conn: InMemConnectionProto,
619
629
  *,
620
630
  graph_id: str | None = None,
631
+ name: str | None = None,
621
632
  metadata: MetadataInput = None,
622
633
  ctx: Auth.types.BaseAuthContext | None = None,
623
634
  ) -> int:
@@ -635,6 +646,7 @@ class Assistants(Authenticated):
635
646
  for assistant in conn.store["assistants"]:
636
647
  if (
637
648
  (not graph_id or assistant["graph_id"] == graph_id)
649
+ and (not name or name.lower() in assistant["name"].lower())
638
650
  and (
639
651
  not metadata or is_jsonb_contained(assistant["metadata"], metadata)
640
652
  )
@@ -738,6 +750,7 @@ class Threads(Authenticated):
738
750
  async def search(
739
751
  conn: InMemConnectionProto,
740
752
  *,
753
+ ids: list[str] | list[UUID] | None = None,
741
754
  metadata: MetadataInput,
742
755
  values: MetadataInput,
743
756
  status: ThreadStatus | None,
@@ -765,7 +778,19 @@ class Threads(Authenticated):
765
778
  )
766
779
 
767
780
  # Apply filters
781
+ id_set: set[UUID] | None = None
782
+ if ids:
783
+ id_set = set()
784
+ for i in ids:
785
+ try:
786
+ id_set.add(_ensure_uuid(i))
787
+ except Exception:
788
+ raise HTTPException(
789
+ status_code=400, detail="Invalid thread ID " + str(i)
790
+ ) from None
768
791
  for thread in threads:
792
+ if id_set is not None and thread.get("thread_id") not in id_set:
793
+ continue
769
794
  if filters and not _check_filter_match(thread["metadata"], filters):
770
795
  continue
771
796
 
@@ -943,6 +968,7 @@ class Threads(Authenticated):
943
968
  thread_id: UUID,
944
969
  *,
945
970
  metadata: MetadataValue,
971
+ ttl: ThreadTTLConfig | None = None,
946
972
  ctx: Auth.types.BaseAuthContext | None = None,
947
973
  ) -> AsyncIterator[Thread]:
948
974
  """Update a thread."""
@@ -1215,13 +1241,23 @@ class Threads(Authenticated):
1215
1241
  """Create a copy of an existing thread."""
1216
1242
  thread_id = _ensure_uuid(thread_id)
1217
1243
  new_thread_id = uuid4()
1218
- filters = await Threads.handle_event(
1244
+ read_filters = await Threads.handle_event(
1219
1245
  ctx,
1220
1246
  "read",
1221
1247
  Auth.types.ThreadsRead(
1248
+ thread_id=thread_id,
1249
+ ),
1250
+ )
1251
+ # Assert that the user has permissions to create a new thread.
1252
+ # (We don't actually need the filters.)
1253
+ await Threads.handle_event(
1254
+ ctx,
1255
+ "create",
1256
+ Auth.types.ThreadsCreate(
1222
1257
  thread_id=new_thread_id,
1223
1258
  ),
1224
1259
  )
1260
+
1225
1261
  async with conn.pipeline():
1226
1262
  # Find the original thread in our store
1227
1263
  original_thread = next(
@@ -1230,8 +1266,8 @@ class Threads(Authenticated):
1230
1266
 
1231
1267
  if not original_thread:
1232
1268
  return _empty_generator()
1233
- if filters and not _check_filter_match(
1234
- original_thread["metadata"], filters
1269
+ if read_filters and not _check_filter_match(
1270
+ original_thread["metadata"], read_filters
1235
1271
  ):
1236
1272
  return _empty_generator()
1237
1273
 
@@ -1240,7 +1276,7 @@ class Threads(Authenticated):
1240
1276
  "thread_id": new_thread_id,
1241
1277
  "created_at": datetime.now(tz=UTC),
1242
1278
  "updated_at": datetime.now(tz=UTC),
1243
- "metadata": deepcopy(original_thread["metadata"]),
1279
+ "metadata": copy.deepcopy(original_thread["metadata"]),
1244
1280
  "status": "idle",
1245
1281
  "config": {},
1246
1282
  }
@@ -1327,7 +1363,14 @@ class Threads(Authenticated):
1327
1363
  )
1328
1364
 
1329
1365
  metadata = thread.get("metadata", {})
1330
- thread_config = thread.get("config", {})
1366
+ thread_config = cast(dict[str, Any], thread.get("config", {}))
1367
+ thread_config = {
1368
+ **thread_config,
1369
+ "configurable": {
1370
+ **thread_config.get("configurable", {}),
1371
+ **config.get("configurable", {}),
1372
+ },
1373
+ }
1331
1374
 
1332
1375
  # Fallback to graph_id from run if not in thread metadata
1333
1376
  graph_id = metadata.get("graph_id")
@@ -1377,6 +1420,7 @@ class Threads(Authenticated):
1377
1420
  """Add state to a thread."""
1378
1421
  from langgraph_api.graph import get_graph
1379
1422
  from langgraph_api.schema import ThreadUpdateResponse
1423
+ from langgraph_api.state import state_snapshot_to_thread_state
1380
1424
  from langgraph_api.store import get_store
1381
1425
  from langgraph_api.utils import fetchone
1382
1426
 
@@ -1414,6 +1458,13 @@ class Threads(Authenticated):
1414
1458
  status_code=409,
1415
1459
  detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
1416
1460
  )
1461
+ thread_config = {
1462
+ **thread_config,
1463
+ "configurable": {
1464
+ **thread_config.get("configurable", {}),
1465
+ **config.get("configurable", {}),
1466
+ },
1467
+ }
1417
1468
 
1418
1469
  # Fallback to graph_id from run if not in thread metadata
1419
1470
  graph_id = metadata.get("graph_id")
@@ -1454,6 +1505,19 @@ class Threads(Authenticated):
1454
1505
  thread["values"] = state.values
1455
1506
  break
1456
1507
 
1508
+ # Publish state update event
1509
+ from langgraph_api.serde import json_dumpb
1510
+
1511
+ event_data = {
1512
+ "state": state_snapshot_to_thread_state(state),
1513
+ "thread_id": str(thread_id),
1514
+ }
1515
+ await Threads.Stream.publish(
1516
+ thread_id,
1517
+ "state_update",
1518
+ json_dumpb(event_data),
1519
+ )
1520
+
1457
1521
  return ThreadUpdateResponse(
1458
1522
  checkpoint=next_config["configurable"],
1459
1523
  # Including deprecated fields
@@ -1496,7 +1560,14 @@ class Threads(Authenticated):
1496
1560
  thread_iter, not_found_detail=f"Thread {thread_id} not found."
1497
1561
  )
1498
1562
 
1499
- thread_config = thread["config"]
1563
+ thread_config = cast(dict[str, Any], thread["config"])
1564
+ thread_config = {
1565
+ **thread_config,
1566
+ "configurable": {
1567
+ **thread_config.get("configurable", {}),
1568
+ **config.get("configurable", {}),
1569
+ },
1570
+ }
1500
1571
  metadata = thread["metadata"]
1501
1572
 
1502
1573
  if not thread:
@@ -1543,6 +1614,19 @@ class Threads(Authenticated):
1543
1614
  thread["values"] = state.values
1544
1615
  break
1545
1616
 
1617
+ # Publish state update event
1618
+ from langgraph_api.serde import json_dumpb
1619
+
1620
+ event_data = {
1621
+ "state": state,
1622
+ "thread_id": str(thread_id),
1623
+ }
1624
+ await Threads.Stream.publish(
1625
+ thread_id,
1626
+ "state_update",
1627
+ json_dumpb(event_data),
1628
+ )
1629
+
1546
1630
  return ThreadUpdateResponse(
1547
1631
  checkpoint=next_config["configurable"],
1548
1632
  )
@@ -1584,7 +1668,14 @@ class Threads(Authenticated):
1584
1668
  if not _check_filter_match(thread_metadata, filters):
1585
1669
  return []
1586
1670
 
1587
- thread_config = thread["config"]
1671
+ thread_config = cast(dict[str, Any], thread["config"])
1672
+ thread_config = {
1673
+ **thread_config,
1674
+ "configurable": {
1675
+ **thread_config.get("configurable", {}),
1676
+ **config.get("configurable", {}),
1677
+ },
1678
+ }
1588
1679
  # If graph_id exists, get state history
1589
1680
  if graph_id := thread_metadata.get("graph_id"):
1590
1681
  async with get_graph(
@@ -1613,7 +1704,9 @@ class Threads(Authenticated):
1613
1704
 
1614
1705
  return []
1615
1706
 
1616
- class Stream:
1707
+ class Stream(Authenticated):
1708
+ resource = "threads"
1709
+
1617
1710
  @staticmethod
1618
1711
  async def subscribe(
1619
1712
  conn: InMemConnectionProto | AsyncConnectionProto,
@@ -1626,6 +1719,13 @@ class Threads(Authenticated):
1626
1719
 
1627
1720
  # Create new queues only for runs not yet seen
1628
1721
  thread_id = _ensure_uuid(thread_id)
1722
+
1723
+ # Add thread stream queue
1724
+ if thread_id not in seen_runs:
1725
+ queue = await stream_manager.add_thread_stream(thread_id)
1726
+ queues.append((thread_id, queue))
1727
+ seen_runs.add(thread_id)
1728
+
1629
1729
  for run in conn.store["runs"]:
1630
1730
  if run["thread_id"] == thread_id:
1631
1731
  run_id = run["run_id"]
@@ -1641,9 +1741,32 @@ class Threads(Authenticated):
1641
1741
  thread_id: UUID,
1642
1742
  *,
1643
1743
  last_event_id: str | None = None,
1744
+ stream_modes: list[ThreadStreamMode],
1745
+ ctx: Auth.types.BaseAuthContext | None = None,
1644
1746
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
1645
1747
  """Stream the thread output."""
1646
- from langgraph_api.serde import json_loads
1748
+ await Threads.Stream.check_thread_stream_auth(thread_id, ctx)
1749
+
1750
+ from langgraph_api.utils.stream_codec import (
1751
+ decode_stream_message,
1752
+ )
1753
+
1754
+ def should_filter_event(event_name: str, message_bytes: bytes) -> bool:
1755
+ """Check if an event should be filtered out based on stream_modes."""
1756
+ if "run_modes" in stream_modes and event_name != "state_update":
1757
+ return False
1758
+ if "state_update" in stream_modes and event_name == "state_update":
1759
+ return False
1760
+ if "lifecycle" in stream_modes and event_name == "metadata":
1761
+ try:
1762
+ message_data = orjson.loads(message_bytes)
1763
+ if message_data.get("status") == "run_done":
1764
+ return False
1765
+ if "attempt" in message_data and "run_id" in message_data:
1766
+ return False
1767
+ except (orjson.JSONDecodeError, TypeError):
1768
+ pass
1769
+ return True
1647
1770
 
1648
1771
  stream_manager = get_stream_manager()
1649
1772
  seen_runs: set[UUID] = set()
@@ -1673,23 +1796,24 @@ class Threads(Authenticated):
1673
1796
 
1674
1797
  # Yield sorted events
1675
1798
  for message, run_id in all_events:
1676
- data = json_loads(message.data)
1677
- event_name = data["event"]
1678
- message_content = data["message"]
1679
-
1680
- if event_name == "control":
1681
- if message_content == b"done":
1682
- yield (
1683
- b"metadata",
1684
- orjson.dumps(
1685
- {"status": "run_done", "run_id": run_id}
1686
- ),
1687
- message.id,
1799
+ decoded = decode_stream_message(
1800
+ message.data, channel=message.topic
1801
+ )
1802
+ event_bytes = decoded.event_bytes
1803
+ message_bytes = decoded.message_bytes
1804
+
1805
+ if event_bytes == b"control":
1806
+ if message_bytes == b"done":
1807
+ event_bytes = b"metadata"
1808
+ message_bytes = orjson.dumps(
1809
+ {"status": "run_done", "run_id": run_id}
1688
1810
  )
1689
- else:
1811
+ if not should_filter_event(
1812
+ event_bytes.decode("utf-8"), message_bytes
1813
+ ):
1690
1814
  yield (
1691
- event_name.encode(),
1692
- base64.b64decode(message_content),
1815
+ event_bytes,
1816
+ message_bytes,
1693
1817
  message.id,
1694
1818
  )
1695
1819
 
@@ -1708,28 +1832,27 @@ class Threads(Authenticated):
1708
1832
  message = await asyncio.wait_for(
1709
1833
  queue.get(), timeout=0.2
1710
1834
  )
1711
- data = json_loads(message.data)
1712
- event_name = data["event"]
1713
- message_content = data["message"]
1714
-
1715
- if event_name == "control":
1716
- if message_content == b"done":
1717
- # Extract run_id from topic
1718
- topic = message.topic.decode()
1719
- run_id = topic.split("run:")[1].split(":")[0]
1720
- yield (
1721
- b"metadata",
1722
- orjson.dumps(
1723
- {"status": "run_done", "run_id": run_id}
1724
- ),
1725
- message.id,
1726
- )
1727
- else:
1728
- yield (
1729
- event_name.encode(),
1730
- base64.b64decode(message_content),
1731
- message.id,
1835
+ decoded = decode_stream_message(
1836
+ message.data, channel=message.topic
1837
+ )
1838
+ event = decoded.event_bytes
1839
+ event_name = event.decode("utf-8")
1840
+ payload = decoded.message_bytes
1841
+
1842
+ if event == b"control" and payload == b"done":
1843
+ topic = message.topic.decode()
1844
+ run_id = topic.split("run:")[1].split(":")[0]
1845
+ meta_event = b"metadata"
1846
+ meta_payload = orjson.dumps(
1847
+ {"status": "run_done", "run_id": run_id}
1732
1848
  )
1849
+ if not should_filter_event(
1850
+ "metadata", meta_payload
1851
+ ):
1852
+ yield (meta_event, meta_payload, message.id)
1853
+ else:
1854
+ if not should_filter_event(event_name, payload):
1855
+ yield (event, payload, message.id)
1733
1856
 
1734
1857
  except TimeoutError:
1735
1858
  continue
@@ -1758,6 +1881,41 @@ class Threads(Authenticated):
1758
1881
  # Ignore cleanup errors
1759
1882
  pass
1760
1883
 
1884
+ @staticmethod
1885
+ async def publish(
1886
+ thread_id: UUID | str,
1887
+ event: str,
1888
+ message: bytes,
1889
+ ) -> None:
1890
+ """Publish a thread-level event to the thread stream."""
1891
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
1892
+
1893
+ topic = f"thread:{thread_id}:stream".encode()
1894
+
1895
+ stream_manager = get_stream_manager()
1896
+ payload = STREAM_CODEC.encode(event, message)
1897
+ await stream_manager.put_thread(
1898
+ str(thread_id), Message(topic=topic, data=payload)
1899
+ )
1900
+
1901
+ @staticmethod
1902
+ async def check_thread_stream_auth(
1903
+ thread_id: UUID,
1904
+ ctx: Auth.types.BaseAuthContext | None = None,
1905
+ ) -> None:
1906
+ async with connect() as conn:
1907
+ filters = await Threads.Stream.handle_event(
1908
+ ctx,
1909
+ "read",
1910
+ Auth.types.ThreadsRead(thread_id=thread_id),
1911
+ )
1912
+ if filters:
1913
+ thread = await Threads._get_with_filters(
1914
+ cast(InMemConnectionProto, conn), thread_id, filters
1915
+ )
1916
+ if not thread:
1917
+ raise HTTPException(status_code=404, detail="Thread not found")
1918
+
1761
1919
  @staticmethod
1762
1920
  async def count(
1763
1921
  conn: InMemConnectionProto,
@@ -1821,38 +1979,37 @@ class Runs(Authenticated):
1821
1979
  if not pending_runs and not running_runs:
1822
1980
  return {
1823
1981
  "n_pending": 0,
1824
- "max_age_secs": None,
1825
- "med_age_secs": None,
1982
+ "pending_runs_wait_time_max_secs": None,
1983
+ "pending_runs_wait_time_med_secs": None,
1826
1984
  "n_running": 0,
1827
1985
  }
1828
1986
 
1829
- # Get all creation timestamps
1830
- created_times = [run.get("created_at") for run in (pending_runs + running_runs)]
1831
- created_times = [
1832
- t for t in created_times if t is not None
1833
- ] # Filter out None values
1834
-
1835
- if not created_times:
1836
- return {
1837
- "n_pending": len(pending_runs),
1838
- "n_running": len(running_runs),
1839
- "max_age_secs": None,
1840
- "med_age_secs": None,
1841
- }
1842
-
1843
- # Find oldest (max age)
1844
- oldest_time = min(created_times) # Earliest timestamp = oldest run
1845
-
1846
- # Find median age
1847
- sorted_times = sorted(created_times)
1848
- median_idx = len(sorted_times) // 2
1849
- median_time = sorted_times[median_idx]
1987
+ now = datetime.now(UTC)
1988
+ pending_waits: list[float] = []
1989
+ for run in pending_runs:
1990
+ created_at = run.get("created_at")
1991
+ if not isinstance(created_at, datetime):
1992
+ continue
1993
+ if created_at.tzinfo is None:
1994
+ created_at = created_at.replace(tzinfo=UTC)
1995
+ pending_waits.append((now - created_at).total_seconds())
1996
+
1997
+ max_pending_wait = max(pending_waits) if pending_waits else None
1998
+ if pending_waits:
1999
+ sorted_waits = sorted(pending_waits)
2000
+ half = len(sorted_waits) // 2
2001
+ if len(sorted_waits) % 2 == 1:
2002
+ med_pending_wait = sorted_waits[half]
2003
+ else:
2004
+ med_pending_wait = (sorted_waits[half - 1] + sorted_waits[half]) / 2
2005
+ else:
2006
+ med_pending_wait = None
1850
2007
 
1851
2008
  return {
1852
2009
  "n_pending": len(pending_runs),
1853
2010
  "n_running": len(running_runs),
1854
- "max_age_secs": oldest_time,
1855
- "med_age_secs": median_time,
2011
+ "pending_runs_wait_time_max_secs": max_pending_wait,
2012
+ "pending_runs_wait_time_med_secs": med_pending_wait,
1856
2013
  }
1857
2014
 
1858
2015
  @staticmethod
@@ -1916,12 +2073,16 @@ class Runs(Authenticated):
1916
2073
  @asynccontextmanager
1917
2074
  @staticmethod
1918
2075
  async def enter(
1919
- run_id: UUID, thread_id: UUID | None, loop: asyncio.AbstractEventLoop
2076
+ run_id: UUID,
2077
+ thread_id: UUID | None,
2078
+ loop: asyncio.AbstractEventLoop,
2079
+ resumable: bool,
1920
2080
  ) -> AsyncIterator[ValueEvent]:
1921
2081
  """Enter a run, listen for cancellation while running, signal when done."
1922
2082
  This method should be called as a context manager by a worker executing a run.
1923
2083
  """
1924
2084
  from langgraph_api.asyncio import SimpleTaskGroup, ValueEvent
2085
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
1925
2086
 
1926
2087
  stream_manager = get_stream_manager()
1927
2088
  # Get control queue for this run (normal queue is created during run creation)
@@ -1941,12 +2102,14 @@ class Runs(Authenticated):
1941
2102
  )
1942
2103
  await stream_manager.put(run_id, thread_id, control_message)
1943
2104
 
1944
- # Signal done to all subscribers
2105
+ # Signal done to all subscribers using stream codec
1945
2106
  stream_message = Message(
1946
2107
  topic=f"run:{run_id}:stream".encode(),
1947
- data={"event": "control", "message": b"done"},
2108
+ data=STREAM_CODEC.encode("control", b"done"),
2109
+ )
2110
+ await stream_manager.put(
2111
+ run_id, thread_id, stream_message, resumable=resumable
1948
2112
  )
1949
- await stream_manager.put(run_id, thread_id, stream_message)
1950
2113
 
1951
2114
  # Remove the control_queue (normal queue is cleaned up during run deletion)
1952
2115
  await stream_manager.remove_control_queue(run_id, thread_id, control_queue)
@@ -1988,7 +2151,6 @@ class Runs(Authenticated):
1988
2151
  ctx: Auth.types.BaseAuthContext | None = None,
1989
2152
  ) -> AsyncIterator[Run]:
1990
2153
  """Create a run."""
1991
- from langgraph_api.config import FF_RICH_THREADS
1992
2154
  from langgraph_api.schema import Run, Thread
1993
2155
 
1994
2156
  assistant_id = _ensure_uuid(assistant_id)
@@ -2004,6 +2166,7 @@ class Runs(Authenticated):
2004
2166
  run_id = _ensure_uuid(run_id) if run_id else None
2005
2167
  metadata = metadata if metadata is not None else {}
2006
2168
  config = kwargs.get("config", {})
2169
+ temporary = kwargs.get("temporary", False)
2007
2170
 
2008
2171
  # Handle thread creation/update
2009
2172
  existing_thread = next(
@@ -2013,7 +2176,7 @@ class Runs(Authenticated):
2013
2176
  ctx,
2014
2177
  "create_run",
2015
2178
  Auth.types.RunsCreate(
2016
- thread_id=thread_id,
2179
+ thread_id=None if temporary else thread_id,
2017
2180
  assistant_id=assistant_id,
2018
2181
  run_id=run_id,
2019
2182
  status=status,
@@ -2034,49 +2197,35 @@ class Runs(Authenticated):
2034
2197
  # Create new thread
2035
2198
  if thread_id is None:
2036
2199
  thread_id = uuid4()
2037
- if FF_RICH_THREADS:
2038
- thread = Thread(
2039
- thread_id=thread_id,
2040
- status="busy",
2041
- metadata={
2042
- "graph_id": assistant["graph_id"],
2043
- "assistant_id": str(assistant_id),
2044
- **(config.get("metadata") or {}),
2045
- **metadata,
2046
- },
2047
- config=Runs._merge_jsonb(
2048
- assistant["config"],
2049
- config,
2050
- {
2051
- "configurable": Runs._merge_jsonb(
2052
- Runs._get_configurable(assistant["config"]),
2053
- )
2054
- },
2055
- ),
2056
- created_at=datetime.now(UTC),
2057
- updated_at=datetime.now(UTC),
2058
- values=b"",
2059
- )
2060
- else:
2061
- thread = Thread(
2062
- thread_id=thread_id,
2063
- status="idle",
2064
- metadata={
2065
- "graph_id": assistant["graph_id"],
2066
- "assistant_id": str(assistant_id),
2067
- **(config.get("metadata") or {}),
2068
- **metadata,
2200
+
2201
+ thread = Thread(
2202
+ thread_id=thread_id,
2203
+ status="busy",
2204
+ metadata={
2205
+ "graph_id": assistant["graph_id"],
2206
+ "assistant_id": str(assistant_id),
2207
+ **(config.get("metadata") or {}),
2208
+ **metadata,
2209
+ },
2210
+ config=Runs._merge_jsonb(
2211
+ assistant["config"],
2212
+ config,
2213
+ {
2214
+ "configurable": Runs._merge_jsonb(
2215
+ Runs._get_configurable(assistant["config"]),
2216
+ )
2069
2217
  },
2070
- config={},
2071
- created_at=datetime.now(UTC),
2072
- updated_at=datetime.now(UTC),
2073
- values=b"",
2074
- )
2218
+ ),
2219
+ created_at=datetime.now(UTC),
2220
+ updated_at=datetime.now(UTC),
2221
+ values=b"",
2222
+ )
2223
+
2075
2224
  await logger.ainfo("Creating thread", thread_id=thread_id)
2076
2225
  conn.store["threads"].append(thread)
2077
2226
  elif existing_thread:
2078
2227
  # Update existing thread
2079
- if FF_RICH_THREADS and existing_thread["status"] != "busy":
2228
+ if existing_thread["status"] != "busy":
2080
2229
  existing_thread["status"] = "busy"
2081
2230
  existing_thread["metadata"] = Runs._merge_jsonb(
2082
2231
  existing_thread["metadata"],
@@ -2253,66 +2402,6 @@ class Runs(Authenticated):
2253
2402
 
2254
2403
  return _yield_deleted()
2255
2404
 
2256
- @staticmethod
2257
- async def join(
2258
- run_id: UUID,
2259
- *,
2260
- thread_id: UUID,
2261
- ctx: Auth.types.BaseAuthContext | None = None,
2262
- ) -> Fragment:
2263
- """Wait for a run to complete. If already done, return immediately.
2264
-
2265
- Returns:
2266
- the final state of the run.
2267
- """
2268
- from langgraph_api.serde import Fragment
2269
- from langgraph_api.utils import fetchone
2270
-
2271
- async with connect() as conn:
2272
- # Validate ownership
2273
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
2274
- await fetchone(thread_iter)
2275
- last_chunk: bytes | None = None
2276
- # wait for the run to complete
2277
- # Rely on this join's auth
2278
- async for mode, chunk, _ in Runs.Stream.join(
2279
- run_id,
2280
- thread_id=thread_id,
2281
- ctx=ctx,
2282
- ignore_404=True,
2283
- stream_mode=["values", "updates", "error"],
2284
- ):
2285
- if mode == b"values":
2286
- last_chunk = chunk
2287
- elif mode == b"updates" and b"__interrupt__" in chunk:
2288
- last_chunk = chunk
2289
- elif mode == b"error":
2290
- last_chunk = orjson.dumps({"__error__": orjson.Fragment(chunk)})
2291
- # if we received a final chunk, return it
2292
- if last_chunk is not None:
2293
- # ie. if the run completed while we were waiting for it
2294
- return Fragment(last_chunk)
2295
- else:
2296
- # otherwise, the run had already finished, so fetch the state from thread
2297
- async with connect() as conn:
2298
- thread_iter = await Threads.get(conn, thread_id, ctx=ctx)
2299
- thread = await fetchone(thread_iter)
2300
- if thread["status"] == "error":
2301
- return Fragment(
2302
- orjson.dumps({"__error__": orjson.Fragment(thread["error"])})
2303
- )
2304
- if thread["status"] == "interrupted":
2305
- # Get an interrupt for the thread. There is the case where there are multiple interrupts for the same run and we may not show the same
2306
- # interrupt, but we'll always show one. Long term we should show all of them.
2307
- try:
2308
- interrupt_map = thread["interrupts"]
2309
- interrupt = [next(iter(interrupt_map.values()))[0]]
2310
- return Fragment(orjson.dumps({"__interrupt__": interrupt}))
2311
- except Exception:
2312
- # No interrupt, but status is interrupted from a before/after block. Default back to values.
2313
- pass
2314
- return thread["values"]
2315
-
2316
2405
  @staticmethod
2317
2406
  async def cancel(
2318
2407
  conn: InMemConnectionProto | AsyncConnectionProto,
@@ -2538,7 +2627,7 @@ class Runs(Authenticated):
2538
2627
  async def subscribe(
2539
2628
  run_id: UUID,
2540
2629
  thread_id: UUID | None = None,
2541
- ) -> asyncio.Queue:
2630
+ ) -> ContextQueue:
2542
2631
  """Subscribe to the run stream, returning a queue."""
2543
2632
  stream_manager = get_stream_manager()
2544
2633
  queue = await stream_manager.add_queue(_ensure_uuid(run_id), thread_id)
@@ -2562,54 +2651,38 @@ class Runs(Authenticated):
2562
2651
  async def join(
2563
2652
  run_id: UUID,
2564
2653
  *,
2654
+ stream_channel: asyncio.Queue,
2565
2655
  thread_id: UUID,
2566
2656
  ignore_404: bool = False,
2567
2657
  cancel_on_disconnect: bool = False,
2568
- stream_channel: asyncio.Queue | None = None,
2569
2658
  stream_mode: list[StreamMode] | StreamMode | None = None,
2570
2659
  last_event_id: str | None = None,
2571
2660
  ctx: Auth.types.BaseAuthContext | None = None,
2572
2661
  ) -> AsyncIterator[tuple[bytes, bytes, bytes | None]]:
2573
2662
  """Stream the run output."""
2574
2663
  from langgraph_api.asyncio import create_task
2575
- from langgraph_api.serde import json_loads
2576
-
2577
- queue = (
2578
- stream_channel
2579
- if stream_channel
2580
- else await Runs.Stream.subscribe(run_id, thread_id)
2581
- )
2664
+ from langgraph_api.serde import json_dumpb
2665
+ from langgraph_api.utils.stream_codec import decode_stream_message
2582
2666
 
2667
+ queue = stream_channel
2583
2668
  try:
2584
2669
  async with connect() as conn:
2585
- filters = await Runs.handle_event(
2586
- ctx,
2587
- "read",
2588
- Auth.types.ThreadsRead(thread_id=thread_id),
2589
- )
2590
- if filters:
2591
- thread = await Threads._get_with_filters(
2592
- cast(InMemConnectionProto, conn), thread_id, filters
2593
- )
2594
- if not thread:
2595
- raise WrappedHTTPException(
2596
- HTTPException(
2597
- status_code=404, detail="Thread not found"
2598
- )
2599
- )
2670
+ try:
2671
+ await Runs.Stream.check_run_stream_auth(run_id, thread_id, ctx)
2672
+ except HTTPException as e:
2673
+ raise WrappedHTTPException(e) from None
2600
2674
  run = await Runs.get(conn, run_id, thread_id=thread_id, ctx=ctx)
2601
2675
 
2602
2676
  for message in get_stream_manager().restore_messages(
2603
2677
  run_id, thread_id, last_event_id
2604
2678
  ):
2605
2679
  data, id = message.data, message.id
2606
-
2607
- data = json_loads(data)
2608
- mode = data["event"]
2609
- message = data["message"]
2680
+ decoded = decode_stream_message(data, channel=message.topic)
2681
+ mode = decoded.event_bytes.decode("utf-8")
2682
+ payload = decoded.message_bytes
2610
2683
 
2611
2684
  if mode == "control":
2612
- if message == b"done":
2685
+ if payload == b"done":
2613
2686
  return
2614
2687
  elif (
2615
2688
  not stream_mode
@@ -2622,7 +2695,7 @@ class Runs(Authenticated):
2622
2695
  and mode.startswith("messages")
2623
2696
  )
2624
2697
  ):
2625
- yield mode.encode(), base64.b64decode(message), id
2698
+ yield mode.encode(), payload, id
2626
2699
  logger.debug(
2627
2700
  "Replayed run event",
2628
2701
  run_id=str(run_id),
@@ -2636,13 +2709,12 @@ class Runs(Authenticated):
2636
2709
  # Wait for messages with a timeout
2637
2710
  message = await asyncio.wait_for(queue.get(), timeout=0.5)
2638
2711
  data, id = message.data, message.id
2639
-
2640
- data = json_loads(data)
2641
- mode = data["event"]
2642
- message = data["message"]
2712
+ decoded = decode_stream_message(data, channel=message.topic)
2713
+ mode = decoded.event_bytes.decode("utf-8")
2714
+ payload = decoded.message_bytes
2643
2715
 
2644
2716
  if mode == "control":
2645
- if message == b"done":
2717
+ if payload == b"done":
2646
2718
  break
2647
2719
  elif (
2648
2720
  not stream_mode
@@ -2655,13 +2727,13 @@ class Runs(Authenticated):
2655
2727
  and mode.startswith("messages")
2656
2728
  )
2657
2729
  ):
2658
- yield mode.encode(), base64.b64decode(message), id
2730
+ yield mode.encode(), payload, id
2659
2731
  logger.debug(
2660
2732
  "Streamed run event",
2661
2733
  run_id=str(run_id),
2662
2734
  stream_mode=mode,
2663
2735
  message_id=id,
2664
- data=message,
2736
+ data=payload,
2665
2737
  )
2666
2738
  except TimeoutError:
2667
2739
  # Check if the run is still pending
@@ -2675,8 +2747,10 @@ class Runs(Authenticated):
2675
2747
  elif run is None:
2676
2748
  yield (
2677
2749
  b"error",
2678
- HTTPException(
2679
- status_code=404, detail="Run not found"
2750
+ json_dumpb(
2751
+ HTTPException(
2752
+ status_code=404, detail="Run not found"
2753
+ )
2680
2754
  ),
2681
2755
  None,
2682
2756
  )
@@ -2693,6 +2767,25 @@ class Runs(Authenticated):
2693
2767
  stream_manager = get_stream_manager()
2694
2768
  await stream_manager.remove_queue(run_id, thread_id, queue)
2695
2769
 
2770
+ @staticmethod
2771
+ async def check_run_stream_auth(
2772
+ run_id: UUID,
2773
+ thread_id: UUID,
2774
+ ctx: Auth.types.BaseAuthContext | None = None,
2775
+ ) -> None:
2776
+ async with connect() as conn:
2777
+ filters = await Runs.handle_event(
2778
+ ctx,
2779
+ "read",
2780
+ Auth.types.ThreadsRead(thread_id=thread_id),
2781
+ )
2782
+ if filters:
2783
+ thread = await Threads._get_with_filters(
2784
+ cast(InMemConnectionProto, conn), thread_id, filters
2785
+ )
2786
+ if not thread:
2787
+ raise HTTPException(status_code=404, detail="Thread not found")
2788
+
2696
2789
  @staticmethod
2697
2790
  async def publish(
2698
2791
  run_id: UUID | str,
@@ -2703,18 +2796,13 @@ class Runs(Authenticated):
2703
2796
  resumable: bool = False,
2704
2797
  ) -> None:
2705
2798
  """Publish a message to all subscribers of the run stream."""
2706
- from langgraph_api.serde import json_dumpb
2799
+ from langgraph_api.utils.stream_codec import STREAM_CODEC
2707
2800
 
2708
2801
  topic = f"run:{run_id}:stream".encode()
2709
2802
 
2710
2803
  stream_manager = get_stream_manager()
2711
- # Send to all queues subscribed to this run_id
2712
- payload = json_dumpb(
2713
- {
2714
- "event": event,
2715
- "message": message,
2716
- }
2717
- )
2804
+ # Send to all queues subscribed to this run_id using protocol frame
2805
+ payload = STREAM_CODEC.encode(event, message)
2718
2806
  await stream_manager.put(
2719
2807
  run_id, thread_id, Message(topic=topic, data=payload), resumable
2720
2808
  )
@@ -2761,7 +2849,9 @@ class Crons:
2761
2849
  schedule: str,
2762
2850
  cron_id: UUID | None = None,
2763
2851
  thread_id: UUID | None = None,
2852
+ on_run_completed: Literal["delete", "keep"] | None = None,
2764
2853
  end_time: datetime | None = None,
2854
+ metadata: dict | None = None,
2765
2855
  ctx: Auth.types.BaseAuthContext | None = None,
2766
2856
  ) -> AsyncIterator[Cron]:
2767
2857
  raise NotImplementedError
@@ -2852,19 +2942,154 @@ def _delete_checkpoints_for_thread(
2852
2942
  )
2853
2943
 
2854
2944
 
2855
- def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -> bool:
2945
+ def _validate_filter_structure(
2946
+ filters: Auth.types.FilterType | None,
2947
+ nesting_level: int = 0,
2948
+ ) -> None:
2949
+ """Validate the structure of filter conditions without checking matches.
2950
+
2951
+ Args:
2952
+ filters: The filter conditions to validate
2953
+ nesting_level: Current depth of nested operators (max 2)
2954
+
2955
+ Raises:
2956
+ HTTPException: If the filter structure is invalid
2957
+ """
2958
+ if nesting_level > 2:
2959
+ raise HTTPException(
2960
+ status_code=500,
2961
+ detail="Your auth handler returned a filter with too many nested operators. The maximum depth for nested operators is 2. Please simplify your filter.",
2962
+ )
2963
+
2964
+ if not filters:
2965
+ return
2966
+
2967
+ # Handle $or operator
2968
+ if "$or" in filters:
2969
+ or_groups = filters["$or"]
2970
+ if not isinstance(or_groups, list) or not len(or_groups) >= 2:
2971
+ raise HTTPException(
2972
+ status_code=500,
2973
+ detail="Your auth handler returned a filter with an invalid $or operator. The $or operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
2974
+ )
2975
+
2976
+ # Recursively validate all groups
2977
+ for group in or_groups:
2978
+ _validate_filter_structure(group, nesting_level=nesting_level + 1)
2979
+
2980
+ # Validate remaining filters (implicit AND with the $or)
2981
+ remaining_filters = {k: v for k, v in filters.items() if k != "$or"}
2982
+ if remaining_filters:
2983
+ _validate_filter_structure(
2984
+ remaining_filters, nesting_level=nesting_level + 1
2985
+ )
2986
+
2987
+ # Handle $and operator
2988
+ if "$and" in filters:
2989
+ and_groups = filters["$and"]
2990
+ if not isinstance(and_groups, list) or not len(and_groups) >= 2:
2991
+ raise HTTPException(
2992
+ status_code=500,
2993
+ detail="Your auth handler returned a filter with an invalid $and operator. The $and operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
2994
+ )
2995
+
2996
+ # Recursively validate all groups
2997
+ for group in and_groups:
2998
+ _validate_filter_structure(group, nesting_level=nesting_level + 1)
2999
+
3000
+ # Validate remaining filters (implicit AND with the $and)
3001
+ remaining_filters = {k: v for k, v in filters.items() if k != "$and"}
3002
+ if remaining_filters:
3003
+ _validate_filter_structure(
3004
+ remaining_filters, nesting_level=nesting_level + 1
3005
+ )
3006
+
3007
+
3008
+ def _check_filter_match(
3009
+ metadata: dict,
3010
+ filters: Auth.types.FilterType | None,
3011
+ nesting_level: int = 0,
3012
+ ) -> bool:
2856
3013
  """Check if metadata matches the filter conditions.
2857
3014
 
2858
3015
  Args:
2859
3016
  metadata: The metadata to check
2860
3017
  filters: The filter conditions to apply
3018
+ nesting_level: Current depth of nested operators (max 2)
2861
3019
 
2862
3020
  Returns:
2863
3021
  True if the metadata matches all filter conditions, False otherwise
2864
3022
  """
3023
+ if nesting_level > 2:
3024
+ raise HTTPException(
3025
+ status_code=500,
3026
+ detail="Your auth handler returned a filter with too many nested operators. The maximum depth for nested operators is 2. Please simplify your filter.",
3027
+ )
3028
+
2865
3029
  if not filters:
2866
3030
  return True
2867
3031
 
3032
+ # Handle $or operator
3033
+ if "$or" in filters:
3034
+ or_groups = filters["$or"]
3035
+ if not isinstance(or_groups, list) or not len(or_groups) >= 2:
3036
+ raise HTTPException(
3037
+ status_code=500,
3038
+ detail="Your auth handler returned a filter with an invalid $or operator. The $or operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
3039
+ )
3040
+
3041
+ # Validate all groups first to ensure nesting limits are respected
3042
+ # (even if we short-circuit during matching)
3043
+ for group in or_groups:
3044
+ _validate_filter_structure(group, nesting_level=nesting_level + 1)
3045
+
3046
+ # At least one group must match
3047
+ or_match = False
3048
+ for group in or_groups:
3049
+ if _check_filter_match(metadata, group, nesting_level=nesting_level + 1):
3050
+ or_match = True
3051
+ break
3052
+
3053
+ if not or_match:
3054
+ return False
3055
+
3056
+ # Check remaining filters (implicit AND with the $or)
3057
+ remaining_filters = {k: v for k, v in filters.items() if k != "$or"}
3058
+ if remaining_filters:
3059
+ return _check_filter_match(
3060
+ metadata, remaining_filters, nesting_level=nesting_level + 1
3061
+ )
3062
+ return True
3063
+
3064
+ # Handle $and operator
3065
+ if "$and" in filters:
3066
+ and_groups = filters["$and"]
3067
+ if not isinstance(and_groups, list) or not len(and_groups) >= 2:
3068
+ raise HTTPException(
3069
+ status_code=500,
3070
+ detail="Your auth handler returned a filter with an invalid $and operator. The $and operator must be a list of at least 2 filter objects. Check the filter returned by your auth handler.",
3071
+ )
3072
+
3073
+ # Validate all groups first to ensure nesting limits are respected
3074
+ for group in and_groups:
3075
+ _validate_filter_structure(group, nesting_level=nesting_level + 1)
3076
+
3077
+ # All groups must match
3078
+ for group in and_groups:
3079
+ if not _check_filter_match(
3080
+ metadata, group, nesting_level=nesting_level + 1
3081
+ ):
3082
+ return False
3083
+
3084
+ # Check remaining filters (implicit AND with the $and)
3085
+ remaining_filters = {k: v for k, v in filters.items() if k != "$and"}
3086
+ if remaining_filters:
3087
+ return _check_filter_match(
3088
+ metadata, remaining_filters, nesting_level=nesting_level + 1
3089
+ )
3090
+ return True
3091
+
3092
+ # Regular filter logic (implicit AND)
2868
3093
  for key, value in filters.items():
2869
3094
  if isinstance(value, dict):
2870
3095
  op = next(iter(value))
@@ -2874,11 +3099,18 @@ def _check_filter_match(metadata: dict, filters: Auth.types.FilterType | None) -
2874
3099
  if key not in metadata or metadata[key] != filter_value:
2875
3100
  return False
2876
3101
  elif op == "$contains":
2877
- if (
2878
- key not in metadata
2879
- or not isinstance(metadata[key], list)
2880
- or filter_value not in metadata[key]
2881
- ):
3102
+ if key not in metadata or not isinstance(metadata[key], list):
3103
+ return False
3104
+
3105
+ if isinstance(filter_value, list):
3106
+ # Mimick Postgres containment operator behavior.
3107
+ # It would be more efficient to use set operations here,
3108
+ # but we can't assume that elements are hashable.
3109
+ # The Postgres algorithm is also O(n^2).
3110
+ for filter_element in filter_value:
3111
+ if filter_element not in metadata[key]:
3112
+ return False
3113
+ elif filter_value not in metadata[key]:
2882
3114
  return False
2883
3115
  else:
2884
3116
  # Direct equality
@@ -2894,6 +3126,7 @@ async def _empty_generator():
2894
3126
 
2895
3127
 
2896
3128
  __all__ = [
3129
+ "StreamHandler",
2897
3130
  "Assistants",
2898
3131
  "Crons",
2899
3132
  "Runs",