langgraph-runtime-inmem 0.6.12__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.6.12"
12
+ __version__ = "0.8.0"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -34,10 +34,12 @@ if typing.TYPE_CHECKING:
34
34
  from langgraph_api.config import ThreadTTLConfig
35
35
  from langgraph_api.schema import (
36
36
  Assistant,
37
+ AssistantSelectField,
37
38
  Checkpoint,
38
39
  Config,
39
40
  Context,
40
41
  Cron,
42
+ CronSelectField,
41
43
  DeprecatedInterrupt,
42
44
  IfNotExists,
43
45
  MetadataInput,
@@ -46,14 +48,17 @@ if typing.TYPE_CHECKING:
46
48
  OnConflictBehavior,
47
49
  QueueStats,
48
50
  Run,
51
+ RunSelectField,
49
52
  RunStatus,
50
53
  StreamMode,
51
54
  Thread,
55
+ ThreadSelectField,
52
56
  ThreadStatus,
53
57
  ThreadUpdateResponse,
54
58
  )
55
59
  from langgraph_api.schema import Interrupt as InterruptSchema
56
60
  from langgraph_api.serde import Fragment
61
+ from langgraph_api.utils import AsyncConnectionProto
57
62
 
58
63
 
59
64
  logger = structlog.stdlib.get_logger(__name__)
@@ -136,6 +141,7 @@ class Assistants(Authenticated):
136
141
  offset: int,
137
142
  sort_by: str | None = None,
138
143
  sort_order: str | None = None,
144
+ select: list[AssistantSelectField] | None = None,
139
145
  ctx: Auth.types.BaseAuthContext | None = None,
140
146
  ) -> tuple[AsyncIterator[Assistant], int]:
141
147
  metadata = metadata if metadata is not None else {}
@@ -178,6 +184,7 @@ class Assistants(Authenticated):
178
184
  else:
179
185
  filtered_assistants.sort(key=lambda x: x.get(sort_by), reverse=reverse)
180
186
  else:
187
+ sort_by = "created_at"
181
188
  # Default sorting by created_at in descending order
182
189
  filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
183
190
 
@@ -187,14 +194,21 @@ class Assistants(Authenticated):
187
194
 
188
195
  async def assistant_iterator() -> AsyncIterator[Assistant]:
189
196
  for assistant in paginated_assistants:
190
- yield assistant
197
+ if select:
198
+ # Filter to only selected fields
199
+ filtered_assistant = {
200
+ k: v for k, v in assistant.items() if k in select
201
+ }
202
+ yield filtered_assistant
203
+ else:
204
+ yield assistant
191
205
 
192
206
  return assistant_iterator(), cur
193
207
 
194
208
  @staticmethod
195
209
  async def get(
196
210
  conn: InMemConnectionProto,
197
- assistant_id: UUID,
211
+ assistant_id: UUID | str,
198
212
  ctx: Auth.types.BaseAuthContext | None = None,
199
213
  ) -> AsyncIterator[Assistant]:
200
214
  """Get an assistant by ID."""
@@ -217,7 +231,7 @@ class Assistants(Authenticated):
217
231
  @staticmethod
