langgraph-runtime-inmem 0.6.13__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from langgraph_runtime_inmem import (
9
9
  store,
10
10
  )
11
11
 
12
- __version__ = "0.6.13"
12
+ __version__ = "0.8.1"
13
13
  __all__ = [
14
14
  "ops",
15
15
  "database",
@@ -34,10 +34,12 @@ if typing.TYPE_CHECKING:
34
34
  from langgraph_api.config import ThreadTTLConfig
35
35
  from langgraph_api.schema import (
36
36
  Assistant,
37
+ AssistantSelectField,
37
38
  Checkpoint,
38
39
  Config,
39
40
  Context,
40
41
  Cron,
42
+ CronSelectField,
41
43
  DeprecatedInterrupt,
42
44
  IfNotExists,
43
45
  MetadataInput,
@@ -46,14 +48,17 @@ if typing.TYPE_CHECKING:
46
48
  OnConflictBehavior,
47
49
  QueueStats,
48
50
  Run,
51
+ RunSelectField,
49
52
  RunStatus,
50
53
  StreamMode,
51
54
  Thread,
55
+ ThreadSelectField,
52
56
  ThreadStatus,
53
57
  ThreadUpdateResponse,
54
58
  )
55
59
  from langgraph_api.schema import Interrupt as InterruptSchema
56
60
  from langgraph_api.serde import Fragment
61
+ from langgraph_api.utils import AsyncConnectionProto
57
62
 
58
63
 
59
64
  logger = structlog.stdlib.get_logger(__name__)
@@ -136,6 +141,7 @@ class Assistants(Authenticated):
136
141
  offset: int,
137
142
  sort_by: str | None = None,
138
143
  sort_order: str | None = None,
144
+ select: list[AssistantSelectField] | None = None,
139
145
  ctx: Auth.types.BaseAuthContext | None = None,
140
146
  ) -> tuple[AsyncIterator[Assistant], int]:
141
147
  metadata = metadata if metadata is not None else {}
@@ -178,6 +184,7 @@ class Assistants(Authenticated):
178
184
  else:
179
185
  filtered_assistants.sort(key=lambda x: x.get(sort_by), reverse=reverse)
180
186
  else:
187
+ sort_by = "created_at"
181
188
  # Default sorting by created_at in descending order
182
189
  filtered_assistants.sort(key=lambda x: x["created_at"], reverse=True)
183
190
 
@@ -187,14 +194,21 @@ class Assistants(Authenticated):
187
194
 
188
195
  async def assistant_iterator() -> AsyncIterator[Assistant]:
189
196
  for assistant in paginated_assistants:
190
- yield assistant
197
+ if select:
198
+ # Filter to only selected fields
199
+ filtered_assistant = {
200
+ k: v for k, v in assistant.items() if k in select
201
+ }
202
+ yield filtered_assistant
203
+ else:
204
+ yield assistant
191
205
 
192
206
  return assistant_iterator(), cur
193
207
 
194
208
  @staticmethod
195
209
  async def get(
196
210
  conn: InMemConnectionProto,
197
- assistant_id: UUID,
211
+ assistant_id: UUID | str,
198
212
  ctx: Auth.types.BaseAuthContext | None = None,
199
213
  ) -> AsyncIterator[Assistant]:
200
214
  """Get an assistant by ID."""
@@ -217,7 +231,7 @@ class Assistants(Authenticated):
217
231
  @staticmethod
