langgraph-api 0.4.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. langgraph_api/__init__.py +1 -1
  2. langgraph_api/api/__init__.py +111 -51
  3. langgraph_api/api/a2a.py +1610 -0
  4. langgraph_api/api/assistants.py +212 -89
  5. langgraph_api/api/mcp.py +3 -3
  6. langgraph_api/api/meta.py +52 -28
  7. langgraph_api/api/openapi.py +27 -17
  8. langgraph_api/api/profile.py +108 -0
  9. langgraph_api/api/runs.py +342 -195
  10. langgraph_api/api/store.py +19 -2
  11. langgraph_api/api/threads.py +209 -27
  12. langgraph_api/asgi_transport.py +14 -9
  13. langgraph_api/asyncio.py +14 -4
  14. langgraph_api/auth/custom.py +52 -37
  15. langgraph_api/auth/langsmith/backend.py +4 -3
  16. langgraph_api/auth/langsmith/client.py +13 -8
  17. langgraph_api/cli.py +230 -133
  18. langgraph_api/command.py +5 -3
  19. langgraph_api/config/__init__.py +532 -0
  20. langgraph_api/config/_parse.py +58 -0
  21. langgraph_api/config/schemas.py +431 -0
  22. langgraph_api/cron_scheduler.py +17 -1
  23. langgraph_api/encryption/__init__.py +15 -0
  24. langgraph_api/encryption/aes_json.py +158 -0
  25. langgraph_api/encryption/context.py +35 -0
  26. langgraph_api/encryption/custom.py +280 -0
  27. langgraph_api/encryption/middleware.py +632 -0
  28. langgraph_api/encryption/shared.py +63 -0
  29. langgraph_api/errors.py +12 -1
  30. langgraph_api/executor_entrypoint.py +11 -6
  31. langgraph_api/feature_flags.py +29 -0
  32. langgraph_api/graph.py +176 -76
  33. langgraph_api/grpc/client.py +313 -0
  34. langgraph_api/grpc/config_conversion.py +231 -0
  35. langgraph_api/grpc/generated/__init__.py +29 -0
  36. langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
  37. langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
  38. langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
  39. langgraph_api/grpc/generated/core_api_pb2.py +216 -0
  40. langgraph_api/grpc/generated/core_api_pb2.pyi +905 -0
  41. langgraph_api/grpc/generated/core_api_pb2_grpc.py +1621 -0
  42. langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
  43. langgraph_api/grpc/generated/engine_common_pb2.pyi +722 -0
  44. langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
  45. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
  46. langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
  47. langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
  48. langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
  49. langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
  50. langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
  51. langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
  52. langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
  53. langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
  54. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
  55. langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
  56. langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
  57. langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
  58. langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
  59. langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
  60. langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
  61. langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
  62. langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
  63. langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
  64. langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
  65. langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
  66. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
  67. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
  68. langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
  69. langgraph_api/grpc/generated/errors_pb2.py +39 -0
  70. langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
  71. langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
  72. langgraph_api/grpc/ops/__init__.py +370 -0
  73. langgraph_api/grpc/ops/assistants.py +424 -0
  74. langgraph_api/grpc/ops/runs.py +792 -0
  75. langgraph_api/grpc/ops/threads.py +1013 -0
  76. langgraph_api/http.py +16 -5
  77. langgraph_api/http_metrics.py +15 -35
  78. langgraph_api/http_metrics_utils.py +38 -0
  79. langgraph_api/js/build.mts +1 -1
  80. langgraph_api/js/client.http.mts +13 -7
  81. langgraph_api/js/client.mts +2 -5
  82. langgraph_api/js/package.json +29 -28
  83. langgraph_api/js/remote.py +56 -30
  84. langgraph_api/js/src/graph.mts +20 -0
  85. langgraph_api/js/sse.py +2 -2
  86. langgraph_api/js/ui.py +1 -1
  87. langgraph_api/js/yarn.lock +1204 -1006
  88. langgraph_api/logging.py +29 -2
  89. langgraph_api/metadata.py +99 -28
  90. langgraph_api/middleware/http_logger.py +7 -2
  91. langgraph_api/middleware/private_network.py +7 -7
  92. langgraph_api/models/run.py +54 -93
  93. langgraph_api/otel_context.py +205 -0
  94. langgraph_api/patch.py +5 -3
  95. langgraph_api/queue_entrypoint.py +154 -65
  96. langgraph_api/route.py +47 -5
  97. langgraph_api/schema.py +88 -10
  98. langgraph_api/self_hosted_logs.py +124 -0
  99. langgraph_api/self_hosted_metrics.py +450 -0
  100. langgraph_api/serde.py +79 -37
  101. langgraph_api/server.py +138 -60
  102. langgraph_api/state.py +4 -3
  103. langgraph_api/store.py +25 -16
  104. langgraph_api/stream.py +80 -29
  105. langgraph_api/thread_ttl.py +31 -13
  106. langgraph_api/timing/__init__.py +25 -0
  107. langgraph_api/timing/profiler.py +200 -0
  108. langgraph_api/timing/timer.py +318 -0
  109. langgraph_api/utils/__init__.py +53 -8
  110. langgraph_api/utils/cache.py +47 -10
  111. langgraph_api/utils/config.py +2 -1
  112. langgraph_api/utils/errors.py +77 -0
  113. langgraph_api/utils/future.py +10 -6
  114. langgraph_api/utils/headers.py +76 -2
  115. langgraph_api/utils/retriable_client.py +74 -0
  116. langgraph_api/utils/stream_codec.py +315 -0
  117. langgraph_api/utils/uuids.py +29 -62
  118. langgraph_api/validation.py +9 -0
  119. langgraph_api/webhook.py +120 -6
  120. langgraph_api/worker.py +55 -24
  121. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +16 -8
  122. langgraph_api-0.7.3.dist-info/RECORD +168 -0
  123. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
  124. langgraph_runtime/__init__.py +1 -0
  125. langgraph_runtime/routes.py +11 -0
  126. logging.json +1 -3
  127. openapi.json +839 -478
  128. langgraph_api/config.py +0 -387
  129. langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
  130. langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
  131. langgraph_api/js/package-lock.json +0 -3308
  132. langgraph_api-0.4.1.dist-info/RECORD +0 -107
  133. /langgraph_api/{utils.py → grpc/__init__.py} +0 -0
  134. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
  135. {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1013 @@
1
+ """gRPC-based threads operations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from contextlib import AsyncExitStack
7
+ from datetime import UTC
8
+ from http import HTTPStatus
9
+ from typing import TYPE_CHECKING, Any, Literal
10
+ from uuid import UUID
11
+
12
+ if TYPE_CHECKING:
13
+ from collections.abc import Sequence
14
+
15
+ import orjson
16
+ import structlog
17
+ from langgraph.checkpoint.serde.jsonplus import _msgpack_ext_hook_to_json
18
+ from langgraph.types import StateSnapshot, StateUpdate
19
+ from langgraph_sdk import Auth
20
+ from starlette.exceptions import HTTPException
21
+
22
+ from langgraph_api import store as api_store
23
+ from langgraph_api.command import map_cmd
24
+ from langgraph_api.graph import get_graph
25
+ from langgraph_api.grpc.client import get_shared_client
26
+ from langgraph_api.grpc.generated import checkpointer_pb2
27
+ from langgraph_api.grpc.generated import core_api_pb2 as pb
28
+ from langgraph_api.grpc.generated import enum_thread_status_pb2 as enum_thread_status
29
+ from langgraph_api.grpc.ops import (
30
+ Authenticated,
31
+ _map_sort_order,
32
+ grpc_error_guard,
33
+ map_if_exists,
34
+ )
35
+ from langgraph_api.grpc.ops.runs import Runs
36
+ from langgraph_api.schema import ThreadUpdateResponse
37
+ from langgraph_api.serde import json_dumpb, json_dumpb_optional, json_loads
38
+ from langgraph_api.state import patch_interrupt, state_snapshot_to_thread_state
39
+ from langgraph_api.utils import fetchone
40
+ from langgraph_runtime.checkpoint import Checkpointer
41
+
42
+ if TYPE_CHECKING:
43
+ from collections.abc import AsyncIterator
44
+
45
+ from langgraph_api.schema import (
46
+ MetadataInput,
47
+ OnConflictBehavior,
48
+ Thread,
49
+ ThreadSelectField,
50
+ ThreadStatus,
51
+ )
52
+
53
+ logger = structlog.stdlib.get_logger(__name__)
54
+
55
+
56
+ def _snapshot_defaults():
57
+ """Support older versions of langgraph that don't have interrupts field."""
58
+ if not hasattr(StateSnapshot, "interrupts"):
59
+ return {}
60
+ return {"interrupts": tuple()}
61
+
62
+
63
+ THREAD_STATUS_TO_PB = {
64
+ "idle": enum_thread_status.idle,
65
+ "busy": enum_thread_status.busy,
66
+ "interrupted": enum_thread_status.interrupted,
67
+ "error": enum_thread_status.error,
68
+ }
69
+
70
+ THREAD_STATUS_FROM_PB = {v: k for k, v in THREAD_STATUS_TO_PB.items()}
71
+
72
+ THREAD_SORT_BY_MAP = {
73
+ "unspecified": pb.ThreadsSortBy.THREADS_SORT_BY_UNSPECIFIED, # for enum completeness, never sent
74
+ "thread_id": pb.ThreadsSortBy.THREADS_SORT_BY_THREAD_ID,
75
+ "created_at": pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT,
76
+ "updated_at": pb.ThreadsSortBy.THREADS_SORT_BY_UPDATED_AT,
77
+ "status": pb.ThreadsSortBy.THREADS_SORT_BY_STATUS,
78
+ }
79
+
80
+ THREAD_TTL_STRATEGY_MAP = {
81
+ "delete": pb.ThreadTTLStrategy.THREAD_TTL_STRATEGY_DELETE,
82
+ "keep_latest": pb.ThreadTTLStrategy.THREAD_TTL_STRATEGY_KEEP_LATEST,
83
+ }
84
+
85
+
86
+ def _map_thread_status(
87
+ status: ThreadStatus | None,
88
+ ) -> enum_thread_status.ThreadStatus | None:
89
+ if status is None:
90
+ return None
91
+ return THREAD_STATUS_TO_PB.get(status)
92
+
93
+
94
+ def _map_threads_sort_by(sort_by: str | None) -> pb.ThreadsSortBy:
95
+ if not sort_by or sort_by.lower() == "unspecified":
96
+ return pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
97
+ return THREAD_SORT_BY_MAP.get(
98
+ sort_by.lower(), pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
99
+ )
100
+
101
+
102
+ def _map_thread_ttl(ttl: dict[str, Any] | None) -> pb.ThreadTTLConfig | None:
103
+ if not ttl:
104
+ return None
105
+
106
+ config = pb.ThreadTTLConfig()
107
+ strategy = ttl.get("strategy")
108
+ if strategy:
109
+ mapped_strategy = THREAD_TTL_STRATEGY_MAP.get(str(strategy).lower())
110
+ if mapped_strategy is None:
111
+ raise HTTPException(
112
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
113
+ detail=f"Invalid thread TTL strategy: {strategy}. Expected one of {list(THREAD_TTL_STRATEGY_MAP.keys())}",
114
+ )
115
+ config.strategy = mapped_strategy
116
+
117
+ ttl_value = ttl.get("ttl", ttl.get("default_ttl"))
118
+ if ttl_value is not None:
119
+ config.default_ttl = float(ttl_value)
120
+
121
+ sweep_interval = ttl.get("sweep_interval_minutes")
122
+ if sweep_interval is not None:
123
+ config.sweep_interval_minutes = int(sweep_interval)
124
+
125
+ # Note: sweep_limit is a server-side configuration for the TTL sweep loop,
126
+ # not a per-thread setting, so we don't send it via gRPC
127
+
128
+ return config
129
+
130
+
131
+ def fragment_to_value(fragment: pb.Fragment | None) -> Any:
132
+ if fragment is None or not fragment.value or fragment.value == b"{}":
133
+ return None
134
+ try:
135
+ return json_loads(fragment.value)
136
+ except orjson.JSONDecodeError:
137
+ logger.warning("Failed to decode fragment", fragment=fragment.value)
138
+ return None
139
+
140
+
141
+ def _proto_interrupts_to_dict(
142
+ interrupts_map: dict[str, pb.Interrupts],
143
+ ) -> dict[str, list[dict[str, Any]]]:
144
+ out: dict[str, list[dict[str, Any]]] = {}
145
+ for key, interrupts in interrupts_map.items():
146
+ entries: list[dict[str, Any]] = []
147
+ for interrupt in interrupts.interrupts:
148
+ entry: dict[str, Any] = {
149
+ "id": interrupt.id or None,
150
+ "value": json_loads(interrupt.value),
151
+ }
152
+ if interrupt.when:
153
+ entry["when"] = interrupt.when
154
+ if interrupt.resumable:
155
+ entry["resumable"] = interrupt.resumable
156
+ if interrupt.ns:
157
+ entry["ns"] = list(interrupt.ns)
158
+ entries.append(entry)
159
+ out[key] = entries
160
+ return out
161
+
162
+
163
+ def proto_to_thread(proto_thread: pb.Thread) -> Thread:
164
+ """Convert protobuf Thread to API dictionary format."""
165
+ thread_id = (
166
+ UUID(proto_thread.thread_id.value)
167
+ if proto_thread.HasField("thread_id")
168
+ else None
169
+ )
170
+ if thread_id is None:
171
+ raise HTTPException(
172
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
173
+ detail="Thread response missing thread_id",
174
+ )
175
+ created_at = (
176
+ proto_thread.created_at.ToDatetime(tzinfo=UTC)
177
+ if proto_thread.HasField("created_at")
178
+ else None
179
+ )
180
+ updated_at = (
181
+ proto_thread.updated_at.ToDatetime(tzinfo=UTC)
182
+ if proto_thread.HasField("updated_at")
183
+ else None
184
+ )
185
+ status = THREAD_STATUS_FROM_PB.get(proto_thread.status, "idle")
186
+
187
+ return {
188
+ "thread_id": thread_id,
189
+ "created_at": created_at,
190
+ "updated_at": updated_at,
191
+ # Unlike other fields, metadata should never be `None`.
192
+ "metadata": fragment_to_value(proto_thread.metadata) or {},
193
+ "config": fragment_to_value(proto_thread.config) or {},
194
+ "error": fragment_to_value(proto_thread.error),
195
+ "status": status, # type: ignore[typeddict-item]
196
+ "values": fragment_to_value(proto_thread.values),
197
+ "interrupts": _proto_interrupts_to_dict(dict(proto_thread.interrupts)),
198
+ }
199
+
200
+
201
+ def _filter_thread_fields(
202
+ thread: Thread, select: list[ThreadSelectField] | None
203
+ ) -> dict[str, Any]:
204
+ if not select:
205
+ return dict(thread)
206
+ return {field: thread[field] for field in select if field in thread}
207
+
208
+
209
+ def _normalize_uuid(value: UUID | str) -> str:
210
+ return str(value) if isinstance(value, UUID) else str(UUID(str(value)))
211
+
212
+
213
+ def _thread_status_checkpoint_to_proto(
214
+ checkpoint: dict[str, Any] | None,
215
+ ) -> pb.ThreadStatusCheckpoint | None:
216
+ """Convert checkpoint dict to ThreadStatusCheckpoint proto."""
217
+ if checkpoint is None:
218
+ return None
219
+
220
+ # Compute interrupts map from tasks (same logic as storage_postgres/ops.py)
221
+ interrupts = {
222
+ t["id"]: [patch_interrupt(i) for i in t["interrupts"]]
223
+ for t in checkpoint.get("tasks", [])
224
+ if t.get("interrupts")
225
+ }
226
+
227
+ return pb.ThreadStatusCheckpoint(
228
+ values_json=json_dumpb(checkpoint.get("values", {})),
229
+ next=list(checkpoint.get("next", [])),
230
+ interrupts_json=json_dumpb(interrupts),
231
+ )
232
+
233
+
234
+ def _json_contains(container: Any, subset: dict[str, Any]) -> bool:
235
+ if not subset:
236
+ return True
237
+ if not isinstance(container, dict):
238
+ return False
239
+ for key, value in subset.items():
240
+ if key not in container:
241
+ return False
242
+ candidate = container[key]
243
+ if isinstance(value, dict):
244
+ if not _json_contains(candidate, value):
245
+ return False
246
+ else:
247
+ if candidate != value:
248
+ return False
249
+ return True
250
+
251
+
252
+ @grpc_error_guard
253
+ class Threads(Authenticated):
254
+ """gRPC-based threads operations."""
255
+
256
+ resource = "threads"
257
+
258
+ @staticmethod
259
+ async def search(
260
+ conn, # Not used in gRPC implementation
261
+ *,
262
+ ids: list[str] | list[UUID] | None = None,
263
+ metadata: MetadataInput,
264
+ values: MetadataInput,
265
+ status: ThreadStatus | None,
266
+ limit: int,
267
+ offset: int,
268
+ sort_by: str | None = None,
269
+ sort_order: str | None = None,
270
+ select: list[ThreadSelectField] | None = None,
271
+ ctx: Any = None,
272
+ ) -> tuple[AsyncIterator[Thread], int | None]: # type: ignore[return-value]
273
+ metadata = metadata or {}
274
+ values = values or {}
275
+
276
+ auth_filters = await Threads.handle_event(
277
+ ctx,
278
+ "search",
279
+ {
280
+ "metadata": metadata,
281
+ "values": values,
282
+ "status": status,
283
+ "limit": limit,
284
+ "offset": offset,
285
+ },
286
+ )
287
+
288
+ if ids:
289
+ normalized_ids = [_normalize_uuid(thread_id) for thread_id in ids]
290
+ threads: list[Thread] = []
291
+ client = await get_shared_client()
292
+ for thread_id in normalized_ids:
293
+ request = pb.GetThreadRequest(
294
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
295
+ filters=auth_filters,
296
+ )
297
+ response = await client.threads.Get(request)
298
+ thread = proto_to_thread(response)
299
+
300
+ if status and thread["status"] != status:
301
+ continue
302
+ if metadata and not _json_contains(thread["metadata"], metadata):
303
+ continue
304
+ if values and not _json_contains(thread.get("values") or {}, values):
305
+ continue
306
+ threads.append(thread)
307
+
308
+ total = len(threads)
309
+ paginated = threads[offset : offset + limit]
310
+ cursor = offset + limit if total > offset + limit else None
311
+
312
+ async def generate_results():
313
+ for thread in paginated:
314
+ yield _filter_thread_fields(thread, select)
315
+
316
+ return generate_results(), cursor
317
+
318
+ request_kwargs: dict[str, Any] = {
319
+ "filters": auth_filters,
320
+ "metadata_json": json_dumpb_optional(metadata),
321
+ "values_json": json_dumpb_optional(values),
322
+ "limit": limit,
323
+ "offset": offset,
324
+ "sort_by": _map_threads_sort_by(sort_by),
325
+ "sort_order": _map_sort_order(sort_order),
326
+ }
327
+
328
+ if status:
329
+ mapped_status = _map_thread_status(status)
330
+ if mapped_status is None:
331
+ raise HTTPException(
332
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
333
+ detail=f"Invalid thread status: {status}",
334
+ )
335
+ request_kwargs["status"] = mapped_status
336
+
337
+ if select:
338
+ request_kwargs["select"] = select
339
+
340
+ client = await get_shared_client()
341
+ response = await client.threads.Search(
342
+ pb.SearchThreadsRequest(**request_kwargs)
343
+ )
344
+
345
+ threads = [proto_to_thread(thread) for thread in response.threads]
346
+ cursor = offset + limit if len(threads) == limit else None
347
+
348
+ async def generate_results():
349
+ for thread in threads:
350
+ yield _filter_thread_fields(thread, select)
351
+
352
+ return generate_results(), cursor
353
+
354
+ @staticmethod
355
+ async def count(
356
+ conn, # Not used
357
+ *,
358
+ metadata: MetadataInput,
359
+ values: MetadataInput,
360
+ status: ThreadStatus | None,
361
+ ctx: Any = None,
362
+ ) -> int: # type: ignore[override]
363
+ metadata = metadata or {}
364
+ values = values or {}
365
+
366
+ auth_filters = await Threads.handle_event(
367
+ ctx,
368
+ "search",
369
+ {
370
+ "metadata": metadata,
371
+ "values": values,
372
+ "status": status,
373
+ },
374
+ )
375
+
376
+ request_kwargs: dict[str, Any] = {
377
+ "filters": auth_filters,
378
+ "metadata_json": json_dumpb_optional(metadata),
379
+ "values_json": json_dumpb_optional(values),
380
+ }
381
+ if status:
382
+ mapped_status = _map_thread_status(status)
383
+ if mapped_status is None:
384
+ raise HTTPException(
385
+ status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
386
+ detail=f"Invalid thread status: {status}",
387
+ )
388
+ request_kwargs["status"] = mapped_status
389
+
390
+ client = await get_shared_client()
391
+ response = await client.threads.Count(pb.CountThreadsRequest(**request_kwargs))
392
+
393
+ return int(response.count)
394
+
395
+ @staticmethod
396
+ async def get(
397
+ conn, # Not used
398
+ thread_id: UUID | str,
399
+ ctx: Auth.types.BaseAuthContext | None = None,
400
+ filters: Auth.types.FilterType | None = None,
401
+ include_ttl: bool = False,
402
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
403
+ """Get a thread by ID.
404
+
405
+ Args:
406
+ conn: Not used (required for interface compatibility)
407
+ thread_id: Thread ID
408
+ ctx: Auth context
409
+ filters: Additional auth filters to merge with auth context filters
410
+ include_ttl: Not yet supported in gRPC - parameter ignored.
411
+ """
412
+ auth_filters = await Threads.handle_event(
413
+ ctx, "read", {"thread_id": str(thread_id)}
414
+ )
415
+ # Merge auth filters with any additional parent filters provided.
416
+ # (Parent filters take precedence.)
417
+ if filters:
418
+ auth_filters = {**(auth_filters or {}), **(filters or {})}
419
+
420
+ request = pb.GetThreadRequest(
421
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
422
+ filters=auth_filters,
423
+ )
424
+ client = await get_shared_client()
425
+ response = await client.threads.Get(request)
426
+
427
+ thread = proto_to_thread(response)
428
+
429
+ async def generate_result():
430
+ yield thread
431
+
432
+ return generate_result()
433
+
434
+ @staticmethod
435
+ async def put(
436
+ conn, # Not used
437
+ thread_id: UUID | str,
438
+ *,
439
+ metadata: MetadataInput,
440
+ if_exists: OnConflictBehavior,
441
+ ttl: dict[str, Any] | None = None,
442
+ ctx: Any = None,
443
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
444
+ metadata = metadata or {}
445
+
446
+ auth_filters = await Threads.handle_event(
447
+ ctx,
448
+ "create",
449
+ {
450
+ "thread_id": str(thread_id),
451
+ "metadata": metadata,
452
+ "if_exists": if_exists,
453
+ },
454
+ )
455
+
456
+ request = pb.CreateThreadRequest(
457
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
458
+ filters=auth_filters,
459
+ if_exists=map_if_exists(if_exists),
460
+ metadata_json=json_dumpb_optional(metadata),
461
+ )
462
+ ttl_config = _map_thread_ttl(ttl)
463
+ if ttl_config is not None:
464
+ request.ttl.CopyFrom(ttl_config)
465
+
466
+ client = await get_shared_client()
467
+ response = await client.threads.Create(request)
468
+ thread = proto_to_thread(response)
469
+
470
+ async def generate_result():
471
+ yield thread
472
+
473
+ return generate_result()
474
+
475
+ @staticmethod
476
+ async def patch(
477
+ conn, # Not used
478
+ thread_id: UUID | str,
479
+ *,
480
+ metadata: MetadataInput,
481
+ ttl: dict[str, Any] | None = None,
482
+ ctx: Any = None,
483
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
484
+ metadata = metadata or {}
485
+
486
+ auth_filters = await Threads.handle_event(
487
+ ctx,
488
+ "update",
489
+ {
490
+ "thread_id": str(thread_id),
491
+ "metadata": metadata,
492
+ },
493
+ )
494
+
495
+ request = pb.PatchThreadRequest(
496
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
497
+ filters=auth_filters,
498
+ metadata_json=json_dumpb_optional(metadata),
499
+ )
500
+
501
+ ttl_config = _map_thread_ttl(ttl)
502
+ if ttl_config is not None:
503
+ request.ttl.CopyFrom(ttl_config)
504
+
505
+ client = await get_shared_client()
506
+ response = await client.threads.Patch(request)
507
+
508
+ thread = proto_to_thread(response)
509
+
510
+ async def generate_result():
511
+ yield thread
512
+
513
+ return generate_result()
514
+
515
+ @staticmethod
516
+ async def delete(
517
+ conn, # Not used
518
+ thread_id: UUID | str,
519
+ ctx: Any = None,
520
+ ) -> AsyncIterator[UUID]: # type: ignore[return-value]
521
+ auth_filters = await Threads.handle_event(
522
+ ctx,
523
+ "delete",
524
+ {
525
+ "thread_id": str(thread_id),
526
+ },
527
+ )
528
+
529
+ request = pb.DeleteThreadRequest(
530
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
531
+ filters=auth_filters,
532
+ )
533
+
534
+ client = await get_shared_client()
535
+ response = await client.threads.Delete(request)
536
+
537
+ deleted_id = UUID(response.value)
538
+
539
+ async def generate_result():
540
+ yield deleted_id
541
+
542
+ return generate_result()
543
+
544
+ @staticmethod
545
+ async def prune(
546
+ thread_ids: Sequence[str] | Sequence[UUID],
547
+ strategy: Literal["delete", "keep_latest"] = "delete",
548
+ batch_size: int = 100,
549
+ ctx: Any = None,
550
+ ) -> int:
551
+ """Prune threads via gRPC.
552
+
553
+ Args:
554
+ thread_ids: List of thread IDs to prune
555
+ strategy: "delete" to remove entirely, "keep_latest" to prune checkpoints
556
+ batch_size: Batch size for operations
557
+ ctx: Auth context for permission checks
558
+
559
+ Returns:
560
+ Number of threads successfully pruned
561
+ """
562
+
563
+ if not thread_ids:
564
+ return 0
565
+
566
+ str_ids = [str(tid) for tid in thread_ids]
567
+ client = await get_shared_client()
568
+
569
+ # Validate delete authorization for all threads before pruning.
570
+ # Auth filters are based on user/action, so we only need to get them once.
571
+ auth_filters = await Threads.handle_event(
572
+ ctx,
573
+ "delete",
574
+ {"thread_ids": str_ids},
575
+ )
576
+
577
+ # Only validate access if auth filters are present
578
+ if auth_filters:
579
+
580
+ async def validate_thread_access(thread_id: str) -> None:
581
+ request = pb.GetThreadRequest(
582
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
583
+ filters=auth_filters,
584
+ )
585
+ await client.threads.Get(request)
586
+
587
+ await asyncio.gather(*[validate_thread_access(tid) for tid in str_ids])
588
+
589
+ if strategy == "delete":
590
+ strategy_proto = checkpointer_pb2.PruneRequest.PruneStrategy.DELETE_ALL
591
+ else:
592
+ strategy_proto = checkpointer_pb2.PruneRequest.PruneStrategy.KEEP_LATEST
593
+ stub = client.checkpointer
594
+
595
+ processed = 0
596
+ for i in range(0, len(str_ids), batch_size):
597
+ batch = str_ids[i : i + batch_size]
598
+ try:
599
+ request = checkpointer_pb2.PruneRequest(
600
+ thread_ids=batch,
601
+ strategy=strategy_proto,
602
+ )
603
+ await stub.Prune(request)
604
+ processed += len(batch)
605
+ except Exception:
606
+ await logger.aexception("Failed to prune thread. Skipping batch.")
607
+ pass
608
+
609
+ return processed
610
+
611
+ @staticmethod
612
+ async def copy(
613
+ conn, # Not used
614
+ thread_id: UUID | str,
615
+ ctx: Any = None,
616
+ ) -> AsyncIterator[Thread]: # type: ignore[return-value]
617
+ auth_filters = await Threads.handle_event(
618
+ ctx,
619
+ "read",
620
+ {
621
+ "thread_id": str(thread_id),
622
+ },
623
+ )
624
+ # Validate that the user also has create permissions
625
+ # Filters will be the same as the read filters, so we can toss these
626
+ await Threads.handle_event(
627
+ ctx,
628
+ "create",
629
+ {
630
+ "thread_id": str(thread_id),
631
+ },
632
+ )
633
+
634
+ request = pb.CopyThreadRequest(
635
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
636
+ filters=auth_filters,
637
+ )
638
+
639
+ client = await get_shared_client()
640
+ response = await client.threads.Copy(request)
641
+
642
+ thread = proto_to_thread(response)
643
+
644
+ async def generate_result():
645
+ yield thread
646
+
647
+ return generate_result()
648
+
649
+ @staticmethod
650
+ async def set_status(
651
+ conn, # Not used in gRPC implementation
652
+ thread_id: UUID | str,
653
+ checkpoint: dict[str, Any] | None,
654
+ exception: BaseException | dict[str, Any] | None,
655
+ expected_status: ThreadStatus | Sequence[ThreadStatus] | None = None,
656
+ ) -> None:
657
+ """Set thread status via gRPC.
658
+
659
+ Args:
660
+ conn: Not used (required for interface compatibility)
661
+ thread_id: Thread ID
662
+ checkpoint: Checkpoint payload containing values, next, tasks, etc.
663
+ exception: Exception to store on thread (BaseException or serialized dict)
664
+ expected_status: Expected current status(es) for optimistic locking
665
+ """
666
+ request_kwargs: dict[str, Any] = {
667
+ "thread_id": pb.UUID(value=_normalize_uuid(thread_id)),
668
+ }
669
+
670
+ # Map checkpoint to proto
671
+ checkpoint_proto = _thread_status_checkpoint_to_proto(checkpoint)
672
+ if checkpoint_proto is not None:
673
+ request_kwargs["checkpoint"] = checkpoint_proto
674
+
675
+ # Map exception to JSON bytes
676
+ if exception is not None:
677
+ if isinstance(exception, BaseException):
678
+ exception_dict = {
679
+ "type": type(exception).__name__,
680
+ "message": str(exception),
681
+ }
682
+ else:
683
+ exception_dict = exception
684
+ request_kwargs["exception_json"] = json_dumpb(exception_dict)
685
+
686
+ # Map expected_status to enum values
687
+ if expected_status:
688
+ if isinstance(expected_status, str):
689
+ expected_status = [expected_status]
690
+ status_enums = []
691
+ for status in expected_status:
692
+ mapped = THREAD_STATUS_TO_PB.get(status)
693
+ if mapped is not None:
694
+ status_enums.append(mapped)
695
+ if status_enums:
696
+ request_kwargs["expected_status"] = status_enums
697
+
698
+ client = await get_shared_client()
699
+ await client.threads.SetStatus(pb.SetThreadStatusRequest(**request_kwargs))
700
+
701
+ @staticmethod
702
+ async def get_graph_id(
703
+ thread_id: UUID | str,
704
+ ) -> str | None:
705
+ """Get the graph ID for the latest run in a thread."""
706
+ request = pb.GetGraphIDRequest(
707
+ thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
708
+ )
709
+
710
+ client = await get_shared_client()
711
+ response = await client.threads.GetGraphID(request)
712
+
713
+ return response.graph_id if response.graph_id else None
714
+
715
+ class State(Authenticated):
716
+ # treat this like threads resource
717
+ resource = "threads"
718
+
719
+ @staticmethod
720
+ async def get(
721
+ conn, # Still needed for checkpointer
722
+ config: dict[str, Any],
723
+ subgraphs: bool,
724
+ ctx: Any = None,
725
+ ) -> StateSnapshot:
726
+ """Get state snapshot for a thread (*internal only*, no auth)."""
727
+ checkpointer = Checkpointer(conn, unpack_hook=_msgpack_ext_hook_to_json)
728
+ thread_id = config["configurable"]["thread_id"]
729
+
730
+ async with conn.pipeline():
731
+ thread, checkpoint, graph_id = await asyncio.gather(
732
+ Threads.get(conn, thread_id, ctx=ctx),
733
+ checkpointer.aget_iter(config),
734
+ Threads.get_graph_id(thread_id),
735
+ )
736
+
737
+ thread = await fetchone(thread)
738
+ metadata = json_loads(thread["metadata"])
739
+ thread_config = json_loads(thread["config"])
740
+ thread_config = {
741
+ **thread_config,
742
+ "configurable": {
743
+ **thread_config.get("configurable", {}),
744
+ **config.get("configurable", {}),
745
+ },
746
+ }
747
+ graph_id = graph_id or metadata.get("graph_id")
748
+
749
+ if graph_id:
750
+ # format latest checkpoint for response
751
+ checkpointer.latest_iter = checkpoint
752
+ async with get_graph(
753
+ graph_id,
754
+ thread_config,
755
+ checkpointer=checkpointer,
756
+ store=(await api_store.get_store()),
757
+ ) as graph:
758
+ return await graph.aget_state(config, subgraphs=subgraphs)
759
+ else:
760
+ _kwargs: dict[str, Any] = {
761
+ "values": {},
762
+ "next": tuple(),
763
+ "config": None,
764
+ "metadata": None,
765
+ "created_at": None,
766
+ "parent_config": None,
767
+ "tasks": tuple(),
768
+ }
769
+ _kwargs.update(_snapshot_defaults())
770
+ return StateSnapshot(**_kwargs) # type: ignore[missing-argument]
771
+
772
+ @staticmethod
773
+ async def post(
774
+ conn, # Still needed for checkpointer and run count check
775
+ config: dict[str, Any],
776
+ values: Any,
777
+ as_node: str | None = None,
778
+ ctx: Any = None,
779
+ ) -> ThreadUpdateResponse:
780
+ """Update thread state."""
781
+ thread_id = UUID(config["configurable"]["thread_id"])
782
+ filters = await Threads.State.handle_event(
783
+ ctx,
784
+ "update",
785
+ Auth.types.ThreadsUpdate(thread_id=thread_id),
786
+ )
787
+
788
+ checkpointer = Checkpointer(conn, use_direct_connection=True)
789
+ async with conn.pipeline():
790
+ thread, checkpoint, graph_id, run_count = await asyncio.gather(
791
+ Threads.get(conn, thread_id, ctx=ctx, filters=filters),
792
+ checkpointer.aget_iter(config),
793
+ Threads.get_graph_id(thread_id),
794
+ Runs.count(thread_id=thread_id, statuses=["pending", "running"]),
795
+ )
796
+
797
+ thread = await fetchone(thread)
798
+ metadata = json_loads(thread["metadata"])
799
+ thread_config = json_loads(thread["config"])
800
+ graph_id = graph_id or metadata.get("graph_id")
801
+
802
+ # Check if thread is busy before allowing state update
803
+ if run_count > 0:
804
+ raise HTTPException(
805
+ status_code=409,
806
+ detail="Thread is busy with a running job. Cannot update state.",
807
+ )
808
+
809
+ if graph_id:
810
+ # update state
811
+ config["configurable"].setdefault("graph_id", graph_id)
812
+ checkpointer.latest_iter = checkpoint
813
+ async with AsyncExitStack() as stack:
814
+ graph = await stack.enter_async_context(
815
+ get_graph(
816
+ graph_id,
817
+ thread_config,
818
+ checkpointer=checkpointer,
819
+ store=(await api_store.get_store()),
820
+ is_for_execution=False,
821
+ )
822
+ )
823
+ await stack.enter_async_context(conn.transaction())
824
+ next_config = await graph.aupdate_state(
825
+ config, values, as_node=as_node
826
+ )
827
+ # update thread values
828
+ state = await Threads.State.get(
829
+ conn, config, subgraphs=False, ctx=ctx
830
+ )
831
+ await Threads.set_status(
832
+ conn,
833
+ thread_id,
834
+ state_snapshot_to_thread_state(state),
835
+ None,
836
+ # Accept if NOT busy
837
+ expected_status=("interrupted", "idle", "error"),
838
+ )
839
+
840
+ # Publish state update event
841
+ event_data = {
842
+ "state": state_snapshot_to_thread_state(state),
843
+ "thread_id": str(thread_id),
844
+ }
845
+ await Runs.Stream.publish(
846
+ "*",
847
+ "state_update",
848
+ json_dumpb(event_data),
849
+ thread_id=thread_id,
850
+ resumable=True,
851
+ )
852
+
853
+ return {
854
+ "checkpoint": next_config["configurable"],
855
+ # below are deprecated
856
+ **next_config,
857
+ "checkpoint_id": next_config["configurable"]["checkpoint_id"],
858
+ }
859
+ else:
860
+ raise HTTPException(
861
+ status_code=400,
862
+ detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
863
+ " This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
864
+ )
865
+
866
+ @staticmethod
867
+ async def bulk(
868
+ conn, # Still needed for checkpointer
869
+ config: dict[str, Any],
870
+ supersteps: Any,
871
+ ctx: Any = None,
872
+ ) -> ThreadUpdateResponse:
873
+ """Update a thread with a batch of state updates."""
874
+ thread_id = UUID(config["configurable"]["thread_id"])
875
+ filters = await Threads.State.handle_event(
876
+ ctx,
877
+ "update",
878
+ Auth.types.ThreadsUpdate(thread_id=thread_id),
879
+ )
880
+
881
+ checkpointer = Checkpointer(conn)
882
+
883
+ async with conn.pipeline():
884
+ thread, graph_id = await asyncio.gather(
885
+ Threads.get(conn, thread_id, ctx=ctx, filters=filters),
886
+ Threads.get_graph_id(config["configurable"]["thread_id"]),
887
+ )
888
+ thread = await fetchone(thread)
889
+ thread_config = json_loads(thread["config"])
890
+ metadata = json_loads(thread["metadata"])
891
+ graph_id = graph_id or metadata.get("graph_id")
892
+
893
+ if graph_id:
894
+ # update state
895
+ config["configurable"].setdefault("graph_id", graph_id)
896
+ config["configurable"].setdefault("checkpoint_ns", "")
897
+
898
+ async with AsyncExitStack() as stack:
899
+ graph = await stack.enter_async_context(
900
+ get_graph(
901
+ graph_id,
902
+ thread_config,
903
+ checkpointer=checkpointer,
904
+ store=(await api_store.get_store()),
905
+ is_for_execution=False,
906
+ )
907
+ )
908
+
909
+ await stack.enter_async_context(conn.transaction())
910
+ next_config = await graph.abulk_update_state(
911
+ config,
912
+ [
913
+ [
914
+ StateUpdate(
915
+ (
916
+ map_cmd(update.get("command"))
917
+ if update.get("command")
918
+ else update.get("values")
919
+ ),
920
+ update.get("as_node"),
921
+ )
922
+ for update in superstep.get("updates", [])
923
+ ]
924
+ for superstep in supersteps
925
+ ],
926
+ )
927
+
928
+ # update thread values
929
+ state = await Threads.State.get(
930
+ conn, config, subgraphs=False, ctx=ctx
931
+ )
932
+
933
+ await Threads.set_status(
934
+ conn,
935
+ thread_id,
936
+ state_snapshot_to_thread_state(state),
937
+ None,
938
+ )
939
+
940
+ # Publish state update event
941
+ event_data = {
942
+ "state": state_snapshot_to_thread_state(state),
943
+ "thread_id": str(thread_id),
944
+ }
945
+ await Runs.Stream.publish(
946
+ "*",
947
+ "state_update",
948
+ json_dumpb(event_data),
949
+ thread_id=thread_id,
950
+ resumable=True,
951
+ )
952
+
953
+ return ThreadUpdateResponse(checkpoint=next_config["configurable"])
954
+ else:
955
+ raise HTTPException(
956
+ status_code=400,
957
+ detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
958
+ " This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
959
+ )
960
+
961
+ @staticmethod
962
+ async def list(
963
+ conn, # Still needed for checkpointer
964
+ *,
965
+ config: dict[str, Any],
966
+ limit: int = 1,
967
+ before: Any = None,
968
+ metadata: Any = None,
969
+ ctx: Any = None,
970
+ ) -> list[StateSnapshot]:
971
+ """Get the history of a thread."""
972
+ async with conn.pipeline():
973
+ thread, graph_id = await asyncio.gather(
974
+ Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx),
975
+ Threads.get_graph_id(config["configurable"]["thread_id"]),
976
+ )
977
+ thread = await fetchone(thread)
978
+ thread_metadata = json_loads(thread["metadata"])
979
+ thread_config = json_loads(thread["config"])
980
+ thread_config = {
981
+ **thread_config,
982
+ "configurable": {
983
+ **thread_config.get("configurable", {}),
984
+ **config.get("configurable", {}),
985
+ },
986
+ }
987
+ graph_id = graph_id or thread_metadata.get("graph_id")
988
+
989
+ if graph_id:
990
+ async with get_graph(
991
+ graph_id,
992
+ thread_config,
993
+ checkpointer=Checkpointer(
994
+ conn, unpack_hook=_msgpack_ext_hook_to_json
995
+ ),
996
+ store=(await api_store.get_store()),
997
+ is_for_execution=False,
998
+ ) as graph:
999
+ return [
1000
+ c
1001
+ async for c in graph.aget_state_history(
1002
+ config,
1003
+ limit=limit,
1004
+ filter=metadata,
1005
+ before=(
1006
+ {"configurable": {"checkpoint_id": before}}
1007
+ if isinstance(before, str)
1008
+ else before
1009
+ ),
1010
+ )
1011
+ ]
1012
+ else:
1013
+ return []