218
232
  async def put(
219
233
  conn: InMemConnectionProto,
220
- assistant_id: UUID,
234
+ assistant_id: UUID | str,
221
235
  *,
222
236
  graph_id: str,
223
237
  config: Config,
@@ -597,6 +611,37 @@ class Assistants(Authenticated):
597
611
 
598
612
  return _yield_versions()
599
613
 
614
+ @staticmethod
615
+ async def count(
616
+ conn: InMemConnectionProto,
617
+ *,
618
+ graph_id: str | None = None,
619
+ metadata: MetadataInput = None,
620
+ ctx: Auth.types.BaseAuthContext | None = None,
621
+ ) -> int:
622
+ """Get count of assistants."""
623
+ metadata = metadata if metadata is not None else {}
624
+ filters = await Assistants.handle_event(
625
+ ctx,
626
+ "search",
627
+ Auth.types.AssistantsSearch(
628
+ graph_id=graph_id, metadata=metadata, limit=0, offset=0
629
+ ),
630
+ )
631
+
632
+ count = 0
633
+ for assistant in conn.store["assistants"]:
634
+ if (
635
+ (not graph_id or assistant["graph_id"] == graph_id)
636
+ and (
637
+ not metadata or is_jsonb_contained(assistant["metadata"], metadata)
638
+ )
639
+ and (not filters or _check_filter_match(assistant["metadata"], filters))
640
+ ):
641
+ count += 1
642
+
643
+ return count
644
+
600
645
 
601
646
  def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
602
647
  """
@@ -680,7 +725,7 @@ def _patch_interrupt(
680
725
  "value": interrupt.value,
681
726
  "resumable": interrupt.resumable,
682
727
  "ns": interrupt.ns,
683
- "when": interrupt.when,
728
+ "when": interrupt.when, # type: ignore[unresolved-attribute]
684
729
  }
685
730
 
686
731
 
@@ -698,6 +743,7 @@ class Threads(Authenticated):
698
743
  offset: int,
699
744
  sort_by: str | None = None,
700
745
  sort_order: str | None = None,
746
+ select: list[ThreadSelectField] | None = None,
701
747
  ctx: Auth.types.BaseAuthContext | None = None,
702
748
  ) -> tuple[AsyncIterator[Thread], int]:
703
749
  threads = conn.store["threads"]
@@ -747,6 +793,7 @@ class Threads(Authenticated):
747
793
  filtered_threads, key=lambda x: x.get(sort_by), reverse=reverse
748
794
  )
749
795
  else:
796
+ sort_by = "created_at"
750
797
  # Default sorting by created_at in descending order
751
798
  sorted_threads = sorted(
752
799
  filtered_threads, key=lambda x: x["updated_at"], reverse=True
@@ -758,7 +805,12 @@ class Threads(Authenticated):
758
805
 
759
806
  async def thread_iterator() -> AsyncIterator[Thread]:
760
807
  for thread in paginated_threads:
761
- yield thread
808
+ if select:
809
+ # Filter to only selected fields
810
+ filtered_thread = {k: v for k, v in thread.items() if k in select}
811
+ yield filtered_thread
812
+ else:
813
+ yield thread
762
814
 
763
815
  return thread_iterator(), cursor
764
816
 
@@ -822,7 +874,7 @@ class Threads(Authenticated):
822
874
  @staticmethod
823
875
  async def put(
824
876
  conn: InMemConnectionProto,
825
- thread_id: UUID,
877
+ thread_id: UUID | str,
826
878
  *,
827
879
  metadata: MetadataInput,
828
880
  if_exists: OnConflictBehavior,
@@ -1001,6 +1053,7 @@ class Threads(Authenticated):
1001
1053
  thread_id: UUID,
1002
1054
  run_id: UUID,
1003
1055
  run_status: RunStatus | Literal["rollback"],
1056
+ graph_id: str,
1004
1057
  checkpoint: CheckpointPayload | None = None,
1005
1058
  exception: BaseException | None = None,
1006
1059
  ) -> None:
@@ -1080,6 +1133,7 @@ class Threads(Authenticated):
1080
1133
  final_thread_status = "busy"
1081
1134
  else:
1082
1135
  final_thread_status = base_thread_status
1136
+ thread["metadata"]["graph_id"] = graph_id
1083
1137
  thread.update(
1084
1138
  {
1085
1139
  "updated_at": now,
@@ -1273,7 +1327,15 @@ class Threads(Authenticated):
1273
1327
  metadata = thread.get("metadata", {})
1274
1328
  thread_config = thread.get("config", {})
1275
1329
 
1276
- if graph_id := metadata.get("graph_id"):
1330
+ # Fallback to graph_id from run if not in thread metadata
1331
+ graph_id = metadata.get("graph_id")
1332
+ if not graph_id:
1333
+ for run in conn.store["runs"]:
1334
+ if run["thread_id"] == thread_id:
1335
+ graph_id = run["kwargs"]["config"]["configurable"]["graph_id"]
1336
+ break
1337
+
1338
+ if graph_id:
1277
1339
  # format latest checkpoint for response
1278
1340
  checkpointer.latest_iter = checkpoint
1279
1341
  async with get_graph(
@@ -1351,7 +1413,15 @@ class Threads(Authenticated):
1351
1413
  detail=f"Thread {thread_id} has in-flight runs: {pending_runs}",
1352
1414
  )
1353
1415
 
1354
- if graph_id := metadata.get("graph_id"):
1416
+ # Fallback to graph_id from run if not in thread metadata
1417
+ graph_id = metadata.get("graph_id")
1418
+ if not graph_id:
1419
+ for run in conn.store["runs"]:
1420
+ if run["thread_id"] == thread_id:
1421
+ graph_id = run["kwargs"]["config"]["configurable"]["graph_id"]
1422
+ break
1423
+
1424
+ if graph_id:
1355
1425
  config["configurable"].setdefault("graph_id", graph_id)
1356
1426
 
1357
1427
  checkpointer.latest_iter = checkpoint
@@ -1389,7 +1459,11 @@ class Threads(Authenticated):
1389
1459
  checkpoint_id=next_config["configurable"]["checkpoint_id"],
1390
1460
  )
1391
1461
  else:
1392
- raise HTTPException(status_code=400, detail="Thread has no graph ID.")
1462
+ raise HTTPException(
1463
+ status_code=400,
1464
+ detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
1465
+ " This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
1466
+ )
1393
1467
 
1394
1468
  @staticmethod
1395
1469
  async def bulk(
@@ -1471,7 +1545,11 @@ class Threads(Authenticated):
1471
1545
  checkpoint=next_config["configurable"],
1472
1546
  )
1473
1547
  else:
1474
- raise HTTPException(status_code=400, detail="Thread has no graph ID")
1548
+ raise HTTPException(
1549
+ status_code=400,
1550
+ detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
1551
+ " This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
1552
+ )
1475
1553
 
1476
1554
  @staticmethod
1477
1555
  async def list(
@@ -1533,6 +1611,53 @@ class Threads(Authenticated):
1533
1611
 
1534
1612
  return []
1535
1613
 
1614
+ @staticmethod
1615
+ async def count(
1616
+ conn: InMemConnectionProto,
1617
+ *,
1618
+ metadata: MetadataInput = None,
1619
+ values: MetadataInput = None,
1620
+ status: ThreadStatus | None = None,
1621
+ ctx: Auth.types.BaseAuthContext | None = None,
1622
+ ) -> int:
1623
+ """Get count of threads."""
1624
+ threads = conn.store["threads"]
1625
+ metadata = metadata if metadata is not None else {}
1626
+ values = values if values is not None else {}
1627
+ filters = await Threads.handle_event(
1628
+ ctx,
1629
+ "search",
1630
+ Auth.types.ThreadsSearch(
1631
+ metadata=metadata,
1632
+ values=values,
1633
+ status=status,
1634
+ limit=0,
1635
+ offset=0,
1636
+ ),
1637
+ )
1638
+
1639
+ count = 0
1640
+ for thread in threads:
1641
+ if filters and not _check_filter_match(thread["metadata"], filters):
1642
+ continue
1643
+
1644
+ if metadata and not is_jsonb_contained(thread["metadata"], metadata):
1645
+ continue
1646
+
1647
+ if (
1648
+ values
1649
+ and "values" in thread
1650
+ and not is_jsonb_contained(thread["values"], values)
1651
+ ):
1652
+ continue
1653
+
1654
+ if status and thread.get("status") != status:
1655
+ continue
1656
+
1657
+ count += 1
1658
+
1659
+ return count
1660
+
1536
1661
 
1537
1662
  RUN_LOCK = asyncio.Lock()
1538
1663
 
@@ -1678,9 +1803,9 @@ class Runs(Authenticated):
1678
1803
  await stream_manager.remove_control_queue(run_id, queue)
1679
1804
 
1680
1805
  @staticmethod
1681
- async def sweep(conn: InMemConnectionProto) -> list[UUID]:
1806
+ async def sweep() -> None:
1682
1807
  """Sweep runs that are no longer running"""
1683
- return []
1808
+ pass
1684
1809
 
1685
1810
  @staticmethod
1686
1811
  def _merge_jsonb(*objects: dict) -> dict:
@@ -1698,7 +1823,7 @@ class Runs(Authenticated):
1698
1823
 
1699
1824
  @staticmethod
1700
1825
  async def put(
1701
- conn: InMemConnectionProto,
1826
+ conn: InMemConnectionProto | AsyncConnectionProto,
1702
1827
  assistant_id: UUID,
1703
1828
  kwargs: dict,
1704
1829
  *,
@@ -1714,6 +1839,7 @@ class Runs(Authenticated):
1714
1839
  ctx: Auth.types.BaseAuthContext | None = None,
1715
1840
  ) -> AsyncIterator[Run]:
1716
1841
  """Create a run."""
1842
+ from langgraph_api.config import FF_RICH_THREADS
1717
1843
  from langgraph_api.schema import Run, Thread
1718
1844
 
1719
1845
  assistant_id = _ensure_uuid(assistant_id)
@@ -1759,34 +1885,50 @@ class Runs(Authenticated):
1759
1885
  # Create new thread
1760
1886
  if thread_id is None:
1761
1887
  thread_id = uuid4()
1762
- thread = Thread(
1763
- thread_id=thread_id,
1764
- status="busy",
1765
- metadata={
1766
- "graph_id": assistant["graph_id"],
1767
- "assistant_id": str(assistant_id),
1768
- **(config.get("metadata") or {}),
1769
- **metadata,
1770
- },
1771
- config=Runs._merge_jsonb(
1772
- assistant["config"],
1773
- config,
1774
- {
1775
- "configurable": Runs._merge_jsonb(
1776
- Runs._get_configurable(assistant["config"]),
1777
- Runs._get_configurable(config),
1778
- )
1888
+ if FF_RICH_THREADS:
1889
+ thread = Thread(
1890
+ thread_id=thread_id,
1891
+ status="busy",
1892
+ metadata={
1893
+ "graph_id": assistant["graph_id"],
1894
+ "assistant_id": str(assistant_id),
1895
+ **(config.get("metadata") or {}),
1896
+ **metadata,
1779
1897
  },
1780
- ),
1781
- created_at=datetime.now(UTC),
1782
- updated_at=datetime.now(UTC),
1783
- values=b"",
1784
- )
1898
+ config=Runs._merge_jsonb(
1899
+ assistant["config"],
1900
+ config,
1901
+ {
1902
+ "configurable": Runs._merge_jsonb(
1903
+ Runs._get_configurable(assistant["config"]),
1904
+ Runs._get_configurable(config),
1905
+ )
1906
+ },
1907
+ ),
1908
+ created_at=datetime.now(UTC),
1909
+ updated_at=datetime.now(UTC),
1910
+ values=b"",
1911
+ )
1912
+ else:
1913
+ thread = Thread(
1914
+ thread_id=thread_id,
1915
+ status="idle",
1916
+ metadata={
1917
+ "graph_id": assistant["graph_id"],
1918
+ "assistant_id": str(assistant_id),
1919
+ **(config.get("metadata") or {}),
1920
+ **metadata,
1921
+ },
1922
+ config={},
1923
+ created_at=datetime.now(UTC),
1924
+ updated_at=datetime.now(UTC),
1925
+ values=b"",
1926
+ )
1785
1927
  await logger.ainfo("Creating thread", thread_id=thread_id)
1786
1928
  conn.store["threads"].append(thread)
1787
1929
  elif existing_thread:
1788
1930
  # Update existing thread
1789
- if existing_thread["status"] != "busy":
1931
+ if FF_RICH_THREADS and existing_thread["status"] != "busy":
1790
1932
  existing_thread["status"] = "busy"
1791
1933
  existing_thread["metadata"] = Runs._merge_jsonb(
1792
1934
  existing_thread["metadata"],
@@ -2025,8 +2167,8 @@ class Runs(Authenticated):
2025
2167
 
2026
2168
  @staticmethod
2027
2169
  async def cancel(
2028
- conn: InMemConnectionProto,
2029
- run_ids: Sequence[UUID] | None = None,
2170
+ conn: InMemConnectionProto | AsyncConnectionProto,
2171
+ run_ids: Sequence[UUID | str] | None = None,
2030
2172
  *,
2031
2173
  action: Literal["interrupt", "rollback"] = "interrupt",
2032
2174
  thread_id: UUID | None = None,
@@ -2190,6 +2332,7 @@ class Runs(Authenticated):
2190
2332
  limit: int = 10,
2191
2333
  offset: int = 0,
2192
2334
  status: RunStatus | None = None,
2335
+ select: list[RunSelectField] | None = None,
2193
2336
  ctx: Auth.types.BaseAuthContext | None = None,
2194
2337
  ) -> AsyncIterator[Run]:
2195
2338
  """List all runs by thread."""
@@ -2217,7 +2360,12 @@ class Runs(Authenticated):
2217
2360
 
2218
2361
  async def _return():
2219
2362
  for run in sliced_runs:
2220
- yield run
2363
+ if select:
2364
+ # Filter to only selected fields
2365
+ filtered_run = {k: v for k, v in run.items() if k in select}
2366
+ yield filtered_run
2367
+ else:
2368
+ yield run
2221
2369
 
2222
2370
  return _return()
2223
2371
 
@@ -2389,7 +2537,7 @@ class Runs(Authenticated):
2389
2537
 
2390
2538
  @staticmethod
2391
2539
  async def publish(
2392
- run_id: UUID,
2540
+ run_id: UUID | str,
2393
2541
  event: str,
2394
2542
  message: bytes,
2395
2543
  *,
@@ -2490,10 +2638,24 @@ class Crons:
2490
2638
  thread_id: UUID | None,
2491
2639
  limit: int,
2492
2640
  offset: int,
2641
+ select: list[CronSelectField] | None = None,
2493
2642
  ctx: Auth.types.BaseAuthContext | None = None,
2494
- ) -> AsyncIterator[Cron]:
2643
+ sort_by: str | None = None,
2644
+ sort_order: Literal["asc", "desc"] | None = None,
2645
+ ) -> tuple[AsyncIterator[Cron], int]:
2495
2646
  raise NotImplementedError
2496
2647
 
2648
+ @staticmethod
2649
+ async def count(
2650
+ conn: InMemConnectionProto,
2651
+ *,
2652
+ assistant_id: UUID | None = None,
2653
+ thread_id: UUID | None = None,
2654
+ ctx: Auth.types.BaseAuthContext | None = None,
2655
+ ) -> int:
2656
+ """Get count of crons."""
2657
+ raise NotImplementedError("The in-mem server does not implement Crons.")
2658
+
2497
2659
 
2498
2660
  async def cancel_run(
2499
2661
  thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
@@ -128,8 +128,7 @@ async def queue():
128
128
  # sweep runs if needed
129
129
  if do_sweep:
130
130
  last_sweep_secs = loop.time()
131
- run_ids = await ops.Runs.sweep(conn)
132
- logger.info("Swept runs", run_ids=run_ids)
131
+ await ops.Runs.sweep()
133
132
  except Exception as exc:
134
133
  # keep trying to run the scheduler indefinitely
135
134
  logger.exception("Background worker scheduler failed", exc_info=exc)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.6.12
3
+ Version: 0.8.0
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=Y4pMrNRANpl0WNkm4rJgbGxM0sdgeJNJ17xTUA5Do0w,311
1
+ langgraph_runtime_inmem/__init__.py,sha256=ZCa8w1LAQPb9BIGUaaEwW1hak0vidcMAUArcD_C1lyU,310
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=G_6L2khpRDSpS2Vs_SujzHayODcwG5V2IhFP7LLBXgw,6349
4
4
  langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=lhvPaQRUmEKumeWJUMOcSmYpZkLbOWvAJT-wYzu5NrM,91701
8
- langgraph_runtime_inmem/queue.py,sha256=nqfgz7j_Jkh5Ek5-RsHB2Uvwbxguu9IUPkGXIxvFPns,10037
7
+ langgraph_runtime_inmem/ops.py,sha256=pV2IHiis4ydYFtroOfX-FcQjhXvk5O1ESuy2kk7mbyI,98064
8
+ langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.6.12.dist-info/METADATA,sha256=kVDdIT1O23mKzQMYtfFwpy9F4IYZzLYt60T12AczxeQ,566
12
- langgraph_runtime_inmem-0.6.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.6.12.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.8.0.dist-info/METADATA,sha256=hg_MVwCcdeCjZrYul77k4J5y5ByIUDOtTqqUOQbyqDs,565
12
+ langgraph_runtime_inmem-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.8.0.dist-info/RECORD,,