218
232
  async def put(
219
233
  conn: InMemConnectionProto,
220
- assistant_id: UUID,
234
+ assistant_id: UUID | str,
221
235
  *,
222
236
  graph_id: str,
223
237
  config: Config,
@@ -597,6 +611,37 @@ class Assistants(Authenticated):
597
611
 
598
612
  return _yield_versions()
599
613
 
614
+ @staticmethod
615
+ async def count(
616
+ conn: InMemConnectionProto,
617
+ *,
618
+ graph_id: str | None = None,
619
+ metadata: MetadataInput = None,
620
+ ctx: Auth.types.BaseAuthContext | None = None,
621
+ ) -> int:
622
+ """Get count of assistants."""
623
+ metadata = metadata if metadata is not None else {}
624
+ filters = await Assistants.handle_event(
625
+ ctx,
626
+ "search",
627
+ Auth.types.AssistantsSearch(
628
+ graph_id=graph_id, metadata=metadata, limit=0, offset=0
629
+ ),
630
+ )
631
+
632
+ count = 0
633
+ for assistant in conn.store["assistants"]:
634
+ if (
635
+ (not graph_id or assistant["graph_id"] == graph_id)
636
+ and (
637
+ not metadata or is_jsonb_contained(assistant["metadata"], metadata)
638
+ )
639
+ and (not filters or _check_filter_match(assistant["metadata"], filters))
640
+ ):
641
+ count += 1
642
+
643
+ return count
644
+
600
645
 
601
646
  def is_jsonb_contained(superset: dict[str, Any], subset: dict[str, Any]) -> bool:
602
647
  """
@@ -680,7 +725,7 @@ def _patch_interrupt(
680
725
  "value": interrupt.value,
681
726
  "resumable": interrupt.resumable,
682
727
  "ns": interrupt.ns,
683
- "when": interrupt.when,
728
+ "when": interrupt.when, # type: ignore[unresolved-attribute]
684
729
  }
685
730
 
686
731
 
@@ -698,6 +743,7 @@ class Threads(Authenticated):
698
743
  offset: int,
699
744
  sort_by: str | None = None,
700
745
  sort_order: str | None = None,
746
+ select: list[ThreadSelectField] | None = None,
701
747
  ctx: Auth.types.BaseAuthContext | None = None,
702
748
  ) -> tuple[AsyncIterator[Thread], int]:
703
749
  threads = conn.store["threads"]
@@ -747,6 +793,7 @@ class Threads(Authenticated):
747
793
  filtered_threads, key=lambda x: x.get(sort_by), reverse=reverse
748
794
  )
749
795
  else:
796
+ sort_by = "created_at"
750
797
  # Default sorting by created_at in descending order
751
798
  sorted_threads = sorted(
752
799
  filtered_threads, key=lambda x: x["updated_at"], reverse=True
@@ -758,7 +805,12 @@ class Threads(Authenticated):
758
805
 
759
806
  async def thread_iterator() -> AsyncIterator[Thread]:
760
807
  for thread in paginated_threads:
761
- yield thread
808
+ if select:
809
+ # Filter to only selected fields
810
+ filtered_thread = {k: v for k, v in thread.items() if k in select}
811
+ yield filtered_thread
812
+ else:
813
+ yield thread
762
814
 
763
815
  return thread_iterator(), cursor
764
816
 
@@ -822,7 +874,7 @@ class Threads(Authenticated):
822
874
  @staticmethod
823
875
  async def put(
824
876
  conn: InMemConnectionProto,
825
- thread_id: UUID,
877
+ thread_id: UUID | str,
826
878
  *,
827
879
  metadata: MetadataInput,
828
880
  if_exists: OnConflictBehavior,
@@ -1559,6 +1611,53 @@ class Threads(Authenticated):
1559
1611
 
1560
1612
  return []
1561
1613
 
1614
+ @staticmethod
1615
+ async def count(
1616
+ conn: InMemConnectionProto,
1617
+ *,
1618
+ metadata: MetadataInput = None,
1619
+ values: MetadataInput = None,
1620
+ status: ThreadStatus | None = None,
1621
+ ctx: Auth.types.BaseAuthContext | None = None,
1622
+ ) -> int:
1623
+ """Get count of threads."""
1624
+ threads = conn.store["threads"]
1625
+ metadata = metadata if metadata is not None else {}
1626
+ values = values if values is not None else {}
1627
+ filters = await Threads.handle_event(
1628
+ ctx,
1629
+ "search",
1630
+ Auth.types.ThreadsSearch(
1631
+ metadata=metadata,
1632
+ values=values,
1633
+ status=status,
1634
+ limit=0,
1635
+ offset=0,
1636
+ ),
1637
+ )
1638
+
1639
+ count = 0
1640
+ for thread in threads:
1641
+ if filters and not _check_filter_match(thread["metadata"], filters):
1642
+ continue
1643
+
1644
+ if metadata and not is_jsonb_contained(thread["metadata"], metadata):
1645
+ continue
1646
+
1647
+ if (
1648
+ values
1649
+ and "values" in thread
1650
+ and not is_jsonb_contained(thread["values"], values)
1651
+ ):
1652
+ continue
1653
+
1654
+ if status and thread.get("status") != status:
1655
+ continue
1656
+
1657
+ count += 1
1658
+
1659
+ return count
1660
+
1562
1661
 
1563
1662
  RUN_LOCK = asyncio.Lock()
1564
1663
 
@@ -1724,7 +1823,7 @@ class Runs(Authenticated):
1724
1823
 
1725
1824
  @staticmethod
1726
1825
  async def put(
1727
- conn: InMemConnectionProto,
1826
+ conn: InMemConnectionProto | AsyncConnectionProto,
1728
1827
  assistant_id: UUID,
1729
1828
  kwargs: dict,
1730
1829
  *,
@@ -1802,7 +1901,6 @@ class Runs(Authenticated):
1802
1901
  {
1803
1902
  "configurable": Runs._merge_jsonb(
1804
1903
  Runs._get_configurable(assistant["config"]),
1805
- Runs._get_configurable(config),
1806
1904
  )
1807
1905
  },
1808
1906
  ),
@@ -1846,7 +1944,6 @@ class Runs(Authenticated):
1846
1944
  "configurable": Runs._merge_jsonb(
1847
1945
  Runs._get_configurable(assistant["config"]),
1848
1946
  Runs._get_configurable(existing_thread["config"]),
1849
- Runs._get_configurable(config),
1850
1947
  )
1851
1948
  },
1852
1949
  )
@@ -2068,8 +2165,8 @@ class Runs(Authenticated):
2068
2165
 
2069
2166
  @staticmethod
2070
2167
  async def cancel(
2071
- conn: InMemConnectionProto,
2072
- run_ids: Sequence[UUID] | None = None,
2168
+ conn: InMemConnectionProto | AsyncConnectionProto,
2169
+ run_ids: Sequence[UUID | str] | None = None,
2073
2170
  *,
2074
2171
  action: Literal["interrupt", "rollback"] = "interrupt",
2075
2172
  thread_id: UUID | None = None,
@@ -2233,6 +2330,7 @@ class Runs(Authenticated):
2233
2330
  limit: int = 10,
2234
2331
  offset: int = 0,
2235
2332
  status: RunStatus | None = None,
2333
+ select: list[RunSelectField] | None = None,
2236
2334
  ctx: Auth.types.BaseAuthContext | None = None,
2237
2335
  ) -> AsyncIterator[Run]:
2238
2336
  """List all runs by thread."""
@@ -2260,7 +2358,12 @@ class Runs(Authenticated):
2260
2358
 
2261
2359
  async def _return():
2262
2360
  for run in sliced_runs:
2263
- yield run
2361
+ if select:
2362
+ # Filter to only selected fields
2363
+ filtered_run = {k: v for k, v in run.items() if k in select}
2364
+ yield filtered_run
2365
+ else:
2366
+ yield run
2264
2367
 
2265
2368
  return _return()
2266
2369
 
@@ -2432,7 +2535,7 @@ class Runs(Authenticated):
2432
2535
 
2433
2536
  @staticmethod
2434
2537
  async def publish(
2435
- run_id: UUID,
2538
+ run_id: UUID | str,
2436
2539
  event: str,
2437
2540
  message: bytes,
2438
2541
  *,
@@ -2533,10 +2636,24 @@ class Crons:
2533
2636
  thread_id: UUID | None,
2534
2637
  limit: int,
2535
2638
  offset: int,
2639
+ select: list[CronSelectField] | None = None,
2536
2640
  ctx: Auth.types.BaseAuthContext | None = None,
2537
- ) -> AsyncIterator[Cron]:
2641
+ sort_by: str | None = None,
2642
+ sort_order: Literal["asc", "desc"] | None = None,
2643
+ ) -> tuple[AsyncIterator[Cron], int]:
2538
2644
  raise NotImplementedError
2539
2645
 
2646
+ @staticmethod
2647
+ async def count(
2648
+ conn: InMemConnectionProto,
2649
+ *,
2650
+ assistant_id: UUID | None = None,
2651
+ thread_id: UUID | None = None,
2652
+ ctx: Auth.types.BaseAuthContext | None = None,
2653
+ ) -> int:
2654
+ """Get count of crons."""
2655
+ raise NotImplementedError("The in-mem server does not implement Crons.")
2656
+
2540
2657
 
2541
2658
  async def cancel_run(
2542
2659
  thread_id: UUID, run_id: UUID, ctx: Auth.types.BaseAuthContext | None = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-runtime-inmem
3
- Version: 0.6.13
3
+ Version: 0.8.1
4
4
  Summary: Inmem implementation for the LangGraph API server.
5
5
  Author-email: Will Fu-Hinthorn <will@langchain.dev>
6
6
  License: Elastic-2.0
@@ -1,13 +1,13 @@
1
- langgraph_runtime_inmem/__init__.py,sha256=mPj20Y7jJEr9XATlg7oD3Q0F4BJdwWnfN7JMijIiRHc,311
1
+ langgraph_runtime_inmem/__init__.py,sha256=HSPTGiVB69XNTkwTDcmNR5AmVYBGvgbwoW_RmOWec8g,310
2
2
  langgraph_runtime_inmem/checkpoint.py,sha256=nc1G8DqVdIu-ibjKTqXfbPfMbAsKjPObKqegrSzo6Po,4432
3
3
  langgraph_runtime_inmem/database.py,sha256=G_6L2khpRDSpS2Vs_SujzHayODcwG5V2IhFP7LLBXgw,6349
4
4
  langgraph_runtime_inmem/inmem_stream.py,sha256=UWk1srLF44HZPPbRdArGGhsy0MY0UOJKSIxBSO7Hosc,5138
5
5
  langgraph_runtime_inmem/lifespan.py,sha256=t0w2MX2dGxe8yNtSX97Z-d2pFpllSLS4s1rh2GJDw5M,3557
6
6
  langgraph_runtime_inmem/metrics.py,sha256=HhO0RC2bMDTDyGBNvnd2ooLebLA8P1u5oq978Kp_nAA,392
7
- langgraph_runtime_inmem/ops.py,sha256=ffxYACydXPmWKt89AY22M1j8s75JbFMtzXI1SKC3R50,93949
7
+ langgraph_runtime_inmem/ops.py,sha256=rtO-dgPQnJEymF_yvzxpynUNse-lq1flb0B112pg6pk,97940
8
8
  langgraph_runtime_inmem/queue.py,sha256=33qfFKPhQicZ1qiibllYb-bTFzUNSN2c4bffPACP5es,9952
9
9
  langgraph_runtime_inmem/retry.py,sha256=XmldOP4e_H5s264CagJRVnQMDFcEJR_dldVR1Hm5XvM,763
10
10
  langgraph_runtime_inmem/store.py,sha256=rTfL1JJvd-j4xjTrL8qDcynaWF6gUJ9-GDVwH0NBD_I,3506
11
- langgraph_runtime_inmem-0.6.13.dist-info/METADATA,sha256=GUxT1ClUkx_1b0nR-tiuetNRDfSoB5iejMXS2ieX9HA,566
12
- langgraph_runtime_inmem-0.6.13.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- langgraph_runtime_inmem-0.6.13.dist-info/RECORD,,
11
+ langgraph_runtime_inmem-0.8.1.dist-info/METADATA,sha256=WfRHwBTIUfr1Ux1T1gYgGE5QojW_83T91KELEwub2Bg,565
12
+ langgraph_runtime_inmem-0.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ langgraph_runtime_inmem-0.8.1.dist-info/RECORD,,