langgraph-api 0.4.1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/__init__.py +111 -51
- langgraph_api/api/a2a.py +1610 -0
- langgraph_api/api/assistants.py +212 -89
- langgraph_api/api/mcp.py +3 -3
- langgraph_api/api/meta.py +52 -28
- langgraph_api/api/openapi.py +27 -17
- langgraph_api/api/profile.py +108 -0
- langgraph_api/api/runs.py +342 -195
- langgraph_api/api/store.py +19 -2
- langgraph_api/api/threads.py +209 -27
- langgraph_api/asgi_transport.py +14 -9
- langgraph_api/asyncio.py +14 -4
- langgraph_api/auth/custom.py +52 -37
- langgraph_api/auth/langsmith/backend.py +4 -3
- langgraph_api/auth/langsmith/client.py +13 -8
- langgraph_api/cli.py +230 -133
- langgraph_api/command.py +5 -3
- langgraph_api/config/__init__.py +532 -0
- langgraph_api/config/_parse.py +58 -0
- langgraph_api/config/schemas.py +431 -0
- langgraph_api/cron_scheduler.py +17 -1
- langgraph_api/encryption/__init__.py +15 -0
- langgraph_api/encryption/aes_json.py +158 -0
- langgraph_api/encryption/context.py +35 -0
- langgraph_api/encryption/custom.py +280 -0
- langgraph_api/encryption/middleware.py +632 -0
- langgraph_api/encryption/shared.py +63 -0
- langgraph_api/errors.py +12 -1
- langgraph_api/executor_entrypoint.py +11 -6
- langgraph_api/feature_flags.py +29 -0
- langgraph_api/graph.py +176 -76
- langgraph_api/grpc/client.py +313 -0
- langgraph_api/grpc/config_conversion.py +231 -0
- langgraph_api/grpc/generated/__init__.py +29 -0
- langgraph_api/grpc/generated/checkpointer_pb2.py +63 -0
- langgraph_api/grpc/generated/checkpointer_pb2.pyi +99 -0
- langgraph_api/grpc/generated/checkpointer_pb2_grpc.py +329 -0
- langgraph_api/grpc/generated/core_api_pb2.py +216 -0
- langgraph_api/grpc/generated/core_api_pb2.pyi +905 -0
- langgraph_api/grpc/generated/core_api_pb2_grpc.py +1621 -0
- langgraph_api/grpc/generated/engine_common_pb2.py +219 -0
- langgraph_api/grpc/generated/engine_common_pb2.pyi +722 -0
- langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2.pyi +12 -0
- langgraph_api/grpc/generated/enum_cancel_run_action_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_control_signal_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_durability_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_durability_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_durability_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_multitask_strategy_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_run_status_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_run_status_pb2.pyi +22 -0
- langgraph_api/grpc/generated/enum_run_status_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2.pyi +28 -0
- langgraph_api/grpc/generated/enum_stream_mode_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_thread_status_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.py +37 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2.pyi +16 -0
- langgraph_api/grpc/generated/enum_thread_stream_mode_pb2_grpc.py +24 -0
- langgraph_api/grpc/generated/errors_pb2.py +39 -0
- langgraph_api/grpc/generated/errors_pb2.pyi +21 -0
- langgraph_api/grpc/generated/errors_pb2_grpc.py +24 -0
- langgraph_api/grpc/ops/__init__.py +370 -0
- langgraph_api/grpc/ops/assistants.py +424 -0
- langgraph_api/grpc/ops/runs.py +792 -0
- langgraph_api/grpc/ops/threads.py +1013 -0
- langgraph_api/http.py +16 -5
- langgraph_api/http_metrics.py +15 -35
- langgraph_api/http_metrics_utils.py +38 -0
- langgraph_api/js/build.mts +1 -1
- langgraph_api/js/client.http.mts +13 -7
- langgraph_api/js/client.mts +2 -5
- langgraph_api/js/package.json +29 -28
- langgraph_api/js/remote.py +56 -30
- langgraph_api/js/src/graph.mts +20 -0
- langgraph_api/js/sse.py +2 -2
- langgraph_api/js/ui.py +1 -1
- langgraph_api/js/yarn.lock +1204 -1006
- langgraph_api/logging.py +29 -2
- langgraph_api/metadata.py +99 -28
- langgraph_api/middleware/http_logger.py +7 -2
- langgraph_api/middleware/private_network.py +7 -7
- langgraph_api/models/run.py +54 -93
- langgraph_api/otel_context.py +205 -0
- langgraph_api/patch.py +5 -3
- langgraph_api/queue_entrypoint.py +154 -65
- langgraph_api/route.py +47 -5
- langgraph_api/schema.py +88 -10
- langgraph_api/self_hosted_logs.py +124 -0
- langgraph_api/self_hosted_metrics.py +450 -0
- langgraph_api/serde.py +79 -37
- langgraph_api/server.py +138 -60
- langgraph_api/state.py +4 -3
- langgraph_api/store.py +25 -16
- langgraph_api/stream.py +80 -29
- langgraph_api/thread_ttl.py +31 -13
- langgraph_api/timing/__init__.py +25 -0
- langgraph_api/timing/profiler.py +200 -0
- langgraph_api/timing/timer.py +318 -0
- langgraph_api/utils/__init__.py +53 -8
- langgraph_api/utils/cache.py +47 -10
- langgraph_api/utils/config.py +2 -1
- langgraph_api/utils/errors.py +77 -0
- langgraph_api/utils/future.py +10 -6
- langgraph_api/utils/headers.py +76 -2
- langgraph_api/utils/retriable_client.py +74 -0
- langgraph_api/utils/stream_codec.py +315 -0
- langgraph_api/utils/uuids.py +29 -62
- langgraph_api/validation.py +9 -0
- langgraph_api/webhook.py +120 -6
- langgraph_api/worker.py +55 -24
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/METADATA +16 -8
- langgraph_api-0.7.3.dist-info/RECORD +168 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/WHEEL +1 -1
- langgraph_runtime/__init__.py +1 -0
- langgraph_runtime/routes.py +11 -0
- logging.json +1 -3
- openapi.json +839 -478
- langgraph_api/config.py +0 -387
- langgraph_api/js/isolate-0x130008000-46649-46649-v8.log +0 -4430
- langgraph_api/js/isolate-0x138008000-44681-44681-v8.log +0 -4430
- langgraph_api/js/package-lock.json +0 -3308
- langgraph_api-0.4.1.dist-info/RECORD +0 -107
- /langgraph_api/{utils.py → grpc/__init__.py} +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/entry_points.txt +0 -0
- {langgraph_api-0.4.1.dist-info → langgraph_api-0.7.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""gRPC client wrapper for LangGraph persistence services."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
from grpc import aio # type: ignore[import]
|
|
9
|
+
from grpc_health.v1 import health_pb2, health_pb2_grpc # type: ignore[import]
|
|
10
|
+
|
|
11
|
+
from langgraph_api import config
|
|
12
|
+
|
|
13
|
+
from .generated.checkpointer_pb2_grpc import CheckpointerStub
|
|
14
|
+
from .generated.core_api_pb2_grpc import (
|
|
15
|
+
AdminStub,
|
|
16
|
+
AssistantsStub,
|
|
17
|
+
RunsStub,
|
|
18
|
+
ThreadsStub,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = structlog.stdlib.get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Shared gRPC client pools (main thread + thread-local for isolated loops).
|
|
25
|
+
_client_pool: "GrpcClientPool | None" = None
|
|
26
|
+
_thread_local = threading.local()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
GRPC_HEALTHCHECK_TIMEOUT = 5.0
|
|
30
|
+
GRPC_INIT_TIMEOUT = 10.0
|
|
31
|
+
GRPC_INIT_PROBE_INTERVAL = 0.5
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GrpcClient:
|
|
35
|
+
"""gRPC client for LangGraph persistence services."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
server_address: str | None = None,
|
|
40
|
+
):
|
|
41
|
+
"""Initialize the gRPC client.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
server_address: The gRPC server address (default: localhost:50051)
|
|
45
|
+
"""
|
|
46
|
+
self.server_address = server_address or config.GRPC_SERVER_ADDRESS
|
|
47
|
+
self._channel: aio.Channel | None = None
|
|
48
|
+
self._assistants_stub: AssistantsStub | None = None
|
|
49
|
+
self._runs_stub: RunsStub | None = None
|
|
50
|
+
self._threads_stub: ThreadsStub | None = None
|
|
51
|
+
self._admin_stub: AdminStub | None = None
|
|
52
|
+
self._checkpointer_stub: CheckpointerStub | None = None
|
|
53
|
+
self._health_stub: health_pb2_grpc.HealthStub | None = None
|
|
54
|
+
|
|
55
|
+
async def __aenter__(self):
|
|
56
|
+
"""Async context manager entry."""
|
|
57
|
+
await self.connect()
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
61
|
+
"""Async context manager exit."""
|
|
62
|
+
await self.close()
|
|
63
|
+
|
|
64
|
+
async def connect(self):
|
|
65
|
+
"""Connect to the gRPC server."""
|
|
66
|
+
if self._channel is not None:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
options = [
|
|
70
|
+
("grpc.max_receive_message_length", config.GRPC_CLIENT_MAX_RECV_MSG_BYTES),
|
|
71
|
+
("grpc.max_send_message_length", config.GRPC_CLIENT_MAX_SEND_MSG_BYTES),
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
self._channel = aio.insecure_channel(self.server_address, options=options)
|
|
75
|
+
|
|
76
|
+
self._assistants_stub = AssistantsStub(self._channel)
|
|
77
|
+
self._runs_stub = RunsStub(self._channel)
|
|
78
|
+
self._threads_stub = ThreadsStub(self._channel)
|
|
79
|
+
self._admin_stub = AdminStub(self._channel)
|
|
80
|
+
self._checkpointer_stub = CheckpointerStub(self._channel)
|
|
81
|
+
self._health_stub = health_pb2_grpc.HealthStub(self._channel)
|
|
82
|
+
|
|
83
|
+
await logger.adebug(
|
|
84
|
+
"Connected to gRPC server", server_address=self.server_address
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
async def close(self):
|
|
88
|
+
"""Close the gRPC connection."""
|
|
89
|
+
if self._channel is not None:
|
|
90
|
+
await self._channel.close()
|
|
91
|
+
self._channel = None
|
|
92
|
+
self._assistants_stub = None
|
|
93
|
+
self._runs_stub = None
|
|
94
|
+
self._threads_stub = None
|
|
95
|
+
self._admin_stub = None
|
|
96
|
+
self._checkpointer_stub = None
|
|
97
|
+
self._health_stub = None
|
|
98
|
+
await logger.adebug("Closed gRPC connection")
|
|
99
|
+
|
|
100
|
+
async def healthcheck(self) -> bool:
|
|
101
|
+
"""Check if the gRPC server is healthy.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
True if the server is healthy and serving.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
RuntimeError: If the client is not connected or the server is unhealthy.
|
|
108
|
+
"""
|
|
109
|
+
if self._health_stub is None:
|
|
110
|
+
raise RuntimeError(
|
|
111
|
+
"Client not connected. Use async context manager or call connect() first."
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
request = health_pb2.HealthCheckRequest(service="")
|
|
115
|
+
response = await self._health_stub.Check(
|
|
116
|
+
request, timeout=GRPC_HEALTHCHECK_TIMEOUT
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if response.status != health_pb2.HealthCheckResponse.SERVING:
|
|
120
|
+
raise RuntimeError(f"gRPC server is not healthy. Status: {response.status}")
|
|
121
|
+
|
|
122
|
+
return True
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def assistants(self) -> AssistantsStub:
|
|
126
|
+
"""Get the assistants service stub."""
|
|
127
|
+
if self._assistants_stub is None:
|
|
128
|
+
raise RuntimeError(
|
|
129
|
+
"Client not connected. Use async context manager or call connect() first."
|
|
130
|
+
)
|
|
131
|
+
return self._assistants_stub
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def threads(self) -> ThreadsStub:
|
|
135
|
+
"""Get the threads service stub."""
|
|
136
|
+
if self._threads_stub is None:
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
"Client not connected. Use async context manager or call connect() first."
|
|
139
|
+
)
|
|
140
|
+
return self._threads_stub
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def runs(self) -> RunsStub:
|
|
144
|
+
"""Get the runs service stub."""
|
|
145
|
+
if self._runs_stub is None:
|
|
146
|
+
raise RuntimeError(
|
|
147
|
+
"Client not connected. Use async context manager or call connect() first."
|
|
148
|
+
)
|
|
149
|
+
return self._runs_stub
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def admin(self) -> AdminStub:
|
|
153
|
+
"""Get the admin service stub."""
|
|
154
|
+
if self._admin_stub is None:
|
|
155
|
+
raise RuntimeError(
|
|
156
|
+
"Client not connected. Use async context manager or call connect() first."
|
|
157
|
+
)
|
|
158
|
+
return self._admin_stub
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def checkpointer(self) -> CheckpointerStub:
|
|
162
|
+
"""Get the checkpointer service stub."""
|
|
163
|
+
if self._checkpointer_stub is None:
|
|
164
|
+
raise RuntimeError(
|
|
165
|
+
"Client not connected. Use async context manager or call connect() first."
|
|
166
|
+
)
|
|
167
|
+
return self._checkpointer_stub
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class GrpcClientPool:
|
|
171
|
+
"""Pool of gRPC clients for load distribution."""
|
|
172
|
+
|
|
173
|
+
def __init__(self, pool_size: int = 5, server_address: str | None = None):
|
|
174
|
+
self.pool_size = pool_size
|
|
175
|
+
self.server_address = server_address
|
|
176
|
+
self.clients: list[GrpcClient] = []
|
|
177
|
+
self._current_index = 0
|
|
178
|
+
self._init_lock = asyncio.Lock()
|
|
179
|
+
self._initialized = False
|
|
180
|
+
|
|
181
|
+
async def _initialize(self):
|
|
182
|
+
"""Initialize the pool of clients."""
|
|
183
|
+
async with self._init_lock:
|
|
184
|
+
if self._initialized:
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
await logger.ainfo(
|
|
188
|
+
"Initializing gRPC client pool",
|
|
189
|
+
pool_size=self.pool_size,
|
|
190
|
+
server_address=self.server_address,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
for _ in range(self.pool_size):
|
|
194
|
+
client = GrpcClient(server_address=self.server_address)
|
|
195
|
+
await client.connect()
|
|
196
|
+
self.clients.append(client)
|
|
197
|
+
|
|
198
|
+
self._initialized = True
|
|
199
|
+
await logger.ainfo(
|
|
200
|
+
f"gRPC client pool initialized with {self.pool_size} clients"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
async def get_client(self) -> GrpcClient:
|
|
204
|
+
"""Get next client using round-robin selection.
|
|
205
|
+
|
|
206
|
+
Round-robin without strict locking - slight races are acceptable
|
|
207
|
+
and result in good enough distribution under high load.
|
|
208
|
+
"""
|
|
209
|
+
if not self._initialized:
|
|
210
|
+
await self._initialize()
|
|
211
|
+
|
|
212
|
+
idx = self._current_index % self.pool_size
|
|
213
|
+
self._current_index = idx + 1
|
|
214
|
+
return self.clients[idx]
|
|
215
|
+
|
|
216
|
+
async def close(self):
|
|
217
|
+
"""Close all clients in the pool."""
|
|
218
|
+
if self._initialized:
|
|
219
|
+
await logger.ainfo(f"Closing gRPC client pool ({self.pool_size} clients)")
|
|
220
|
+
for client in self.clients:
|
|
221
|
+
await client.close()
|
|
222
|
+
self.clients.clear()
|
|
223
|
+
self._initialized = False
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
async def get_shared_client() -> GrpcClient:
|
|
227
|
+
"""Get a gRPC client from the shared pool.
|
|
228
|
+
|
|
229
|
+
Uses a pool of channels for better performance under high concurrency.
|
|
230
|
+
Each channel is a separate TCP connection that can handle ~100-200
|
|
231
|
+
concurrent streams effectively. Pools are scoped per thread/loop to
|
|
232
|
+
avoid cross-loop gRPC channel usage.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
A GrpcClient instance from the pool
|
|
236
|
+
"""
|
|
237
|
+
if threading.current_thread() is not threading.main_thread():
|
|
238
|
+
pool = getattr(_thread_local, "grpc_pool", None)
|
|
239
|
+
if pool is None:
|
|
240
|
+
pool = GrpcClientPool(
|
|
241
|
+
pool_size=1,
|
|
242
|
+
server_address=config.GRPC_SERVER_ADDRESS,
|
|
243
|
+
)
|
|
244
|
+
_thread_local.grpc_pool = pool
|
|
245
|
+
return await pool.get_client()
|
|
246
|
+
|
|
247
|
+
global _client_pool
|
|
248
|
+
if _client_pool is None:
|
|
249
|
+
_client_pool = GrpcClientPool(
|
|
250
|
+
pool_size=config.GRPC_CLIENT_POOL_SIZE,
|
|
251
|
+
server_address=config.GRPC_SERVER_ADDRESS,
|
|
252
|
+
)
|
|
253
|
+
return await _client_pool.get_client()
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
async def wait_until_grpc_ready(
|
|
257
|
+
timeout_seconds: float = GRPC_INIT_TIMEOUT,
|
|
258
|
+
interval_seconds: float = GRPC_INIT_PROBE_INTERVAL,
|
|
259
|
+
):
|
|
260
|
+
"""Wait for the gRPC server to be ready with retries during startup.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
timeout_seconds: Maximum time to wait for the server to be ready.
|
|
264
|
+
interval_seconds: Time to wait between health check attempts.
|
|
265
|
+
Raises:
|
|
266
|
+
RuntimeError: If the server is not ready within the timeout period.
|
|
267
|
+
"""
|
|
268
|
+
client = await get_shared_client()
|
|
269
|
+
max_attempts = int(timeout_seconds / interval_seconds)
|
|
270
|
+
|
|
271
|
+
await logger.ainfo(
|
|
272
|
+
"Waiting for gRPC server to be ready",
|
|
273
|
+
timeout_seconds=timeout_seconds,
|
|
274
|
+
interval_seconds=interval_seconds,
|
|
275
|
+
max_attempts=max_attempts,
|
|
276
|
+
)
|
|
277
|
+
start_time = time.time()
|
|
278
|
+
for attempt in range(max_attempts):
|
|
279
|
+
try:
|
|
280
|
+
await client.healthcheck()
|
|
281
|
+
await logger.ainfo(
|
|
282
|
+
"gRPC server is ready",
|
|
283
|
+
attempt=attempt + 1,
|
|
284
|
+
elapsed_seconds=round(time.time() - start_time, 3),
|
|
285
|
+
)
|
|
286
|
+
return
|
|
287
|
+
except Exception as exc:
|
|
288
|
+
if attempt >= max_attempts - 1:
|
|
289
|
+
raise RuntimeError(
|
|
290
|
+
f"gRPC server not ready after {timeout_seconds}s (reached max attempts: {max_attempts})"
|
|
291
|
+
) from exc
|
|
292
|
+
else:
|
|
293
|
+
await logger.adebug(
|
|
294
|
+
"Waiting for gRPC server to be ready",
|
|
295
|
+
attempt=attempt + 1,
|
|
296
|
+
max_attempts=max_attempts,
|
|
297
|
+
)
|
|
298
|
+
await asyncio.sleep(interval_seconds)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
async def close_shared_client():
|
|
302
|
+
"""Close the shared gRPC client pool."""
|
|
303
|
+
if threading.current_thread() is not threading.main_thread():
|
|
304
|
+
pool = getattr(_thread_local, "grpc_pool", None)
|
|
305
|
+
if pool is not None:
|
|
306
|
+
await pool.close()
|
|
307
|
+
delattr(_thread_local, "grpc_pool")
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
global _client_pool
|
|
311
|
+
if _client_pool is not None:
|
|
312
|
+
await _client_pool.close()
|
|
313
|
+
_client_pool = None
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Conversion utils for the RunnableConfig."""
|
|
2
|
+
|
|
3
|
+
# THIS IS DUPLICATED
|
|
4
|
+
# TODO: WFH - Deduplicate with the executor logic by moving into a separate package
|
|
5
|
+
# Sequencing in the next PR.
|
|
6
|
+
from typing import Any, cast
|
|
7
|
+
|
|
8
|
+
import orjson
|
|
9
|
+
from langchain_core.runnables.config import RunnableConfig
|
|
10
|
+
|
|
11
|
+
from langgraph_api.grpc.generated import (
|
|
12
|
+
engine_common_pb2,
|
|
13
|
+
enum_durability_pb2,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
CONFIG_KEY_SEND = "__pregel_send"
|
|
17
|
+
CONFIG_KEY_READ = "__pregel_read"
|
|
18
|
+
CONFIG_KEY_RESUMING = "__pregel_resuming"
|
|
19
|
+
CONFIG_KEY_TASK_ID = "__pregel_task_id"
|
|
20
|
+
CONFIG_KEY_THREAD_ID = "thread_id"
|
|
21
|
+
CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map"
|
|
22
|
+
CONFIG_KEY_CHECKPOINT_ID = "checkpoint_id"
|
|
23
|
+
CONFIG_KEY_CHECKPOINT_NS = "checkpoint_ns"
|
|
24
|
+
CONFIG_KEY_SCRATCHPAD = "__pregel_scratchpad"
|
|
25
|
+
CONFIG_KEY_DURABILITY = "__pregel_durability"
|
|
26
|
+
CONFIG_KEY_GRAPH_ID = "graph_id"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _durability_to_proto(
|
|
30
|
+
durability: str,
|
|
31
|
+
) -> enum_durability_pb2.Durability:
|
|
32
|
+
match durability:
|
|
33
|
+
case "async":
|
|
34
|
+
return enum_durability_pb2.Durability.ASYNC
|
|
35
|
+
case "sync":
|
|
36
|
+
return enum_durability_pb2.Durability.SYNC
|
|
37
|
+
case "exit":
|
|
38
|
+
return enum_durability_pb2.Durability.EXIT
|
|
39
|
+
case _:
|
|
40
|
+
raise ValueError(f"invalid durability: {durability}")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _durability_from_proto(
|
|
44
|
+
durability: enum_durability_pb2.Durability,
|
|
45
|
+
) -> str:
|
|
46
|
+
match durability:
|
|
47
|
+
case enum_durability_pb2.Durability.ASYNC:
|
|
48
|
+
return "async"
|
|
49
|
+
case enum_durability_pb2.Durability.SYNC:
|
|
50
|
+
return "sync"
|
|
51
|
+
case enum_durability_pb2.Durability.EXIT:
|
|
52
|
+
return "exit"
|
|
53
|
+
case _:
|
|
54
|
+
raise ValueError(f"invalid durability: {durability}")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def config_to_proto(
|
|
58
|
+
config: RunnableConfig,
|
|
59
|
+
) -> engine_common_pb2.EngineRunnableConfig | None:
|
|
60
|
+
# Prepare kwargs for construction
|
|
61
|
+
if not config:
|
|
62
|
+
return None
|
|
63
|
+
cp = {**config}
|
|
64
|
+
pb_config = engine_common_pb2.EngineRunnableConfig()
|
|
65
|
+
for k, v in (cp.pop("metadata", None) or {}).items():
|
|
66
|
+
if k == "run_attempt":
|
|
67
|
+
pb_config.run_attempt = v
|
|
68
|
+
elif k == "run_id":
|
|
69
|
+
pb_config.server_run_id = str(v)
|
|
70
|
+
else:
|
|
71
|
+
pb_config.metadata_json[k] = orjson.dumps(v)
|
|
72
|
+
if run_name := cp.pop("run_name", None):
|
|
73
|
+
pb_config.run_name = run_name
|
|
74
|
+
|
|
75
|
+
if run_id := cp.pop("run_id", None):
|
|
76
|
+
pb_config.run_id = str(run_id)
|
|
77
|
+
|
|
78
|
+
if (max_concurrency := cp.pop("max_concurrency", None)) and isinstance(
|
|
79
|
+
max_concurrency, int
|
|
80
|
+
):
|
|
81
|
+
pb_config.max_concurrency = max_concurrency
|
|
82
|
+
|
|
83
|
+
if (recursion_limit := cp.pop("recursion_limit", None)) and isinstance(
|
|
84
|
+
recursion_limit, int
|
|
85
|
+
):
|
|
86
|
+
pb_config.recursion_limit = recursion_limit
|
|
87
|
+
|
|
88
|
+
# Handle collections after construction
|
|
89
|
+
if (tags := cp.pop("tags", None)) and isinstance(tags, list):
|
|
90
|
+
pb_config.tags.extend(tags)
|
|
91
|
+
|
|
92
|
+
if (configurable := cp.pop("configurable", None)) and isinstance(
|
|
93
|
+
configurable, dict
|
|
94
|
+
):
|
|
95
|
+
_inject_configurable_into_proto(configurable, pb_config)
|
|
96
|
+
if cp:
|
|
97
|
+
pb_config.extra_json.update({k: orjson.dumps(v) for k, v in cp.items()})
|
|
98
|
+
|
|
99
|
+
return pb_config
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
RESTRICTED_RESERVED_CONFIGURABLE_KEYS = {
|
|
103
|
+
CONFIG_KEY_SEND,
|
|
104
|
+
CONFIG_KEY_READ,
|
|
105
|
+
CONFIG_KEY_SCRATCHPAD,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _inject_configurable_into_proto(
|
|
110
|
+
configurable: dict[str, Any], proto: engine_common_pb2.EngineRunnableConfig
|
|
111
|
+
) -> None:
|
|
112
|
+
extra = {}
|
|
113
|
+
for key, value in configurable.items():
|
|
114
|
+
if key == CONFIG_KEY_RESUMING:
|
|
115
|
+
proto.resuming = bool(value)
|
|
116
|
+
elif key == CONFIG_KEY_TASK_ID:
|
|
117
|
+
proto.task_id = str(value)
|
|
118
|
+
elif key == CONFIG_KEY_THREAD_ID:
|
|
119
|
+
proto.thread_id = str(value)
|
|
120
|
+
elif key == CONFIG_KEY_CHECKPOINT_MAP:
|
|
121
|
+
proto.checkpoint_map.update(cast("dict[str, str]", value))
|
|
122
|
+
elif key == CONFIG_KEY_CHECKPOINT_ID:
|
|
123
|
+
proto.checkpoint_id = str(value)
|
|
124
|
+
elif key == CONFIG_KEY_CHECKPOINT_NS:
|
|
125
|
+
proto.checkpoint_ns = str(value)
|
|
126
|
+
elif key == CONFIG_KEY_DURABILITY and value:
|
|
127
|
+
proto.durability = _durability_to_proto(value)
|
|
128
|
+
elif key not in RESTRICTED_RESERVED_CONFIGURABLE_KEYS:
|
|
129
|
+
extra[key] = value
|
|
130
|
+
if extra:
|
|
131
|
+
proto.extra_configurable_json.update(
|
|
132
|
+
{k: orjson.dumps(v) for k, v in extra.items()}
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def context_to_json_bytes(context: dict[str, Any] | Any) -> bytes | None:
|
|
137
|
+
"""Convert context to JSON bytes for proto serialization."""
|
|
138
|
+
if context is None:
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
# Convert dataclass or other objects to dict if needed
|
|
142
|
+
if hasattr(context, "__dict__") and not hasattr(context, "items"):
|
|
143
|
+
# Convert dataclass to dict
|
|
144
|
+
context_dict = context.__dict__
|
|
145
|
+
elif hasattr(context, "items"):
|
|
146
|
+
# Already a dict-like object
|
|
147
|
+
context_dict = dict(context)
|
|
148
|
+
else:
|
|
149
|
+
# Try to convert to dict using vars()
|
|
150
|
+
context_dict = vars(context) if hasattr(context, "__dict__") else {}
|
|
151
|
+
|
|
152
|
+
return orjson.dumps(context_dict)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def config_from_proto(
|
|
156
|
+
config_proto: engine_common_pb2.EngineRunnableConfig | None,
|
|
157
|
+
) -> RunnableConfig:
|
|
158
|
+
if not config_proto:
|
|
159
|
+
return RunnableConfig(tags=[], metadata={}, configurable={})
|
|
160
|
+
|
|
161
|
+
configurable = _configurable_from_proto(config_proto)
|
|
162
|
+
|
|
163
|
+
metadata = {}
|
|
164
|
+
for k, v in config_proto.metadata_json.items():
|
|
165
|
+
metadata[k] = orjson.loads(v)
|
|
166
|
+
if config_proto.HasField("run_attempt"):
|
|
167
|
+
metadata["run_attempt"] = config_proto.run_attempt
|
|
168
|
+
if config_proto.HasField("server_run_id"):
|
|
169
|
+
metadata["run_id"] = config_proto.server_run_id
|
|
170
|
+
|
|
171
|
+
config = RunnableConfig()
|
|
172
|
+
if config_proto.extra_json:
|
|
173
|
+
for k, v in config_proto.extra_json.items():
|
|
174
|
+
config[k] = orjson.loads(v) # type: ignore[invalid-key]
|
|
175
|
+
if config_proto.tags:
|
|
176
|
+
config["tags"] = list(config_proto.tags)
|
|
177
|
+
if metadata:
|
|
178
|
+
config["metadata"] = metadata
|
|
179
|
+
if configurable:
|
|
180
|
+
config["configurable"] = configurable
|
|
181
|
+
if config_proto.HasField("run_name"):
|
|
182
|
+
config["run_name"] = config_proto.run_name
|
|
183
|
+
|
|
184
|
+
if config_proto.HasField("max_concurrency"):
|
|
185
|
+
config["max_concurrency"] = config_proto.max_concurrency
|
|
186
|
+
|
|
187
|
+
if config_proto.HasField("recursion_limit"):
|
|
188
|
+
config["recursion_limit"] = config_proto.recursion_limit
|
|
189
|
+
|
|
190
|
+
return config
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _configurable_from_proto(
|
|
194
|
+
config_proto: engine_common_pb2.EngineRunnableConfig,
|
|
195
|
+
) -> dict[str, Any]:
|
|
196
|
+
configurable = {}
|
|
197
|
+
|
|
198
|
+
if config_proto.HasField("resuming"):
|
|
199
|
+
configurable[CONFIG_KEY_RESUMING] = config_proto.resuming
|
|
200
|
+
|
|
201
|
+
if config_proto.HasField("task_id"):
|
|
202
|
+
configurable[CONFIG_KEY_TASK_ID] = config_proto.task_id
|
|
203
|
+
|
|
204
|
+
if config_proto.HasField("thread_id"):
|
|
205
|
+
configurable[CONFIG_KEY_THREAD_ID] = config_proto.thread_id
|
|
206
|
+
|
|
207
|
+
if config_proto.HasField("checkpoint_id"):
|
|
208
|
+
configurable[CONFIG_KEY_CHECKPOINT_ID] = config_proto.checkpoint_id
|
|
209
|
+
|
|
210
|
+
if config_proto.HasField("checkpoint_ns"):
|
|
211
|
+
configurable[CONFIG_KEY_CHECKPOINT_NS] = config_proto.checkpoint_ns
|
|
212
|
+
|
|
213
|
+
if config_proto.HasField("durability"):
|
|
214
|
+
durability = _durability_from_proto(config_proto.durability)
|
|
215
|
+
if durability:
|
|
216
|
+
configurable[CONFIG_KEY_DURABILITY] = durability
|
|
217
|
+
|
|
218
|
+
if config_proto.HasField("graph_id"):
|
|
219
|
+
configurable[CONFIG_KEY_GRAPH_ID] = config_proto.graph_id
|
|
220
|
+
|
|
221
|
+
if config_proto.HasField("run_id"):
|
|
222
|
+
configurable["run_id"] = config_proto.run_id
|
|
223
|
+
|
|
224
|
+
if len(config_proto.checkpoint_map) > 0:
|
|
225
|
+
configurable[CONFIG_KEY_CHECKPOINT_MAP] = dict(config_proto.checkpoint_map)
|
|
226
|
+
|
|
227
|
+
if len(config_proto.extra_configurable_json) > 0:
|
|
228
|
+
for k, v in config_proto.extra_configurable_json.items():
|
|
229
|
+
configurable[k] = orjson.loads(v)
|
|
230
|
+
|
|
231
|
+
return configurable
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Generated protobuf files
|
|
2
|
+
# Import enum files first to avoid circular imports
|
|
3
|
+
from . import enum_cancel_run_action_pb2
|
|
4
|
+
from . import enum_control_signal_pb2
|
|
5
|
+
from . import enum_durability_pb2
|
|
6
|
+
from . import enum_multitask_strategy_pb2
|
|
7
|
+
from . import enum_run_status_pb2
|
|
8
|
+
from . import enum_stream_mode_pb2
|
|
9
|
+
from . import enum_thread_status_pb2
|
|
10
|
+
from . import enum_thread_stream_mode_pb2
|
|
11
|
+
from . import core_api_pb2
|
|
12
|
+
from . import core_api_pb2_grpc
|
|
13
|
+
from . import checkpointer_pb2
|
|
14
|
+
from . import checkpointer_pb2_grpc
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"core_api_pb2",
|
|
18
|
+
"core_api_pb2_grpc",
|
|
19
|
+
"checkpointer_pb2",
|
|
20
|
+
"checkpointer_pb2_grpc",
|
|
21
|
+
"enum_cancel_run_action_pb2",
|
|
22
|
+
"enum_control_signal_pb2",
|
|
23
|
+
"enum_durability_pb2",
|
|
24
|
+
"enum_multitask_strategy_pb2",
|
|
25
|
+
"enum_run_status_pb2",
|
|
26
|
+
"enum_stream_mode_pb2",
|
|
27
|
+
"enum_thread_status_pb2",
|
|
28
|
+
"enum_thread_stream_mode_pb2",
|
|
29
|
+
]
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: checkpointer.proto
|
|
5
|
+
# Protobuf Python Version: 6.31.1
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
6,
|
|
15
|
+
31,
|
|
16
|
+
1,
|
|
17
|
+
'',
|
|
18
|
+
'checkpointer.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2
|
|
26
|
+
from . import engine_common_pb2 as engine__common__pb2
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x63heckpointer.proto\x12\x0c\x63heckpointer\x1a\x1bgoogle/protobuf/empty.proto\x1a\x13\x65ngine-common.proto\"\x97\x02\n\nPutRequest\x12\x32\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\".engineCommon.EngineRunnableConfig\x12,\n\ncheckpoint\x18\x02 \x01(\x0b\x32\x18.engineCommon.Checkpoint\x12\x32\n\x08metadata\x18\x03 \x01(\x0b\x32 .engineCommon.CheckpointMetadata\x12?\n\x0cnew_versions\x18\x04 \x03(\x0b\x32).checkpointer.PutRequest.NewVersionsEntry\x1a\x32\n\x10NewVersionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x8f\x01\n\x10PutWritesRequest\x12\x32\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\".engineCommon.EngineRunnableConfig\x12#\n\x06writes\x18\x02 \x03(\x0b\x32\x13.engineCommon.Write\x12\x0f\n\x07task_id\x18\x03 \x01(\t\x12\x11\n\ttask_path\x18\x04 \x01(\t\"\xa8\x01\n\x0bListRequest\x12\x32\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\".engineCommon.EngineRunnableConfig\x12\x13\n\x0b\x66ilter_json\x18\x02 \x01(\x0c\x12\x32\n\x06\x62\x65\x66ore\x18\x03 \x01(\x0b\x32\".engineCommon.EngineRunnableConfig\x12\x12\n\x05limit\x18\x04 \x01(\x03H\x00\x88\x01\x01\x42\x08\n\x06_limit\"(\n\x13\x44\x65leteThreadRequest\x12\x11\n\tthread_id\x18\x01 \x01(\t\"\xa1\x01\n\x0cPruneRequest\x12\x12\n\nthread_ids\x18\x01 \x03(\t\x12:\n\x08strategy\x18\x02 \x01(\x0e\x32(.checkpointer.PruneRequest.PruneStrategy\"A\n\rPruneStrategy\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x0f\n\x0bKEEP_LATEST\x10\x01\x12\x0e\n\nDELETE_ALL\x10\x02\"E\n\x0fGetTupleRequest\x12\x32\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\".engineCommon.EngineRunnableConfig\"F\n\x0bPutResponse\x12\x37\n\x0bnext_config\x18\x01 \x01(\x0b\x32\".engineCommon.EngineRunnableConfig\"H\n\x0cListResponse\x12\x38\n\x11\x63heckpoint_tuples\x18\x01 \x03(\x0b\x32\x1d.engineCommon.CheckpointTuple\"e\n\x10GetTupleResponse\x12<\n\x10\x63heckpoint_tuple\x18\x01 \x01(\x0b\x32\x1d.engineCommon.CheckpointTupleH\x00\x88\x01\x01\x42\x13\n\x11_checkpoint_tuple2\xa1\x03\n\x0c\x43heckpointer\x12:\n\x03Put\x12\x18.checkpointer.PutRequest\x1a\x19.checkpointer.PutResponse\x12\x43\n\tPutWrites\x12\x1e.checkpointer.PutWritesRequest\x1a\x16.google.protobuf.Empty\x12=\n\x04List\x12\x19.checkpointer.ListRequest\x1a\x1a.checkpointer.ListResponse\x12I\n\x08GetTuple\x12\x1d.checkpointer.GetTupleRequest\x1a\x1e.checkpointer.GetTupleResponse\x12I\n\x0c\x44\x65leteThread\x12!.checkpointer.DeleteThreadRequest\x1a\x16.google.protobuf.Empty\x12;\n\x05Prune\x12\x1a.checkpointer.PruneRequest\x1a\x16.google.protobuf.EmptyB?Z=github.com/langchain-ai/langgraph-api/core/internal/engine/pbb\x06proto3')
|
|
30
|
+
|
|
31
|
+
_globals = globals()
|
|
32
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
33
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'checkpointer_pb2', _globals)
|
|
34
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
35
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
36
|
+
_globals['DESCRIPTOR']._serialized_options = b'Z=github.com/langchain-ai/langgraph-api/core/internal/engine/pb'
|
|
37
|
+
_globals['_PUTREQUEST_NEWVERSIONSENTRY']._loaded_options = None
|
|
38
|
+
_globals['_PUTREQUEST_NEWVERSIONSENTRY']._serialized_options = b'8\001'
|
|
39
|
+
_globals['_PUTREQUEST']._serialized_start=87
|
|
40
|
+
_globals['_PUTREQUEST']._serialized_end=366
|
|
41
|
+
_globals['_PUTREQUEST_NEWVERSIONSENTRY']._serialized_start=316
|
|
42
|
+
_globals['_PUTREQUEST_NEWVERSIONSENTRY']._serialized_end=366
|
|
43
|
+
_globals['_PUTWRITESREQUEST']._serialized_start=369
|
|
44
|
+
_globals['_PUTWRITESREQUEST']._serialized_end=512
|
|
45
|
+
_globals['_LISTREQUEST']._serialized_start=515
|
|
46
|
+
_globals['_LISTREQUEST']._serialized_end=683
|
|
47
|
+
_globals['_DELETETHREADREQUEST']._serialized_start=685
|
|
48
|
+
_globals['_DELETETHREADREQUEST']._serialized_end=725
|
|
49
|
+
_globals['_PRUNEREQUEST']._serialized_start=728
|
|
50
|
+
_globals['_PRUNEREQUEST']._serialized_end=889
|
|
51
|
+
_globals['_PRUNEREQUEST_PRUNESTRATEGY']._serialized_start=824
|
|
52
|
+
_globals['_PRUNEREQUEST_PRUNESTRATEGY']._serialized_end=889
|
|
53
|
+
_globals['_GETTUPLEREQUEST']._serialized_start=891
|
|
54
|
+
_globals['_GETTUPLEREQUEST']._serialized_end=960
|
|
55
|
+
_globals['_PUTRESPONSE']._serialized_start=962
|
|
56
|
+
_globals['_PUTRESPONSE']._serialized_end=1032
|
|
57
|
+
_globals['_LISTRESPONSE']._serialized_start=1034
|
|
58
|
+
_globals['_LISTRESPONSE']._serialized_end=1106
|
|
59
|
+
_globals['_GETTUPLERESPONSE']._serialized_start=1108
|
|
60
|
+
_globals['_GETTUPLERESPONSE']._serialized_end=1209
|
|
61
|
+
_globals['_CHECKPOINTER']._serialized_start=1212
|
|
62
|
+
_globals['_CHECKPOINTER']._serialized_end=1629
|
|
63
|
+
# @@protoc_insertion_point(module_scope)
|