langgraph-api 0.4.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/__init__.py +111 -51
- langgraph_api/api/a2a.py +1610 -0
- langgraph_api/api/assistants.py +212 -89
- langgraph_api/api/mcp.py +3 -3
- langgraph_api/api/meta.py +52 -28
- langgraph_api/api/openapi.py +27 -17
- langgraph_api/api/profile.py +108 -0
- langgraph_api/api/runs.py +342 -195
- langgraph_api/api/store.py +19 -2
- langgraph_api/api/threads.py +209 -27
- langgraph_api/asgi_transport.py +14 -9
- langgraph_api/asyncio.py +14 -4
- langgraph_api/auth/custom.py +52 -37
- langgraph_api/auth/langsmith/backend.py +4 -3
- langgraph_api/auth/langsmith/client.py +13 -8
- langgraph_api/cli.py +230 -133
- langgraph_api/command.py +5 -3
- langgraph_api/config/__init__.py +532 -0
- langgraph_api/config/_parse.py +58 -0
- langgraph_api/config/schemas.py +431 -0
- langgraph_api/cron_scheduler.py +17 -1
- langgraph_api/encryption/__init__.py +15 -0
- langgraph_api/encryption/aes_json.py +158 -0
- langgraph_api/encryption/context.py +35 -0
- langgraph_api/encryption/custom.py +280 -0
- langgraph_api/encryption/middleware.py +632 -0
- langgraph_api/encryption/shared.py +63 -0
- langgraph_api/errors.py +12 -1
- langgraph_api/executor_entrypoint.py +11 -6
- langgraph_api/feature_flags.py +29 -0
- langgraph_api/graph.py +176 -76
- langgraph_api/grpc/client.py +313 -0
- langgraph_api/grpc/config_conversion.py +231 -0
- langgraph_api/grpc/generated/__init__.py +29 -0
- langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
- langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
- langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
- langgraph_api/grpc/generated/core_api_pb2.py +216 -0
- langgraph_api/grpc/generated/core_api_pb2.pyi +905 -0
- langgraph_api/grpc/generated/core_api_pb2_grpc.py +1621 -0
- langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
- langgraph_api/grpc/generated/engine_common_pb2.pyi +722 -0
- langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
- langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/errors_pb2.py +39 -0
- langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
- langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
- langgraph_api/grpc/ops/__init__.py +370 -0
- langgraph_api/grpc/ops/assistants.py +424 -0
- langgraph_api/grpc/ops/runs.py +792 -0
- langgraph_api/grpc/ops/threads.py +1013 -0
- langgraph_api/http.py +16 -5
- langgraph_api/http_metrics.py +15 -35
- langgraph_api/http_metrics_utils.py +38 -0
- langgraph_api/js/build.mts +1 -1
- langgraph_api/js/client.http.mts +13 -7
- langgraph_api/js/client.mts +2 -5
- langgraph_api/js/package.json +29 -28
- langgraph_api/js/remote.py +56 -30
- langgraph_api/js/src/graph.mts +20 -0
- langgraph_api/js/sse.py +2 -2
- langgraph_api/js/ui.py +1 -1
- langgraph_api/js/yarn.lock +1204 -1006
- langgraph_api/logging.py +29 -2
- langgraph_api/metadata.py +99 -28
- langgraph_api/middleware/http_logger.py +7 -2
- langgraph_api/middleware/private_network.py +7 -7
- langgraph_api/models/run.py +54 -93
- langgraph_api/otel_context.py +205 -0
- langgraph_api/patch.py +5 -3
- langgraph_api/queue_entrypoint.py +154 -65
- langgraph_api/route.py +47 -5
- langgraph_api/schema.py +88 -10
- langgraph_api/self_hosted_logs.py +124 -0
- langgraph_api/self_hosted_metrics.py +450 -0
- langgraph_api/serde.py +79 -37
- langgraph_api/server.py +138 -60
- langgraph_api/state.py +4 -3
- langgraph_api/store.py +25 -16
- langgraph_api/stream.py +80 -29
- langgraph_api/thread_ttl.py +31 -13
- langgraph_api/timing/__init__.py +25 -0
- langgraph_api/timing/profiler.py +200 -0
- langgraph_api/timing/timer.py +318 -0
- langgraph_api/utils/__init__.py +53 -8
- langgraph_api/utils/cache.py +47 -10
- langgraph_api/utils/config.py +2 -1
- langgraph_api/utils/errors.py +77 -0
- langgraph_api/utils/future.py +10 -6
- langgraph_api/utils/headers.py +76 -2
- langgraph_api/utils/retriable_client.py +74 -0
- langgraph_api/utils/stream_codec.py +315 -0
- langgraph_api/utils/uuids.py +29 -62
- langgraph_api/validation.py +9 -0
- langgraph_api/webhook.py +120 -6
- langgraph_api/worker.py +55 -24
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +16 -8
- langgraph_api-0.7.3.dist-info/RECORD +168 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
- langgraph_runtime/__init__.py +1 -0
- langgraph_runtime/routes.py +11 -0
- logging.json +1 -3
- openapi.json +839 -478
- langgraph_api/config.py +0 -387
- langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
- langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
- langgraph_api/js/package-lock.json +0 -3308
- langgraph_api-0.4.1.dist-info/RECORD +0 -107
- /langgraph_api/{utils.py → grpc/__init__.py} +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1013 @@
|
|
|
1
|
+
"""gRPC-based threads operations."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from contextlib import AsyncExitStack
|
|
7
|
+
from datetime import UTC
|
|
8
|
+
from http import HTTPStatus
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
10
|
+
from uuid import UUID
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from collections.abc import Sequence
|
|
14
|
+
|
|
15
|
+
import orjson
|
|
16
|
+
import structlog
|
|
17
|
+
from langgraph.checkpoint.serde.jsonplus import _msgpack_ext_hook_to_json
|
|
18
|
+
from langgraph.types import StateSnapshot, StateUpdate
|
|
19
|
+
from langgraph_sdk import Auth
|
|
20
|
+
from starlette.exceptions import HTTPException
|
|
21
|
+
|
|
22
|
+
from langgraph_api import store as api_store
|
|
23
|
+
from langgraph_api.command import map_cmd
|
|
24
|
+
from langgraph_api.graph import get_graph
|
|
25
|
+
from langgraph_api.grpc.client import get_shared_client
|
|
26
|
+
from langgraph_api.grpc.generated import checkpointer_pb2
|
|
27
|
+
from langgraph_api.grpc.generated import core_api_pb2 as pb
|
|
28
|
+
from langgraph_api.grpc.generated import enum_thread_status_pb2 as enum_thread_status
|
|
29
|
+
from langgraph_api.grpc.ops import (
|
|
30
|
+
Authenticated,
|
|
31
|
+
_map_sort_order,
|
|
32
|
+
grpc_error_guard,
|
|
33
|
+
map_if_exists,
|
|
34
|
+
)
|
|
35
|
+
from langgraph_api.grpc.ops.runs import Runs
|
|
36
|
+
from langgraph_api.schema import ThreadUpdateResponse
|
|
37
|
+
from langgraph_api.serde import json_dumpb, json_dumpb_optional, json_loads
|
|
38
|
+
from langgraph_api.state import patch_interrupt, state_snapshot_to_thread_state
|
|
39
|
+
from langgraph_api.utils import fetchone
|
|
40
|
+
from langgraph_runtime.checkpoint import Checkpointer
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from collections.abc import AsyncIterator
|
|
44
|
+
|
|
45
|
+
from langgraph_api.schema import (
|
|
46
|
+
MetadataInput,
|
|
47
|
+
OnConflictBehavior,
|
|
48
|
+
Thread,
|
|
49
|
+
ThreadSelectField,
|
|
50
|
+
ThreadStatus,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _snapshot_defaults():
|
|
57
|
+
"""Support older versions of langgraph that don't have interrupts field."""
|
|
58
|
+
if not hasattr(StateSnapshot, "interrupts"):
|
|
59
|
+
return {}
|
|
60
|
+
return {"interrupts": tuple()}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
THREAD_STATUS_TO_PB = {
|
|
64
|
+
"idle": enum_thread_status.idle,
|
|
65
|
+
"busy": enum_thread_status.busy,
|
|
66
|
+
"interrupted": enum_thread_status.interrupted,
|
|
67
|
+
"error": enum_thread_status.error,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
THREAD_STATUS_FROM_PB = {v: k for k, v in THREAD_STATUS_TO_PB.items()}
|
|
71
|
+
|
|
72
|
+
THREAD_SORT_BY_MAP = {
|
|
73
|
+
"unspecified": pb.ThreadsSortBy.THREADS_SORT_BY_UNSPECIFIED, # for enum completeness, never sent
|
|
74
|
+
"thread_id": pb.ThreadsSortBy.THREADS_SORT_BY_THREAD_ID,
|
|
75
|
+
"created_at": pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT,
|
|
76
|
+
"updated_at": pb.ThreadsSortBy.THREADS_SORT_BY_UPDATED_AT,
|
|
77
|
+
"status": pb.ThreadsSortBy.THREADS_SORT_BY_STATUS,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
THREAD_TTL_STRATEGY_MAP = {
|
|
81
|
+
"delete": pb.ThreadTTLStrategy.THREAD_TTL_STRATEGY_DELETE,
|
|
82
|
+
"keep_latest": pb.ThreadTTLStrategy.THREAD_TTL_STRATEGY_KEEP_LATEST,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _map_thread_status(
|
|
87
|
+
status: ThreadStatus | None,
|
|
88
|
+
) -> enum_thread_status.ThreadStatus | None:
|
|
89
|
+
if status is None:
|
|
90
|
+
return None
|
|
91
|
+
return THREAD_STATUS_TO_PB.get(status)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _map_threads_sort_by(sort_by: str | None) -> pb.ThreadsSortBy:
|
|
95
|
+
if not sort_by or sort_by.lower() == "unspecified":
|
|
96
|
+
return pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
|
|
97
|
+
return THREAD_SORT_BY_MAP.get(
|
|
98
|
+
sort_by.lower(), pb.ThreadsSortBy.THREADS_SORT_BY_CREATED_AT
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _map_thread_ttl(ttl: dict[str, Any] | None) -> pb.ThreadTTLConfig | None:
|
|
103
|
+
if not ttl:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
config = pb.ThreadTTLConfig()
|
|
107
|
+
strategy = ttl.get("strategy")
|
|
108
|
+
if strategy:
|
|
109
|
+
mapped_strategy = THREAD_TTL_STRATEGY_MAP.get(str(strategy).lower())
|
|
110
|
+
if mapped_strategy is None:
|
|
111
|
+
raise HTTPException(
|
|
112
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
113
|
+
detail=f"Invalid thread TTL strategy: {strategy}. Expected one of {list(THREAD_TTL_STRATEGY_MAP.keys())}",
|
|
114
|
+
)
|
|
115
|
+
config.strategy = mapped_strategy
|
|
116
|
+
|
|
117
|
+
ttl_value = ttl.get("ttl", ttl.get("default_ttl"))
|
|
118
|
+
if ttl_value is not None:
|
|
119
|
+
config.default_ttl = float(ttl_value)
|
|
120
|
+
|
|
121
|
+
sweep_interval = ttl.get("sweep_interval_minutes")
|
|
122
|
+
if sweep_interval is not None:
|
|
123
|
+
config.sweep_interval_minutes = int(sweep_interval)
|
|
124
|
+
|
|
125
|
+
# Note: sweep_limit is a server-side configuration for the TTL sweep loop,
|
|
126
|
+
# not a per-thread setting, so we don't send it via gRPC
|
|
127
|
+
|
|
128
|
+
return config
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def fragment_to_value(fragment: pb.Fragment | None) -> Any:
|
|
132
|
+
if fragment is None or not fragment.value or fragment.value == b"{}":
|
|
133
|
+
return None
|
|
134
|
+
try:
|
|
135
|
+
return json_loads(fragment.value)
|
|
136
|
+
except orjson.JSONDecodeError:
|
|
137
|
+
logger.warning("Failed to decode fragment", fragment=fragment.value)
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _proto_interrupts_to_dict(
|
|
142
|
+
interrupts_map: dict[str, pb.Interrupts],
|
|
143
|
+
) -> dict[str, list[dict[str, Any]]]:
|
|
144
|
+
out: dict[str, list[dict[str, Any]]] = {}
|
|
145
|
+
for key, interrupts in interrupts_map.items():
|
|
146
|
+
entries: list[dict[str, Any]] = []
|
|
147
|
+
for interrupt in interrupts.interrupts:
|
|
148
|
+
entry: dict[str, Any] = {
|
|
149
|
+
"id": interrupt.id or None,
|
|
150
|
+
"value": json_loads(interrupt.value),
|
|
151
|
+
}
|
|
152
|
+
if interrupt.when:
|
|
153
|
+
entry["when"] = interrupt.when
|
|
154
|
+
if interrupt.resumable:
|
|
155
|
+
entry["resumable"] = interrupt.resumable
|
|
156
|
+
if interrupt.ns:
|
|
157
|
+
entry["ns"] = list(interrupt.ns)
|
|
158
|
+
entries.append(entry)
|
|
159
|
+
out[key] = entries
|
|
160
|
+
return out
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def proto_to_thread(proto_thread: pb.Thread) -> Thread:
|
|
164
|
+
"""Convert protobuf Thread to API dictionary format."""
|
|
165
|
+
thread_id = (
|
|
166
|
+
UUID(proto_thread.thread_id.value)
|
|
167
|
+
if proto_thread.HasField("thread_id")
|
|
168
|
+
else None
|
|
169
|
+
)
|
|
170
|
+
if thread_id is None:
|
|
171
|
+
raise HTTPException(
|
|
172
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
|
173
|
+
detail="Thread response missing thread_id",
|
|
174
|
+
)
|
|
175
|
+
created_at = (
|
|
176
|
+
proto_thread.created_at.ToDatetime(tzinfo=UTC)
|
|
177
|
+
if proto_thread.HasField("created_at")
|
|
178
|
+
else None
|
|
179
|
+
)
|
|
180
|
+
updated_at = (
|
|
181
|
+
proto_thread.updated_at.ToDatetime(tzinfo=UTC)
|
|
182
|
+
if proto_thread.HasField("updated_at")
|
|
183
|
+
else None
|
|
184
|
+
)
|
|
185
|
+
status = THREAD_STATUS_FROM_PB.get(proto_thread.status, "idle")
|
|
186
|
+
|
|
187
|
+
return {
|
|
188
|
+
"thread_id": thread_id,
|
|
189
|
+
"created_at": created_at,
|
|
190
|
+
"updated_at": updated_at,
|
|
191
|
+
# Unlike other fields, metadata should never be `None`.
|
|
192
|
+
"metadata": fragment_to_value(proto_thread.metadata) or {},
|
|
193
|
+
"config": fragment_to_value(proto_thread.config) or {},
|
|
194
|
+
"error": fragment_to_value(proto_thread.error),
|
|
195
|
+
"status": status, # type: ignore[typeddict-item]
|
|
196
|
+
"values": fragment_to_value(proto_thread.values),
|
|
197
|
+
"interrupts": _proto_interrupts_to_dict(dict(proto_thread.interrupts)),
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _filter_thread_fields(
|
|
202
|
+
thread: Thread, select: list[ThreadSelectField] | None
|
|
203
|
+
) -> dict[str, Any]:
|
|
204
|
+
if not select:
|
|
205
|
+
return dict(thread)
|
|
206
|
+
return {field: thread[field] for field in select if field in thread}
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _normalize_uuid(value: UUID | str) -> str:
|
|
210
|
+
return str(value) if isinstance(value, UUID) else str(UUID(str(value)))
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _thread_status_checkpoint_to_proto(
|
|
214
|
+
checkpoint: dict[str, Any] | None,
|
|
215
|
+
) -> pb.ThreadStatusCheckpoint | None:
|
|
216
|
+
"""Convert checkpoint dict to ThreadStatusCheckpoint proto."""
|
|
217
|
+
if checkpoint is None:
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# Compute interrupts map from tasks (same logic as storage_postgres/ops.py)
|
|
221
|
+
interrupts = {
|
|
222
|
+
t["id"]: [patch_interrupt(i) for i in t["interrupts"]]
|
|
223
|
+
for t in checkpoint.get("tasks", [])
|
|
224
|
+
if t.get("interrupts")
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
return pb.ThreadStatusCheckpoint(
|
|
228
|
+
values_json=json_dumpb(checkpoint.get("values", {})),
|
|
229
|
+
next=list(checkpoint.get("next", [])),
|
|
230
|
+
interrupts_json=json_dumpb(interrupts),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _json_contains(container: Any, subset: dict[str, Any]) -> bool:
|
|
235
|
+
if not subset:
|
|
236
|
+
return True
|
|
237
|
+
if not isinstance(container, dict):
|
|
238
|
+
return False
|
|
239
|
+
for key, value in subset.items():
|
|
240
|
+
if key not in container:
|
|
241
|
+
return False
|
|
242
|
+
candidate = container[key]
|
|
243
|
+
if isinstance(value, dict):
|
|
244
|
+
if not _json_contains(candidate, value):
|
|
245
|
+
return False
|
|
246
|
+
else:
|
|
247
|
+
if candidate != value:
|
|
248
|
+
return False
|
|
249
|
+
return True
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@grpc_error_guard
|
|
253
|
+
class Threads(Authenticated):
|
|
254
|
+
"""gRPC-based threads operations."""
|
|
255
|
+
|
|
256
|
+
resource = "threads"
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
async def search(
|
|
260
|
+
conn, # Not used in gRPC implementation
|
|
261
|
+
*,
|
|
262
|
+
ids: list[str] | list[UUID] | None = None,
|
|
263
|
+
metadata: MetadataInput,
|
|
264
|
+
values: MetadataInput,
|
|
265
|
+
status: ThreadStatus | None,
|
|
266
|
+
limit: int,
|
|
267
|
+
offset: int,
|
|
268
|
+
sort_by: str | None = None,
|
|
269
|
+
sort_order: str | None = None,
|
|
270
|
+
select: list[ThreadSelectField] | None = None,
|
|
271
|
+
ctx: Any = None,
|
|
272
|
+
) -> tuple[AsyncIterator[Thread], int | None]: # type: ignore[return-value]
|
|
273
|
+
metadata = metadata or {}
|
|
274
|
+
values = values or {}
|
|
275
|
+
|
|
276
|
+
auth_filters = await Threads.handle_event(
|
|
277
|
+
ctx,
|
|
278
|
+
"search",
|
|
279
|
+
{
|
|
280
|
+
"metadata": metadata,
|
|
281
|
+
"values": values,
|
|
282
|
+
"status": status,
|
|
283
|
+
"limit": limit,
|
|
284
|
+
"offset": offset,
|
|
285
|
+
},
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if ids:
|
|
289
|
+
normalized_ids = [_normalize_uuid(thread_id) for thread_id in ids]
|
|
290
|
+
threads: list[Thread] = []
|
|
291
|
+
client = await get_shared_client()
|
|
292
|
+
for thread_id in normalized_ids:
|
|
293
|
+
request = pb.GetThreadRequest(
|
|
294
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
295
|
+
filters=auth_filters,
|
|
296
|
+
)
|
|
297
|
+
response = await client.threads.Get(request)
|
|
298
|
+
thread = proto_to_thread(response)
|
|
299
|
+
|
|
300
|
+
if status and thread["status"] != status:
|
|
301
|
+
continue
|
|
302
|
+
if metadata and not _json_contains(thread["metadata"], metadata):
|
|
303
|
+
continue
|
|
304
|
+
if values and not _json_contains(thread.get("values") or {}, values):
|
|
305
|
+
continue
|
|
306
|
+
threads.append(thread)
|
|
307
|
+
|
|
308
|
+
total = len(threads)
|
|
309
|
+
paginated = threads[offset : offset + limit]
|
|
310
|
+
cursor = offset + limit if total > offset + limit else None
|
|
311
|
+
|
|
312
|
+
async def generate_results():
|
|
313
|
+
for thread in paginated:
|
|
314
|
+
yield _filter_thread_fields(thread, select)
|
|
315
|
+
|
|
316
|
+
return generate_results(), cursor
|
|
317
|
+
|
|
318
|
+
request_kwargs: dict[str, Any] = {
|
|
319
|
+
"filters": auth_filters,
|
|
320
|
+
"metadata_json": json_dumpb_optional(metadata),
|
|
321
|
+
"values_json": json_dumpb_optional(values),
|
|
322
|
+
"limit": limit,
|
|
323
|
+
"offset": offset,
|
|
324
|
+
"sort_by": _map_threads_sort_by(sort_by),
|
|
325
|
+
"sort_order": _map_sort_order(sort_order),
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
if status:
|
|
329
|
+
mapped_status = _map_thread_status(status)
|
|
330
|
+
if mapped_status is None:
|
|
331
|
+
raise HTTPException(
|
|
332
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
333
|
+
detail=f"Invalid thread status: {status}",
|
|
334
|
+
)
|
|
335
|
+
request_kwargs["status"] = mapped_status
|
|
336
|
+
|
|
337
|
+
if select:
|
|
338
|
+
request_kwargs["select"] = select
|
|
339
|
+
|
|
340
|
+
client = await get_shared_client()
|
|
341
|
+
response = await client.threads.Search(
|
|
342
|
+
pb.SearchThreadsRequest(**request_kwargs)
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
threads = [proto_to_thread(thread) for thread in response.threads]
|
|
346
|
+
cursor = offset + limit if len(threads) == limit else None
|
|
347
|
+
|
|
348
|
+
async def generate_results():
|
|
349
|
+
for thread in threads:
|
|
350
|
+
yield _filter_thread_fields(thread, select)
|
|
351
|
+
|
|
352
|
+
return generate_results(), cursor
|
|
353
|
+
|
|
354
|
+
@staticmethod
|
|
355
|
+
async def count(
|
|
356
|
+
conn, # Not used
|
|
357
|
+
*,
|
|
358
|
+
metadata: MetadataInput,
|
|
359
|
+
values: MetadataInput,
|
|
360
|
+
status: ThreadStatus | None,
|
|
361
|
+
ctx: Any = None,
|
|
362
|
+
) -> int: # type: ignore[override]
|
|
363
|
+
metadata = metadata or {}
|
|
364
|
+
values = values or {}
|
|
365
|
+
|
|
366
|
+
auth_filters = await Threads.handle_event(
|
|
367
|
+
ctx,
|
|
368
|
+
"search",
|
|
369
|
+
{
|
|
370
|
+
"metadata": metadata,
|
|
371
|
+
"values": values,
|
|
372
|
+
"status": status,
|
|
373
|
+
},
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
request_kwargs: dict[str, Any] = {
|
|
377
|
+
"filters": auth_filters,
|
|
378
|
+
"metadata_json": json_dumpb_optional(metadata),
|
|
379
|
+
"values_json": json_dumpb_optional(values),
|
|
380
|
+
}
|
|
381
|
+
if status:
|
|
382
|
+
mapped_status = _map_thread_status(status)
|
|
383
|
+
if mapped_status is None:
|
|
384
|
+
raise HTTPException(
|
|
385
|
+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
|
386
|
+
detail=f"Invalid thread status: {status}",
|
|
387
|
+
)
|
|
388
|
+
request_kwargs["status"] = mapped_status
|
|
389
|
+
|
|
390
|
+
client = await get_shared_client()
|
|
391
|
+
response = await client.threads.Count(pb.CountThreadsRequest(**request_kwargs))
|
|
392
|
+
|
|
393
|
+
return int(response.count)
|
|
394
|
+
|
|
395
|
+
@staticmethod
|
|
396
|
+
async def get(
|
|
397
|
+
conn, # Not used
|
|
398
|
+
thread_id: UUID | str,
|
|
399
|
+
ctx: Auth.types.BaseAuthContext | None = None,
|
|
400
|
+
filters: Auth.types.FilterType | None = None,
|
|
401
|
+
include_ttl: bool = False,
|
|
402
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
403
|
+
"""Get a thread by ID.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
conn: Not used (required for interface compatibility)
|
|
407
|
+
thread_id: Thread ID
|
|
408
|
+
ctx: Auth context
|
|
409
|
+
filters: Additional auth filters to merge with auth context filters
|
|
410
|
+
include_ttl: Not yet supported in gRPC - parameter ignored.
|
|
411
|
+
"""
|
|
412
|
+
auth_filters = await Threads.handle_event(
|
|
413
|
+
ctx, "read", {"thread_id": str(thread_id)}
|
|
414
|
+
)
|
|
415
|
+
# Merge auth filters with any additional parent filters provided.
|
|
416
|
+
# (Parent filters take precedence.)
|
|
417
|
+
if filters:
|
|
418
|
+
auth_filters = {**(auth_filters or {}), **(filters or {})}
|
|
419
|
+
|
|
420
|
+
request = pb.GetThreadRequest(
|
|
421
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
422
|
+
filters=auth_filters,
|
|
423
|
+
)
|
|
424
|
+
client = await get_shared_client()
|
|
425
|
+
response = await client.threads.Get(request)
|
|
426
|
+
|
|
427
|
+
thread = proto_to_thread(response)
|
|
428
|
+
|
|
429
|
+
async def generate_result():
|
|
430
|
+
yield thread
|
|
431
|
+
|
|
432
|
+
return generate_result()
|
|
433
|
+
|
|
434
|
+
@staticmethod
|
|
435
|
+
async def put(
|
|
436
|
+
conn, # Not used
|
|
437
|
+
thread_id: UUID | str,
|
|
438
|
+
*,
|
|
439
|
+
metadata: MetadataInput,
|
|
440
|
+
if_exists: OnConflictBehavior,
|
|
441
|
+
ttl: dict[str, Any] | None = None,
|
|
442
|
+
ctx: Any = None,
|
|
443
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
444
|
+
metadata = metadata or {}
|
|
445
|
+
|
|
446
|
+
auth_filters = await Threads.handle_event(
|
|
447
|
+
ctx,
|
|
448
|
+
"create",
|
|
449
|
+
{
|
|
450
|
+
"thread_id": str(thread_id),
|
|
451
|
+
"metadata": metadata,
|
|
452
|
+
"if_exists": if_exists,
|
|
453
|
+
},
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
request = pb.CreateThreadRequest(
|
|
457
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
458
|
+
filters=auth_filters,
|
|
459
|
+
if_exists=map_if_exists(if_exists),
|
|
460
|
+
metadata_json=json_dumpb_optional(metadata),
|
|
461
|
+
)
|
|
462
|
+
ttl_config = _map_thread_ttl(ttl)
|
|
463
|
+
if ttl_config is not None:
|
|
464
|
+
request.ttl.CopyFrom(ttl_config)
|
|
465
|
+
|
|
466
|
+
client = await get_shared_client()
|
|
467
|
+
response = await client.threads.Create(request)
|
|
468
|
+
thread = proto_to_thread(response)
|
|
469
|
+
|
|
470
|
+
async def generate_result():
|
|
471
|
+
yield thread
|
|
472
|
+
|
|
473
|
+
return generate_result()
|
|
474
|
+
|
|
475
|
+
@staticmethod
|
|
476
|
+
async def patch(
|
|
477
|
+
conn, # Not used
|
|
478
|
+
thread_id: UUID | str,
|
|
479
|
+
*,
|
|
480
|
+
metadata: MetadataInput,
|
|
481
|
+
ttl: dict[str, Any] | None = None,
|
|
482
|
+
ctx: Any = None,
|
|
483
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
484
|
+
metadata = metadata or {}
|
|
485
|
+
|
|
486
|
+
auth_filters = await Threads.handle_event(
|
|
487
|
+
ctx,
|
|
488
|
+
"update",
|
|
489
|
+
{
|
|
490
|
+
"thread_id": str(thread_id),
|
|
491
|
+
"metadata": metadata,
|
|
492
|
+
},
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
request = pb.PatchThreadRequest(
|
|
496
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
497
|
+
filters=auth_filters,
|
|
498
|
+
metadata_json=json_dumpb_optional(metadata),
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
ttl_config = _map_thread_ttl(ttl)
|
|
502
|
+
if ttl_config is not None:
|
|
503
|
+
request.ttl.CopyFrom(ttl_config)
|
|
504
|
+
|
|
505
|
+
client = await get_shared_client()
|
|
506
|
+
response = await client.threads.Patch(request)
|
|
507
|
+
|
|
508
|
+
thread = proto_to_thread(response)
|
|
509
|
+
|
|
510
|
+
async def generate_result():
|
|
511
|
+
yield thread
|
|
512
|
+
|
|
513
|
+
return generate_result()
|
|
514
|
+
|
|
515
|
+
@staticmethod
|
|
516
|
+
async def delete(
|
|
517
|
+
conn, # Not used
|
|
518
|
+
thread_id: UUID | str,
|
|
519
|
+
ctx: Any = None,
|
|
520
|
+
) -> AsyncIterator[UUID]: # type: ignore[return-value]
|
|
521
|
+
auth_filters = await Threads.handle_event(
|
|
522
|
+
ctx,
|
|
523
|
+
"delete",
|
|
524
|
+
{
|
|
525
|
+
"thread_id": str(thread_id),
|
|
526
|
+
},
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
request = pb.DeleteThreadRequest(
|
|
530
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
531
|
+
filters=auth_filters,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
client = await get_shared_client()
|
|
535
|
+
response = await client.threads.Delete(request)
|
|
536
|
+
|
|
537
|
+
deleted_id = UUID(response.value)
|
|
538
|
+
|
|
539
|
+
async def generate_result():
|
|
540
|
+
yield deleted_id
|
|
541
|
+
|
|
542
|
+
return generate_result()
|
|
543
|
+
|
|
544
|
+
@staticmethod
|
|
545
|
+
async def prune(
|
|
546
|
+
thread_ids: Sequence[str] | Sequence[UUID],
|
|
547
|
+
strategy: Literal["delete", "keep_latest"] = "delete",
|
|
548
|
+
batch_size: int = 100,
|
|
549
|
+
ctx: Any = None,
|
|
550
|
+
) -> int:
|
|
551
|
+
"""Prune threads via gRPC.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
thread_ids: List of thread IDs to prune
|
|
555
|
+
strategy: "delete" to remove entirely, "keep_latest" to prune checkpoints
|
|
556
|
+
batch_size: Batch size for operations
|
|
557
|
+
ctx: Auth context for permission checks
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
Number of threads successfully pruned
|
|
561
|
+
"""
|
|
562
|
+
|
|
563
|
+
if not thread_ids:
|
|
564
|
+
return 0
|
|
565
|
+
|
|
566
|
+
str_ids = [str(tid) for tid in thread_ids]
|
|
567
|
+
client = await get_shared_client()
|
|
568
|
+
|
|
569
|
+
# Validate delete authorization for all threads before pruning.
|
|
570
|
+
# Auth filters are based on user/action, so we only need to get them once.
|
|
571
|
+
auth_filters = await Threads.handle_event(
|
|
572
|
+
ctx,
|
|
573
|
+
"delete",
|
|
574
|
+
{"thread_ids": str_ids},
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
# Only validate access if auth filters are present
|
|
578
|
+
if auth_filters:
|
|
579
|
+
|
|
580
|
+
async def validate_thread_access(thread_id: str) -> None:
|
|
581
|
+
request = pb.GetThreadRequest(
|
|
582
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
583
|
+
filters=auth_filters,
|
|
584
|
+
)
|
|
585
|
+
await client.threads.Get(request)
|
|
586
|
+
|
|
587
|
+
await asyncio.gather(*[validate_thread_access(tid) for tid in str_ids])
|
|
588
|
+
|
|
589
|
+
if strategy == "delete":
|
|
590
|
+
strategy_proto = checkpointer_pb2.PruneRequest.PruneStrategy.DELETE_ALL
|
|
591
|
+
else:
|
|
592
|
+
strategy_proto = checkpointer_pb2.PruneRequest.PruneStrategy.KEEP_LATEST
|
|
593
|
+
stub = client.checkpointer
|
|
594
|
+
|
|
595
|
+
processed = 0
|
|
596
|
+
for i in range(0, len(str_ids), batch_size):
|
|
597
|
+
batch = str_ids[i : i + batch_size]
|
|
598
|
+
try:
|
|
599
|
+
request = checkpointer_pb2.PruneRequest(
|
|
600
|
+
thread_ids=batch,
|
|
601
|
+
strategy=strategy_proto,
|
|
602
|
+
)
|
|
603
|
+
await stub.Prune(request)
|
|
604
|
+
processed += len(batch)
|
|
605
|
+
except Exception:
|
|
606
|
+
await logger.aexception("Failed to prune thread. Skipping batch.")
|
|
607
|
+
pass
|
|
608
|
+
|
|
609
|
+
return processed
|
|
610
|
+
|
|
611
|
+
@staticmethod
|
|
612
|
+
async def copy(
|
|
613
|
+
conn, # Not used
|
|
614
|
+
thread_id: UUID | str,
|
|
615
|
+
ctx: Any = None,
|
|
616
|
+
) -> AsyncIterator[Thread]: # type: ignore[return-value]
|
|
617
|
+
auth_filters = await Threads.handle_event(
|
|
618
|
+
ctx,
|
|
619
|
+
"read",
|
|
620
|
+
{
|
|
621
|
+
"thread_id": str(thread_id),
|
|
622
|
+
},
|
|
623
|
+
)
|
|
624
|
+
# Validate that the user also has create permissions
|
|
625
|
+
# Filters will be the same as the read filters, so we can toss these
|
|
626
|
+
await Threads.handle_event(
|
|
627
|
+
ctx,
|
|
628
|
+
"create",
|
|
629
|
+
{
|
|
630
|
+
"thread_id": str(thread_id),
|
|
631
|
+
},
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
request = pb.CopyThreadRequest(
|
|
635
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
636
|
+
filters=auth_filters,
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
client = await get_shared_client()
|
|
640
|
+
response = await client.threads.Copy(request)
|
|
641
|
+
|
|
642
|
+
thread = proto_to_thread(response)
|
|
643
|
+
|
|
644
|
+
async def generate_result():
|
|
645
|
+
yield thread
|
|
646
|
+
|
|
647
|
+
return generate_result()
|
|
648
|
+
|
|
649
|
+
@staticmethod
|
|
650
|
+
async def set_status(
|
|
651
|
+
conn, # Not used in gRPC implementation
|
|
652
|
+
thread_id: UUID | str,
|
|
653
|
+
checkpoint: dict[str, Any] | None,
|
|
654
|
+
exception: BaseException | dict[str, Any] | None,
|
|
655
|
+
expected_status: ThreadStatus | Sequence[ThreadStatus] | None = None,
|
|
656
|
+
) -> None:
|
|
657
|
+
"""Set thread status via gRPC.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
conn: Not used (required for interface compatibility)
|
|
661
|
+
thread_id: Thread ID
|
|
662
|
+
checkpoint: Checkpoint payload containing values, next, tasks, etc.
|
|
663
|
+
exception: Exception to store on thread (BaseException or serialized dict)
|
|
664
|
+
expected_status: Expected current status(es) for optimistic locking
|
|
665
|
+
"""
|
|
666
|
+
request_kwargs: dict[str, Any] = {
|
|
667
|
+
"thread_id": pb.UUID(value=_normalize_uuid(thread_id)),
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
# Map checkpoint to proto
|
|
671
|
+
checkpoint_proto = _thread_status_checkpoint_to_proto(checkpoint)
|
|
672
|
+
if checkpoint_proto is not None:
|
|
673
|
+
request_kwargs["checkpoint"] = checkpoint_proto
|
|
674
|
+
|
|
675
|
+
# Map exception to JSON bytes
|
|
676
|
+
if exception is not None:
|
|
677
|
+
if isinstance(exception, BaseException):
|
|
678
|
+
exception_dict = {
|
|
679
|
+
"type": type(exception).__name__,
|
|
680
|
+
"message": str(exception),
|
|
681
|
+
}
|
|
682
|
+
else:
|
|
683
|
+
exception_dict = exception
|
|
684
|
+
request_kwargs["exception_json"] = json_dumpb(exception_dict)
|
|
685
|
+
|
|
686
|
+
# Map expected_status to enum values
|
|
687
|
+
if expected_status:
|
|
688
|
+
if isinstance(expected_status, str):
|
|
689
|
+
expected_status = [expected_status]
|
|
690
|
+
status_enums = []
|
|
691
|
+
for status in expected_status:
|
|
692
|
+
mapped = THREAD_STATUS_TO_PB.get(status)
|
|
693
|
+
if mapped is not None:
|
|
694
|
+
status_enums.append(mapped)
|
|
695
|
+
if status_enums:
|
|
696
|
+
request_kwargs["expected_status"] = status_enums
|
|
697
|
+
|
|
698
|
+
client = await get_shared_client()
|
|
699
|
+
await client.threads.SetStatus(pb.SetThreadStatusRequest(**request_kwargs))
|
|
700
|
+
|
|
701
|
+
@staticmethod
|
|
702
|
+
async def get_graph_id(
|
|
703
|
+
thread_id: UUID | str,
|
|
704
|
+
) -> str | None:
|
|
705
|
+
"""Get the graph ID for the latest run in a thread."""
|
|
706
|
+
request = pb.GetGraphIDRequest(
|
|
707
|
+
thread_id=pb.UUID(value=_normalize_uuid(thread_id)),
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
client = await get_shared_client()
|
|
711
|
+
response = await client.threads.GetGraphID(request)
|
|
712
|
+
|
|
713
|
+
return response.graph_id if response.graph_id else None
|
|
714
|
+
|
|
715
|
+
class State(Authenticated):
|
|
716
|
+
# treat this like threads resource
|
|
717
|
+
resource = "threads"
|
|
718
|
+
|
|
719
|
+
@staticmethod
|
|
720
|
+
async def get(
|
|
721
|
+
conn, # Still needed for checkpointer
|
|
722
|
+
config: dict[str, Any],
|
|
723
|
+
subgraphs: bool,
|
|
724
|
+
ctx: Any = None,
|
|
725
|
+
) -> StateSnapshot:
|
|
726
|
+
"""Get state snapshot for a thread (*internal only*, no auth)."""
|
|
727
|
+
checkpointer = Checkpointer(conn, unpack_hook=_msgpack_ext_hook_to_json)
|
|
728
|
+
thread_id = config["configurable"]["thread_id"]
|
|
729
|
+
|
|
730
|
+
async with conn.pipeline():
|
|
731
|
+
thread, checkpoint, graph_id = await asyncio.gather(
|
|
732
|
+
Threads.get(conn, thread_id, ctx=ctx),
|
|
733
|
+
checkpointer.aget_iter(config),
|
|
734
|
+
Threads.get_graph_id(thread_id),
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
thread = await fetchone(thread)
|
|
738
|
+
metadata = json_loads(thread["metadata"])
|
|
739
|
+
thread_config = json_loads(thread["config"])
|
|
740
|
+
thread_config = {
|
|
741
|
+
**thread_config,
|
|
742
|
+
"configurable": {
|
|
743
|
+
**thread_config.get("configurable", {}),
|
|
744
|
+
**config.get("configurable", {}),
|
|
745
|
+
},
|
|
746
|
+
}
|
|
747
|
+
graph_id = graph_id or metadata.get("graph_id")
|
|
748
|
+
|
|
749
|
+
if graph_id:
|
|
750
|
+
# format latest checkpoint for response
|
|
751
|
+
checkpointer.latest_iter = checkpoint
|
|
752
|
+
async with get_graph(
|
|
753
|
+
graph_id,
|
|
754
|
+
thread_config,
|
|
755
|
+
checkpointer=checkpointer,
|
|
756
|
+
store=(await api_store.get_store()),
|
|
757
|
+
) as graph:
|
|
758
|
+
return await graph.aget_state(config, subgraphs=subgraphs)
|
|
759
|
+
else:
|
|
760
|
+
_kwargs: dict[str, Any] = {
|
|
761
|
+
"values": {},
|
|
762
|
+
"next": tuple(),
|
|
763
|
+
"config": None,
|
|
764
|
+
"metadata": None,
|
|
765
|
+
"created_at": None,
|
|
766
|
+
"parent_config": None,
|
|
767
|
+
"tasks": tuple(),
|
|
768
|
+
}
|
|
769
|
+
_kwargs.update(_snapshot_defaults())
|
|
770
|
+
return StateSnapshot(**_kwargs) # type: ignore[missing-argument]
|
|
771
|
+
|
|
772
|
+
@staticmethod
|
|
773
|
+
async def post(
|
|
774
|
+
conn, # Still needed for checkpointer and run count check
|
|
775
|
+
config: dict[str, Any],
|
|
776
|
+
values: Any,
|
|
777
|
+
as_node: str | None = None,
|
|
778
|
+
ctx: Any = None,
|
|
779
|
+
) -> ThreadUpdateResponse:
|
|
780
|
+
"""Update thread state."""
|
|
781
|
+
thread_id = UUID(config["configurable"]["thread_id"])
|
|
782
|
+
filters = await Threads.State.handle_event(
|
|
783
|
+
ctx,
|
|
784
|
+
"update",
|
|
785
|
+
Auth.types.ThreadsUpdate(thread_id=thread_id),
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
checkpointer = Checkpointer(conn, use_direct_connection=True)
|
|
789
|
+
async with conn.pipeline():
|
|
790
|
+
thread, checkpoint, graph_id, run_count = await asyncio.gather(
|
|
791
|
+
Threads.get(conn, thread_id, ctx=ctx, filters=filters),
|
|
792
|
+
checkpointer.aget_iter(config),
|
|
793
|
+
Threads.get_graph_id(thread_id),
|
|
794
|
+
Runs.count(thread_id=thread_id, statuses=["pending", "running"]),
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
thread = await fetchone(thread)
|
|
798
|
+
metadata = json_loads(thread["metadata"])
|
|
799
|
+
thread_config = json_loads(thread["config"])
|
|
800
|
+
graph_id = graph_id or metadata.get("graph_id")
|
|
801
|
+
|
|
802
|
+
# Check if thread is busy before allowing state update
|
|
803
|
+
if run_count > 0:
|
|
804
|
+
raise HTTPException(
|
|
805
|
+
status_code=409,
|
|
806
|
+
detail="Thread is busy with a running job. Cannot update state.",
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
if graph_id:
|
|
810
|
+
# update state
|
|
811
|
+
config["configurable"].setdefault("graph_id", graph_id)
|
|
812
|
+
checkpointer.latest_iter = checkpoint
|
|
813
|
+
async with AsyncExitStack() as stack:
|
|
814
|
+
graph = await stack.enter_async_context(
|
|
815
|
+
get_graph(
|
|
816
|
+
graph_id,
|
|
817
|
+
thread_config,
|
|
818
|
+
checkpointer=checkpointer,
|
|
819
|
+
store=(await api_store.get_store()),
|
|
820
|
+
is_for_execution=False,
|
|
821
|
+
)
|
|
822
|
+
)
|
|
823
|
+
await stack.enter_async_context(conn.transaction())
|
|
824
|
+
next_config = await graph.aupdate_state(
|
|
825
|
+
config, values, as_node=as_node
|
|
826
|
+
)
|
|
827
|
+
# update thread values
|
|
828
|
+
state = await Threads.State.get(
|
|
829
|
+
conn, config, subgraphs=False, ctx=ctx
|
|
830
|
+
)
|
|
831
|
+
await Threads.set_status(
|
|
832
|
+
conn,
|
|
833
|
+
thread_id,
|
|
834
|
+
state_snapshot_to_thread_state(state),
|
|
835
|
+
None,
|
|
836
|
+
# Accept if NOT busy
|
|
837
|
+
expected_status=("interrupted", "idle", "error"),
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
# Publish state update event
|
|
841
|
+
event_data = {
|
|
842
|
+
"state": state_snapshot_to_thread_state(state),
|
|
843
|
+
"thread_id": str(thread_id),
|
|
844
|
+
}
|
|
845
|
+
await Runs.Stream.publish(
|
|
846
|
+
"*",
|
|
847
|
+
"state_update",
|
|
848
|
+
json_dumpb(event_data),
|
|
849
|
+
thread_id=thread_id,
|
|
850
|
+
resumable=True,
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
return {
|
|
854
|
+
"checkpoint": next_config["configurable"],
|
|
855
|
+
# below are deprecated
|
|
856
|
+
**next_config,
|
|
857
|
+
"checkpoint_id": next_config["configurable"]["checkpoint_id"],
|
|
858
|
+
}
|
|
859
|
+
else:
|
|
860
|
+
raise HTTPException(
|
|
861
|
+
status_code=400,
|
|
862
|
+
detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
|
|
863
|
+
" This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
@staticmethod
|
|
867
|
+
async def bulk(
|
|
868
|
+
conn, # Still needed for checkpointer
|
|
869
|
+
config: dict[str, Any],
|
|
870
|
+
supersteps: Any,
|
|
871
|
+
ctx: Any = None,
|
|
872
|
+
) -> ThreadUpdateResponse:
|
|
873
|
+
"""Update a thread with a batch of state updates."""
|
|
874
|
+
thread_id = UUID(config["configurable"]["thread_id"])
|
|
875
|
+
filters = await Threads.State.handle_event(
|
|
876
|
+
ctx,
|
|
877
|
+
"update",
|
|
878
|
+
Auth.types.ThreadsUpdate(thread_id=thread_id),
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
checkpointer = Checkpointer(conn)
|
|
882
|
+
|
|
883
|
+
async with conn.pipeline():
|
|
884
|
+
thread, graph_id = await asyncio.gather(
|
|
885
|
+
Threads.get(conn, thread_id, ctx=ctx, filters=filters),
|
|
886
|
+
Threads.get_graph_id(config["configurable"]["thread_id"]),
|
|
887
|
+
)
|
|
888
|
+
thread = await fetchone(thread)
|
|
889
|
+
thread_config = json_loads(thread["config"])
|
|
890
|
+
metadata = json_loads(thread["metadata"])
|
|
891
|
+
graph_id = graph_id or metadata.get("graph_id")
|
|
892
|
+
|
|
893
|
+
if graph_id:
|
|
894
|
+
# update state
|
|
895
|
+
config["configurable"].setdefault("graph_id", graph_id)
|
|
896
|
+
config["configurable"].setdefault("checkpoint_ns", "")
|
|
897
|
+
|
|
898
|
+
async with AsyncExitStack() as stack:
|
|
899
|
+
graph = await stack.enter_async_context(
|
|
900
|
+
get_graph(
|
|
901
|
+
graph_id,
|
|
902
|
+
thread_config,
|
|
903
|
+
checkpointer=checkpointer,
|
|
904
|
+
store=(await api_store.get_store()),
|
|
905
|
+
is_for_execution=False,
|
|
906
|
+
)
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
await stack.enter_async_context(conn.transaction())
|
|
910
|
+
next_config = await graph.abulk_update_state(
|
|
911
|
+
config,
|
|
912
|
+
[
|
|
913
|
+
[
|
|
914
|
+
StateUpdate(
|
|
915
|
+
(
|
|
916
|
+
map_cmd(update.get("command"))
|
|
917
|
+
if update.get("command")
|
|
918
|
+
else update.get("values")
|
|
919
|
+
),
|
|
920
|
+
update.get("as_node"),
|
|
921
|
+
)
|
|
922
|
+
for update in superstep.get("updates", [])
|
|
923
|
+
]
|
|
924
|
+
for superstep in supersteps
|
|
925
|
+
],
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
# update thread values
|
|
929
|
+
state = await Threads.State.get(
|
|
930
|
+
conn, config, subgraphs=False, ctx=ctx
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
await Threads.set_status(
|
|
934
|
+
conn,
|
|
935
|
+
thread_id,
|
|
936
|
+
state_snapshot_to_thread_state(state),
|
|
937
|
+
None,
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Publish state update event
|
|
941
|
+
event_data = {
|
|
942
|
+
"state": state_snapshot_to_thread_state(state),
|
|
943
|
+
"thread_id": str(thread_id),
|
|
944
|
+
}
|
|
945
|
+
await Runs.Stream.publish(
|
|
946
|
+
"*",
|
|
947
|
+
"state_update",
|
|
948
|
+
json_dumpb(event_data),
|
|
949
|
+
thread_id=thread_id,
|
|
950
|
+
resumable=True,
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
return ThreadUpdateResponse(checkpoint=next_config["configurable"])
|
|
954
|
+
else:
|
|
955
|
+
raise HTTPException(
|
|
956
|
+
status_code=400,
|
|
957
|
+
detail=f"Thread '{thread['thread_id']}' has no assigned graph ID. This usually occurs when no runs have been made on this particular thread."
|
|
958
|
+
" This operation requires a graph ID. Please ensure a run has been made for the thread or manually update the thread metadata (by setting the 'graph_id' field) before running this operation.",
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
@staticmethod
|
|
962
|
+
async def list(
|
|
963
|
+
conn, # Still needed for checkpointer
|
|
964
|
+
*,
|
|
965
|
+
config: dict[str, Any],
|
|
966
|
+
limit: int = 1,
|
|
967
|
+
before: Any = None,
|
|
968
|
+
metadata: Any = None,
|
|
969
|
+
ctx: Any = None,
|
|
970
|
+
) -> list[StateSnapshot]:
|
|
971
|
+
"""Get the history of a thread."""
|
|
972
|
+
async with conn.pipeline():
|
|
973
|
+
thread, graph_id = await asyncio.gather(
|
|
974
|
+
Threads.get(conn, config["configurable"]["thread_id"], ctx=ctx),
|
|
975
|
+
Threads.get_graph_id(config["configurable"]["thread_id"]),
|
|
976
|
+
)
|
|
977
|
+
thread = await fetchone(thread)
|
|
978
|
+
thread_metadata = json_loads(thread["metadata"])
|
|
979
|
+
thread_config = json_loads(thread["config"])
|
|
980
|
+
thread_config = {
|
|
981
|
+
**thread_config,
|
|
982
|
+
"configurable": {
|
|
983
|
+
**thread_config.get("configurable", {}),
|
|
984
|
+
**config.get("configurable", {}),
|
|
985
|
+
},
|
|
986
|
+
}
|
|
987
|
+
graph_id = graph_id or thread_metadata.get("graph_id")
|
|
988
|
+
|
|
989
|
+
if graph_id:
|
|
990
|
+
async with get_graph(
|
|
991
|
+
graph_id,
|
|
992
|
+
thread_config,
|
|
993
|
+
checkpointer=Checkpointer(
|
|
994
|
+
conn, unpack_hook=_msgpack_ext_hook_to_json
|
|
995
|
+
),
|
|
996
|
+
store=(await api_store.get_store()),
|
|
997
|
+
is_for_execution=False,
|
|
998
|
+
) as graph:
|
|
999
|
+
return [
|
|
1000
|
+
c
|
|
1001
|
+
async for c in graph.aget_state_history(
|
|
1002
|
+
config,
|
|
1003
|
+
limit=limit,
|
|
1004
|
+
filter=metadata,
|
|
1005
|
+
before=(
|
|
1006
|
+
{"configurable": {"checkpoint_id": before}}
|
|
1007
|
+
if isinstance(before, str)
|
|
1008
|
+
else before
|
|
1009
|
+
),
|
|
1010
|
+
)
|
|
1011
|
+
]
|
|
1012
|
+
else:
|
|
1013
|
+
return []
|