langgraph-api 0.4.48__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/__init__.py +1 -1
- langgraph_api/api/assistants.py +65 -61
- langgraph_api/api/meta.py +6 -0
- langgraph_api/api/threads.py +1 -1
- langgraph_api/auth/custom.py +29 -24
- langgraph_api/config.py +56 -1
- langgraph_api/graph.py +1 -1
- langgraph_api/{grpc_ops → grpc}/client.py +91 -0
- langgraph_api/grpc/config_conversion.py +225 -0
- langgraph_api/grpc/generated/core_api_pb2.py +275 -0
- langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2.pyi +20 -31
- langgraph_api/{grpc_ops → grpc}/generated/core_api_pb2_grpc.py +2 -2
- langgraph_api/grpc/generated/engine_common_pb2.py +190 -0
- langgraph_api/grpc/generated/engine_common_pb2.pyi +634 -0
- langgraph_api/grpc/generated/engine_common_pb2_grpc.py +24 -0
- langgraph_api/{grpc_ops → grpc}/ops.py +75 -217
- langgraph_api/js/package.json +5 -5
- langgraph_api/js/src/graph.mts +20 -0
- langgraph_api/js/yarn.lock +137 -187
- langgraph_api/queue_entrypoint.py +2 -2
- langgraph_api/route.py +14 -4
- langgraph_api/schema.py +2 -2
- langgraph_api/self_hosted_metrics.py +48 -2
- langgraph_api/serde.py +58 -14
- langgraph_api/worker.py +1 -1
- {langgraph_api-0.4.48.dist-info → langgraph_api-0.5.6.dist-info}/METADATA +5 -5
- {langgraph_api-0.4.48.dist-info → langgraph_api-0.5.6.dist-info}/RECORD +32 -28
- langgraph_api/grpc_ops/generated/core_api_pb2.py +0 -276
- /langgraph_api/{grpc_ops → grpc}/__init__.py +0 -0
- /langgraph_api/{grpc_ops → grpc}/generated/__init__.py +0 -0
- {langgraph_api-0.4.48.dist-info → langgraph_api-0.5.6.dist-info}/WHEEL +0 -0
- {langgraph_api-0.4.48.dist-info → langgraph_api-0.5.6.dist-info}/entry_points.txt +0 -0
- {langgraph_api-0.4.48.dist-info → langgraph_api-0.5.6.dist-info}/licenses/LICENSE +0 -0
langgraph_api/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.
|
|
1
|
+
__version__ = "0.5.6"
|
langgraph_api/api/assistants.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from functools import partial
|
|
1
2
|
from typing import Any
|
|
2
3
|
from uuid import uuid4
|
|
3
4
|
|
|
@@ -15,7 +16,7 @@ from starlette.routing import BaseRoute
|
|
|
15
16
|
from langgraph_api import store as api_store
|
|
16
17
|
from langgraph_api.feature_flags import FF_USE_CORE_API, USE_RUNTIME_CONTEXT_API
|
|
17
18
|
from langgraph_api.graph import get_assistant_id, get_graph
|
|
18
|
-
from langgraph_api.
|
|
19
|
+
from langgraph_api.grpc.ops import Assistants as GrpcAssistants
|
|
19
20
|
from langgraph_api.js.base import BaseRemotePregel
|
|
20
21
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
21
22
|
from langgraph_api.schema import ASSISTANT_FIELDS
|
|
@@ -37,7 +38,7 @@ from langgraph_api.validation import (
|
|
|
37
38
|
ConfigValidator,
|
|
38
39
|
)
|
|
39
40
|
from langgraph_runtime.checkpoint import Checkpointer
|
|
40
|
-
from langgraph_runtime.database import connect
|
|
41
|
+
from langgraph_runtime.database import connect as base_connect
|
|
41
42
|
from langgraph_runtime.ops import Assistants
|
|
42
43
|
from langgraph_runtime.retry import retry_db
|
|
43
44
|
|
|
@@ -45,6 +46,8 @@ logger = structlog.stdlib.get_logger(__name__)
|
|
|
45
46
|
|
|
46
47
|
CrudAssistants = GrpcAssistants if FF_USE_CORE_API else Assistants
|
|
47
48
|
|
|
49
|
+
connect = partial(base_connect, supports_core_api=FF_USE_CORE_API)
|
|
50
|
+
|
|
48
51
|
EXCLUDED_CONFIG_SCHEMA = (
|
|
49
52
|
"__pregel_checkpointer",
|
|
50
53
|
"__pregel_store",
|
|
@@ -255,7 +258,7 @@ async def get_assistant_graph(
|
|
|
255
258
|
assistant_id = get_assistant_id(str(request.path_params["assistant_id"]))
|
|
256
259
|
validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
|
|
257
260
|
async with connect() as conn:
|
|
258
|
-
assistant_ = await
|
|
261
|
+
assistant_ = await CrudAssistants.get(conn, assistant_id)
|
|
259
262
|
assistant = await fetchone(assistant_)
|
|
260
263
|
config = json_loads(assistant["config"])
|
|
261
264
|
configurable = config.setdefault("configurable", {})
|
|
@@ -312,43 +315,44 @@ async def get_assistant_subgraphs(
|
|
|
312
315
|
assistant_id = request.path_params["assistant_id"]
|
|
313
316
|
validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
|
|
314
317
|
async with connect() as conn:
|
|
315
|
-
assistant_ = await
|
|
318
|
+
assistant_ = await CrudAssistants.get(conn, assistant_id)
|
|
316
319
|
assistant = await fetchone(assistant_)
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
320
|
+
|
|
321
|
+
config = json_loads(assistant["config"])
|
|
322
|
+
configurable = config.setdefault("configurable", {})
|
|
323
|
+
configurable.update(get_configurable_headers(request.headers))
|
|
324
|
+
async with get_graph(
|
|
325
|
+
assistant["graph_id"],
|
|
326
|
+
config,
|
|
327
|
+
checkpointer=Checkpointer(),
|
|
328
|
+
store=(await api_store.get_store()),
|
|
329
|
+
) as graph:
|
|
330
|
+
namespace = request.path_params.get("namespace")
|
|
331
|
+
|
|
332
|
+
if isinstance(graph, BaseRemotePregel):
|
|
333
|
+
return ApiResponse(
|
|
334
|
+
await graph.fetch_subgraphs(
|
|
335
|
+
namespace=namespace,
|
|
336
|
+
recurse=request.query_params.get("recurse", "False")
|
|
337
|
+
in ("true", "True"),
|
|
338
|
+
)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
try:
|
|
342
|
+
return ApiResponse(
|
|
343
|
+
{
|
|
344
|
+
ns: _graph_schemas(subgraph)
|
|
345
|
+
async for ns, subgraph in graph.aget_subgraphs(
|
|
331
346
|
namespace=namespace,
|
|
332
347
|
recurse=request.query_params.get("recurse", "False")
|
|
333
348
|
in ("true", "True"),
|
|
334
349
|
)
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
async for ns, subgraph in graph.aget_subgraphs(
|
|
342
|
-
namespace=namespace,
|
|
343
|
-
recurse=request.query_params.get("recurse", "False")
|
|
344
|
-
in ("true", "True"),
|
|
345
|
-
)
|
|
346
|
-
}
|
|
347
|
-
)
|
|
348
|
-
except NotImplementedError:
|
|
349
|
-
raise HTTPException(
|
|
350
|
-
422, detail="The graph does not support visualization"
|
|
351
|
-
) from None
|
|
350
|
+
}
|
|
351
|
+
)
|
|
352
|
+
except NotImplementedError:
|
|
353
|
+
raise HTTPException(
|
|
354
|
+
422, detail="The graph does not support visualization"
|
|
355
|
+
) from None
|
|
352
356
|
|
|
353
357
|
|
|
354
358
|
@retry_db
|
|
@@ -359,40 +363,40 @@ async def get_assistant_schemas(
|
|
|
359
363
|
assistant_id = request.path_params["assistant_id"]
|
|
360
364
|
validate_uuid(assistant_id, "Invalid assistant ID: must be a UUID")
|
|
361
365
|
async with connect() as conn:
|
|
362
|
-
assistant_ = await
|
|
363
|
-
# TODO Implementa cache so we can de-dent and release this connection.
|
|
366
|
+
assistant_ = await CrudAssistants.get(conn, assistant_id)
|
|
364
367
|
assistant = await fetchone(assistant_)
|
|
365
|
-
config = json_loads(assistant["config"])
|
|
366
|
-
configurable = config.setdefault("configurable", {})
|
|
367
|
-
configurable.update(get_configurable_headers(request.headers))
|
|
368
|
-
async with get_graph(
|
|
369
|
-
assistant["graph_id"],
|
|
370
|
-
config,
|
|
371
|
-
checkpointer=Checkpointer(),
|
|
372
|
-
store=(await api_store.get_store()),
|
|
373
|
-
) as graph:
|
|
374
|
-
if isinstance(graph, BaseRemotePregel):
|
|
375
|
-
schemas = await graph.fetch_state_schema()
|
|
376
|
-
return ApiResponse(
|
|
377
|
-
{
|
|
378
|
-
"graph_id": assistant["graph_id"],
|
|
379
|
-
"input_schema": schemas.get("input"),
|
|
380
|
-
"output_schema": schemas.get("output"),
|
|
381
|
-
"state_schema": schemas.get("state"),
|
|
382
|
-
"config_schema": schemas.get("config"),
|
|
383
|
-
"context_schema": schemas.get("context"),
|
|
384
|
-
}
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
schemas = _graph_schemas(graph)
|
|
388
368
|
|
|
369
|
+
config = json_loads(assistant["config"])
|
|
370
|
+
configurable = config.setdefault("configurable", {})
|
|
371
|
+
configurable.update(get_configurable_headers(request.headers))
|
|
372
|
+
async with get_graph(
|
|
373
|
+
assistant["graph_id"],
|
|
374
|
+
config,
|
|
375
|
+
checkpointer=Checkpointer(),
|
|
376
|
+
store=(await api_store.get_store()),
|
|
377
|
+
) as graph:
|
|
378
|
+
if isinstance(graph, BaseRemotePregel):
|
|
379
|
+
schemas = await graph.fetch_state_schema()
|
|
389
380
|
return ApiResponse(
|
|
390
381
|
{
|
|
391
382
|
"graph_id": assistant["graph_id"],
|
|
392
|
-
|
|
383
|
+
"input_schema": schemas.get("input"),
|
|
384
|
+
"output_schema": schemas.get("output"),
|
|
385
|
+
"state_schema": schemas.get("state"),
|
|
386
|
+
"config_schema": schemas.get("config"),
|
|
387
|
+
"context_schema": schemas.get("context"),
|
|
393
388
|
}
|
|
394
389
|
)
|
|
395
390
|
|
|
391
|
+
schemas = _graph_schemas(graph)
|
|
392
|
+
|
|
393
|
+
return ApiResponse(
|
|
394
|
+
{
|
|
395
|
+
"graph_id": assistant["graph_id"],
|
|
396
|
+
**schemas,
|
|
397
|
+
}
|
|
398
|
+
)
|
|
399
|
+
|
|
396
400
|
|
|
397
401
|
@retry_db
|
|
398
402
|
async def patch_assistant(
|
langgraph_api/api/meta.py
CHANGED
|
@@ -86,6 +86,12 @@ async def meta_metrics(request: ApiRequest):
|
|
|
86
86
|
"# HELP lg_api_num_running_runs The number of runs currently running.",
|
|
87
87
|
"# TYPE lg_api_num_running_runs gauge",
|
|
88
88
|
f'lg_api_num_running_runs{{project_id="{metadata.PROJECT_ID}", revision_id="{metadata.HOST_REVISION_ID}"}} {queue_stats["n_running"]}',
|
|
89
|
+
"# HELP lg_api_pending_runs_wait_time_max The maximum time a run has been pending, in seconds.",
|
|
90
|
+
"# TYPE lg_api_pending_runs_wait_time_max gauge",
|
|
91
|
+
f'lg_api_pending_runs_wait_time_max{{project_id="{metadata.PROJECT_ID}", revision_id="{metadata.HOST_REVISION_ID}"}} {queue_stats.get("pending_runs_wait_time_max_secs") or 0}',
|
|
92
|
+
"# HELP lg_api_pending_runs_wait_time_med The median pending wait time across runs, in seconds.",
|
|
93
|
+
"# TYPE lg_api_pending_runs_wait_time_med gauge",
|
|
94
|
+
f'lg_api_pending_runs_wait_time_med{{project_id="{metadata.PROJECT_ID}", revision_id="{metadata.HOST_REVISION_ID}"}} {queue_stats.get("pending_runs_wait_time_med_secs") or 0}',
|
|
89
95
|
]
|
|
90
96
|
)
|
|
91
97
|
except Exception as e:
|
langgraph_api/api/threads.py
CHANGED
|
@@ -6,7 +6,7 @@ from starlette.responses import Response
|
|
|
6
6
|
from starlette.routing import BaseRoute
|
|
7
7
|
|
|
8
8
|
from langgraph_api.feature_flags import FF_USE_CORE_API
|
|
9
|
-
from langgraph_api.
|
|
9
|
+
from langgraph_api.grpc.ops import Threads as GrpcThreads
|
|
10
10
|
from langgraph_api.route import ApiRequest, ApiResponse, ApiRoute
|
|
11
11
|
from langgraph_api.schema import THREAD_FIELDS, ThreadStreamMode
|
|
12
12
|
from langgraph_api.sse import EventSourceResponse
|
langgraph_api/auth/custom.py
CHANGED
|
@@ -355,34 +355,39 @@ def _solve_fastapi_dependencies(
|
|
|
355
355
|
}
|
|
356
356
|
|
|
357
357
|
async def decorator(scope: dict, request: Request):
|
|
358
|
-
async with AsyncExitStack() as
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
358
|
+
async with AsyncExitStack() as request_stack:
|
|
359
|
+
scope["fastapi_inner_astack"] = request_stack
|
|
360
|
+
async with AsyncExitStack() as stack:
|
|
361
|
+
scope["fastapi_function_astack"] = stack
|
|
362
|
+
all_solved = await asyncio.gather(
|
|
363
|
+
*(
|
|
364
|
+
solve_dependencies(
|
|
365
|
+
request=request,
|
|
366
|
+
dependant=dependent,
|
|
367
|
+
async_exit_stack=stack,
|
|
368
|
+
embed_body_fields=False,
|
|
369
|
+
)
|
|
370
|
+
for dependent in dependents.values()
|
|
366
371
|
)
|
|
367
|
-
for dependent in dependents.values()
|
|
368
372
|
)
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
373
|
+
all_injected = await asyncio.gather(
|
|
374
|
+
*(
|
|
375
|
+
_run_async(dependent.call, solved.values, is_async)
|
|
376
|
+
for dependent, solved in zip(
|
|
377
|
+
dependents.values(), all_solved, strict=False
|
|
378
|
+
)
|
|
375
379
|
)
|
|
376
380
|
)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
381
|
+
kwargs = {
|
|
382
|
+
name: value
|
|
383
|
+
for name, value in zip(
|
|
384
|
+
dependents.keys(), all_injected, strict=False
|
|
385
|
+
)
|
|
386
|
+
}
|
|
387
|
+
other_params = _extract_arguments_from_scope(
|
|
388
|
+
scope, _param_names, request=request
|
|
389
|
+
)
|
|
390
|
+
return await fn(**(kwargs | other_params))
|
|
386
391
|
|
|
387
392
|
return decorator
|
|
388
393
|
|
langgraph_api/config.py
CHANGED
|
@@ -128,6 +128,45 @@ class StoreConfig(TypedDict, total=False):
|
|
|
128
128
|
ttl: TTLConfig
|
|
129
129
|
|
|
130
130
|
|
|
131
|
+
class SerdeConfig(TypedDict, total=False):
|
|
132
|
+
"""Configuration for the built-in serde, which handles checkpointing of state.
|
|
133
|
+
|
|
134
|
+
If omitted, no serde is set up (the object store will still be present, however)."""
|
|
135
|
+
|
|
136
|
+
allowed_json_modules: list[list[str]] | Literal[True] | None
|
|
137
|
+
"""Optional. List of allowed python modules to de-serialize custom objects from.
|
|
138
|
+
|
|
139
|
+
If provided, only the specified modules will be allowed to be deserialized.
|
|
140
|
+
If omitted, no modules are allowed, and the object returned will simply be a json object OR
|
|
141
|
+
a deserialized langchain object.
|
|
142
|
+
|
|
143
|
+
Example:
|
|
144
|
+
{...
|
|
145
|
+
"serde": {
|
|
146
|
+
"allowed_json_modules": [
|
|
147
|
+
["my_agent", "my_file", "SomeType"],
|
|
148
|
+
]
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
If you set this to True, any module will be allowed to be deserialized.
|
|
153
|
+
|
|
154
|
+
Example:
|
|
155
|
+
{...
|
|
156
|
+
"serde": {
|
|
157
|
+
"allowed_json_modules": true
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
pickle_fallback: bool
|
|
163
|
+
"""Optional. Whether to allow pickling as a fallback for deserialization.
|
|
164
|
+
|
|
165
|
+
If True, pickling will be allowed as a fallback for deserialization.
|
|
166
|
+
If False, pickling will not be allowed as a fallback for deserialization.
|
|
167
|
+
Defaults to True if not configured."""
|
|
168
|
+
|
|
169
|
+
|
|
131
170
|
class CheckpointerConfig(TypedDict, total=False):
|
|
132
171
|
"""Configuration for the built-in checkpointer, which handles checkpointing of state.
|
|
133
172
|
|
|
@@ -140,6 +179,8 @@ class CheckpointerConfig(TypedDict, total=False):
|
|
|
140
179
|
If provided, the checkpointer will apply TTL settings according to the configuration.
|
|
141
180
|
If omitted, no TTL behavior is configured.
|
|
142
181
|
"""
|
|
182
|
+
serde: SerdeConfig | None
|
|
183
|
+
"""Optional. Defines the configuration for how checkpoints are serialized."""
|
|
143
184
|
|
|
144
185
|
|
|
145
186
|
class SecurityConfig(TypedDict, total=False):
|
|
@@ -240,6 +281,9 @@ REDIS_URI = env("REDIS_URI", cast=str)
|
|
|
240
281
|
REDIS_CLUSTER = env("REDIS_CLUSTER", cast=bool, default=False)
|
|
241
282
|
REDIS_MAX_CONNECTIONS = env("REDIS_MAX_CONNECTIONS", cast=int, default=2000)
|
|
242
283
|
REDIS_CONNECT_TIMEOUT = env("REDIS_CONNECT_TIMEOUT", cast=float, default=10.0)
|
|
284
|
+
REDIS_HEALTH_CHECK_INTERVAL = env(
|
|
285
|
+
"REDIS_HEALTH_CHECK_INTERVAL", cast=float, default=10.0
|
|
286
|
+
)
|
|
243
287
|
REDIS_KEY_PREFIX = env("REDIS_KEY_PREFIX", cast=str, default="")
|
|
244
288
|
RUN_STATS_CACHE_SECONDS = env("RUN_STATS_CACHE_SECONDS", cast=int, default=60)
|
|
245
289
|
|
|
@@ -250,6 +294,13 @@ ALLOW_PRIVATE_NETWORK = env("ALLOW_PRIVATE_NETWORK", cast=bool, default=False)
|
|
|
250
294
|
See https://developer.chrome.com/blog/private-network-access-update-2024-03
|
|
251
295
|
"""
|
|
252
296
|
|
|
297
|
+
# gRPC client pool size for persistence server.
|
|
298
|
+
GRPC_CLIENT_POOL_SIZE = env("GRPC_CLIENT_POOL_SIZE", cast=int, default=5)
|
|
299
|
+
|
|
300
|
+
# Minimum payload size to use the dedicated thread pool for JSON parsing.
|
|
301
|
+
# (Otherwise, the payload is parsed directly in the event loop.)
|
|
302
|
+
JSON_THREAD_POOL_MINIMUM_SIZE_BYTES = 100 * 1024 # 100 KB
|
|
303
|
+
|
|
253
304
|
HTTP_CONFIG = env("LANGGRAPH_HTTP", cast=_parse_schema(HttpConfig), default=None)
|
|
254
305
|
STORE_CONFIG = env("LANGGRAPH_STORE", cast=_parse_schema(StoreConfig), default=None)
|
|
255
306
|
|
|
@@ -339,6 +390,11 @@ def _parse_thread_ttl(value: str | None) -> ThreadTTLConfig | None:
|
|
|
339
390
|
CHECKPOINTER_CONFIG = env(
|
|
340
391
|
"LANGGRAPH_CHECKPOINTER", cast=_parse_schema(CheckpointerConfig), default=None
|
|
341
392
|
)
|
|
393
|
+
SERDE: SerdeConfig | None = (
|
|
394
|
+
CHECKPOINTER_CONFIG["serde"]
|
|
395
|
+
if CHECKPOINTER_CONFIG and "serde" in CHECKPOINTER_CONFIG
|
|
396
|
+
else None
|
|
397
|
+
)
|
|
342
398
|
THREAD_TTL: ThreadTTLConfig | None = env(
|
|
343
399
|
"LANGGRAPH_THREAD_TTL", cast=_parse_thread_ttl, default=None
|
|
344
400
|
)
|
|
@@ -349,7 +405,6 @@ N_JOBS_PER_WORKER = env("N_JOBS_PER_WORKER", cast=int, default=10)
|
|
|
349
405
|
BG_JOB_TIMEOUT_SECS = env("BG_JOB_TIMEOUT_SECS", cast=float, default=3600)
|
|
350
406
|
|
|
351
407
|
FF_CRONS_ENABLED = env("FF_CRONS_ENABLED", cast=bool, default=True)
|
|
352
|
-
FF_RICH_THREADS = env("FF_RICH_THREADS", cast=bool, default=True)
|
|
353
408
|
FF_LOG_DROPPED_EVENTS = env("FF_LOG_DROPPED_EVENTS", cast=bool, default=False)
|
|
354
409
|
FF_LOG_QUERY_AND_PARAMS = env("FF_LOG_QUERY_AND_PARAMS", cast=bool, default=False)
|
|
355
410
|
|
langgraph_api/graph.py
CHANGED
|
@@ -51,7 +51,7 @@ async def register_graph(
|
|
|
51
51
|
description: str | None = None,
|
|
52
52
|
) -> None:
|
|
53
53
|
"""Register a graph."""
|
|
54
|
-
from langgraph_api.
|
|
54
|
+
from langgraph_api.grpc.ops import Assistants as AssistantsGrpc
|
|
55
55
|
from langgraph_runtime.database import connect
|
|
56
56
|
from langgraph_runtime.ops import Assistants as AssistantsRuntime
|
|
57
57
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""gRPC client wrapper for LangGraph persistence services."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import os
|
|
4
5
|
|
|
5
6
|
import structlog
|
|
@@ -10,6 +11,10 @@ from .generated.core_api_pb2_grpc import AdminStub, AssistantsStub, ThreadsStub
|
|
|
10
11
|
logger = structlog.stdlib.get_logger(__name__)
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
# Shared global client pool
|
|
15
|
+
_client_pool: "GrpcClientPool | None" = None
|
|
16
|
+
|
|
17
|
+
|
|
13
18
|
class GrpcClient:
|
|
14
19
|
"""gRPC client for LangGraph persistence services."""
|
|
15
20
|
|
|
@@ -90,3 +95,89 @@ class GrpcClient:
|
|
|
90
95
|
"Client not connected. Use async context manager or call connect() first."
|
|
91
96
|
)
|
|
92
97
|
return self._admin_stub
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class GrpcClientPool:
|
|
101
|
+
"""Pool of gRPC clients for load distribution."""
|
|
102
|
+
|
|
103
|
+
def __init__(self, pool_size: int = 5, server_address: str | None = None):
|
|
104
|
+
self.pool_size = pool_size
|
|
105
|
+
self.server_address = server_address
|
|
106
|
+
self.clients: list[GrpcClient] = []
|
|
107
|
+
self._current_index = 0
|
|
108
|
+
self._init_lock = asyncio.Lock()
|
|
109
|
+
self._initialized = False
|
|
110
|
+
|
|
111
|
+
async def _initialize(self):
|
|
112
|
+
"""Initialize the pool of clients."""
|
|
113
|
+
async with self._init_lock:
|
|
114
|
+
if self._initialized:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
await logger.ainfo(
|
|
118
|
+
"Initializing gRPC client pool",
|
|
119
|
+
pool_size=self.pool_size,
|
|
120
|
+
server_address=self.server_address,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
for _ in range(self.pool_size):
|
|
124
|
+
client = GrpcClient(server_address=self.server_address)
|
|
125
|
+
await client.connect()
|
|
126
|
+
self.clients.append(client)
|
|
127
|
+
|
|
128
|
+
self._initialized = True
|
|
129
|
+
await logger.ainfo(
|
|
130
|
+
f"gRPC client pool initialized with {self.pool_size} clients"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
async def get_client(self) -> GrpcClient:
|
|
134
|
+
"""Get next client using round-robin selection.
|
|
135
|
+
|
|
136
|
+
Round-robin without strict locking - slight races are acceptable
|
|
137
|
+
and result in good enough distribution under high load.
|
|
138
|
+
"""
|
|
139
|
+
if not self._initialized:
|
|
140
|
+
await self._initialize()
|
|
141
|
+
|
|
142
|
+
idx = self._current_index % self.pool_size
|
|
143
|
+
self._current_index = idx + 1
|
|
144
|
+
return self.clients[idx]
|
|
145
|
+
|
|
146
|
+
async def close(self):
|
|
147
|
+
"""Close all clients in the pool."""
|
|
148
|
+
if self._initialized:
|
|
149
|
+
await logger.ainfo(f"Closing gRPC client pool ({self.pool_size} clients)")
|
|
150
|
+
for client in self.clients:
|
|
151
|
+
await client.close()
|
|
152
|
+
self.clients.clear()
|
|
153
|
+
self._initialized = False
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
async def get_shared_client() -> GrpcClient:
|
|
157
|
+
"""Get a gRPC client from the shared pool.
|
|
158
|
+
|
|
159
|
+
Uses a pool of channels for better performance under high concurrency.
|
|
160
|
+
Each channel is a separate TCP connection that can handle ~100-200
|
|
161
|
+
concurrent streams effectively.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
A GrpcClient instance from the pool
|
|
165
|
+
"""
|
|
166
|
+
global _client_pool
|
|
167
|
+
if _client_pool is None:
|
|
168
|
+
from langgraph_api import config
|
|
169
|
+
|
|
170
|
+
_client_pool = GrpcClientPool(
|
|
171
|
+
pool_size=config.GRPC_CLIENT_POOL_SIZE,
|
|
172
|
+
server_address=os.getenv("GRPC_SERVER_ADDRESS"),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return await _client_pool.get_client()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
async def close_shared_client():
|
|
179
|
+
"""Close the shared gRPC client pool."""
|
|
180
|
+
global _client_pool
|
|
181
|
+
if _client_pool is not None:
|
|
182
|
+
await _client_pool.close()
|
|
183
|
+
_client_pool = None
|