langgraph-api 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/api/__init__.py +2 -1
- langgraph_api/api/assistants.py +4 -4
- langgraph_api/api/store.py +67 -15
- langgraph_api/asyncio.py +5 -0
- langgraph_api/auth/custom.py +20 -5
- langgraph_api/config.py +1 -0
- langgraph_api/graph.py +6 -13
- langgraph_api/js/base.py +9 -0
- langgraph_api/js/build.mts +2 -0
- langgraph_api/js/client.mts +383 -409
- langgraph_api/js/client.new.mts +856 -0
- langgraph_api/js/errors.py +11 -0
- langgraph_api/js/package.json +3 -1
- langgraph_api/js/remote.py +16 -673
- langgraph_api/js/remote_new.py +693 -0
- langgraph_api/js/remote_old.py +665 -0
- langgraph_api/js/schema.py +29 -0
- langgraph_api/js/src/utils/serde.mts +7 -0
- langgraph_api/js/tests/api.test.mts +125 -8
- langgraph_api/js/tests/compose-postgres.yml +2 -1
- langgraph_api/js/tests/graphs/agent.mts +2 -0
- langgraph_api/js/tests/graphs/delay.mts +30 -0
- langgraph_api/js/tests/graphs/langgraph.json +2 -1
- langgraph_api/js/yarn.lock +870 -18
- langgraph_api/models/run.py +1 -0
- langgraph_api/queue.py +129 -31
- langgraph_api/route.py +8 -3
- langgraph_api/schema.py +1 -1
- langgraph_api/stream.py +12 -5
- langgraph_api/utils.py +11 -5
- {langgraph_api-0.0.14.dist-info → langgraph_api-0.0.16.dist-info}/METADATA +3 -3
- {langgraph_api-0.0.14.dist-info → langgraph_api-0.0.16.dist-info}/RECORD +37 -30
- langgraph_storage/ops.py +9 -2
- openapi.json +5 -5
- {langgraph_api-0.0.14.dist-info → langgraph_api-0.0.16.dist-info}/LICENSE +0 -0
- {langgraph_api-0.0.14.dist-info → langgraph_api-0.0.16.dist-info}/WHEEL +0 -0
- {langgraph_api-0.0.14.dist-info → langgraph_api-0.0.16.dist-info}/entry_points.txt +0 -0
langgraph_api/js/remote.py
CHANGED
|
@@ -1,675 +1,18 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
import uvicorn
|
|
11
|
-
from langchain_core.runnables import Runnable
|
|
12
|
-
from langchain_core.runnables.config import RunnableConfig
|
|
13
|
-
from langchain_core.runnables.graph import Edge, Node
|
|
14
|
-
from langchain_core.runnables.graph import Graph as DrawableGraph
|
|
15
|
-
from langchain_core.runnables.schema import (
|
|
16
|
-
CustomStreamEvent,
|
|
17
|
-
StandardStreamEvent,
|
|
18
|
-
StreamEvent,
|
|
19
|
-
)
|
|
20
|
-
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
21
|
-
from langgraph.pregel.types import PregelTask, StateSnapshot
|
|
22
|
-
from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp
|
|
23
|
-
from langgraph.types import Command, Interrupt
|
|
24
|
-
from pydantic import BaseModel
|
|
25
|
-
from starlette.applications import Starlette
|
|
26
|
-
from starlette.requests import Request
|
|
27
|
-
from starlette.routing import Route
|
|
28
|
-
|
|
29
|
-
from langgraph_api.js.server_sent_events import aconnect_sse
|
|
30
|
-
from langgraph_api.route import ApiResponse
|
|
31
|
-
from langgraph_api.serde import json_dumpb
|
|
32
|
-
from langgraph_api.utils import AsyncConnectionProto
|
|
33
|
-
|
|
34
|
-
logger = structlog.stdlib.get_logger(__name__)
|
|
35
|
-
|
|
36
|
-
GRAPH_SOCKET = "./graph.sock"
|
|
37
|
-
CHECKPOINTER_SOCKET = "./checkpointer.sock"
|
|
38
|
-
STORE_SOCKET = "./store.sock"
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
class NoopModel(BaseModel):
|
|
42
|
-
pass
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class RemoteException(Exception):
|
|
46
|
-
error: str
|
|
47
|
-
|
|
48
|
-
def __init__(self, error: str, *args: object) -> None:
|
|
49
|
-
super().__init__(*args)
|
|
50
|
-
self.error = error
|
|
51
|
-
|
|
52
|
-
# Used to nudge the serde to encode like BaseException
|
|
53
|
-
# @see /api/langgraph_api/shared/serde.py:default
|
|
54
|
-
def dict(self):
|
|
55
|
-
return {"error": self.error, "message": str(self)}
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
# Shim for the Pregel API. Will connect to GRAPH_SOCKET
|
|
59
|
-
# UNIX socket to communicate with the JS process.
|
|
60
|
-
class RemotePregel(Runnable):
|
|
61
|
-
# TODO: implement name overriding
|
|
62
|
-
name: str = "LangGraph"
|
|
63
|
-
|
|
64
|
-
# TODO: implement graph_id overriding
|
|
65
|
-
graph_id: str
|
|
66
|
-
|
|
67
|
-
_async_client: httpx.AsyncClient
|
|
68
|
-
|
|
69
|
-
@staticmethod
|
|
70
|
-
async def load(graph_id: str):
|
|
71
|
-
model = RemotePregel()
|
|
72
|
-
|
|
73
|
-
model.graph_id = graph_id
|
|
74
|
-
model._async_client = httpx.AsyncClient(
|
|
75
|
-
base_url="http://graph",
|
|
76
|
-
timeout=httpx.Timeout(None),
|
|
77
|
-
limits=httpx.Limits(),
|
|
78
|
-
transport=httpx.AsyncHTTPTransport(uds=GRAPH_SOCKET),
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
return model
|
|
82
|
-
|
|
83
|
-
async def astream_events(
|
|
84
|
-
self,
|
|
85
|
-
input: Any,
|
|
86
|
-
config: RunnableConfig | None = None,
|
|
87
|
-
*,
|
|
88
|
-
version: Literal["v1", "v2"],
|
|
89
|
-
**kwargs: Any,
|
|
90
|
-
) -> AsyncIterator[StreamEvent]:
|
|
91
|
-
if version != "v2":
|
|
92
|
-
raise ValueError("Only v2 of astream_events is supported")
|
|
93
|
-
|
|
94
|
-
data = {
|
|
95
|
-
"command" if isinstance(input, Command) else "input": input,
|
|
96
|
-
"config": config,
|
|
97
|
-
**kwargs,
|
|
98
|
-
}
|
|
99
|
-
|
|
100
|
-
async with aconnect_sse(
|
|
101
|
-
self._async_client,
|
|
102
|
-
"POST",
|
|
103
|
-
f"/{self.graph_id}/streamEvents",
|
|
104
|
-
headers={"Content-Type": "application/json"},
|
|
105
|
-
data=orjson.dumps(data),
|
|
106
|
-
) as event_source:
|
|
107
|
-
async for sse in event_source.aiter_sse():
|
|
108
|
-
event = orjson.loads(sse["data"])
|
|
109
|
-
if sse["event"] == "error":
|
|
110
|
-
raise RemoteException(event["error"], event["message"])
|
|
111
|
-
elif event["event"] == "on_custom_event":
|
|
112
|
-
yield CustomStreamEvent(**event)
|
|
113
|
-
else:
|
|
114
|
-
yield StandardStreamEvent(**event)
|
|
115
|
-
|
|
116
|
-
async def fetch_state_schema(self):
|
|
117
|
-
schema = await self._async_client.post(f"/{self.graph_id}/getSchema")
|
|
118
|
-
return orjson.loads(schema.content)
|
|
119
|
-
|
|
120
|
-
async def fetch_graph(
|
|
121
|
-
self,
|
|
122
|
-
config: RunnableConfig | None = None,
|
|
123
|
-
*,
|
|
124
|
-
xray: int | bool = False,
|
|
125
|
-
) -> DrawableGraph:
|
|
126
|
-
response = (
|
|
127
|
-
await self._async_client.post(
|
|
128
|
-
f"/{self.graph_id}/getGraph",
|
|
129
|
-
headers={"Content-Type": "application/json"},
|
|
130
|
-
data=orjson.dumps({"config": config, "xray": xray}),
|
|
131
|
-
)
|
|
132
|
-
).json()
|
|
133
|
-
|
|
134
|
-
nodes: list[Any] = response.pop("nodes")
|
|
135
|
-
edges: list[Any] = response.pop("edges")
|
|
136
|
-
|
|
137
|
-
return DrawableGraph(
|
|
138
|
-
{
|
|
139
|
-
data["id"]: Node(
|
|
140
|
-
data["id"], data["id"], NoopModel(), data.get("metadata")
|
|
141
|
-
)
|
|
142
|
-
for data in nodes
|
|
143
|
-
},
|
|
144
|
-
{
|
|
145
|
-
Edge(
|
|
146
|
-
data["source"],
|
|
147
|
-
data["target"],
|
|
148
|
-
data.get("data"),
|
|
149
|
-
data.get("conditional", False),
|
|
150
|
-
)
|
|
151
|
-
for data in edges
|
|
152
|
-
},
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
async def fetch_subgraphs(
|
|
156
|
-
self, *, namespace: str | None = None, recurse: bool = False
|
|
157
|
-
) -> dict[str, dict]:
|
|
158
|
-
return (
|
|
159
|
-
await self._async_client.post(
|
|
160
|
-
f"/{self.graph_id}/getSubgraphs",
|
|
161
|
-
headers={"Content-Type": "application/json"},
|
|
162
|
-
data=orjson.dumps({"namespace": namespace, "recurse": recurse}),
|
|
163
|
-
)
|
|
164
|
-
).json()
|
|
165
|
-
|
|
166
|
-
def _convert_state_snapshot(self, item: dict) -> StateSnapshot:
|
|
167
|
-
def _convert_tasks(tasks: list[dict]) -> tuple[PregelTask, ...]:
|
|
168
|
-
result: list[PregelTask] = []
|
|
169
|
-
for task in tasks:
|
|
170
|
-
state = task.get("state")
|
|
171
|
-
|
|
172
|
-
if state and isinstance(state, dict) and "config" in state:
|
|
173
|
-
state = self._convert_state_snapshot(state)
|
|
174
|
-
|
|
175
|
-
result.append(
|
|
176
|
-
PregelTask(
|
|
177
|
-
task["id"],
|
|
178
|
-
task["name"],
|
|
179
|
-
tuple(task["path"]) if task.get("path") else tuple(),
|
|
180
|
-
# TODO: figure out how to properly deserialise errors
|
|
181
|
-
task.get("error"),
|
|
182
|
-
(
|
|
183
|
-
tuple(
|
|
184
|
-
Interrupt(
|
|
185
|
-
value=interrupt["value"],
|
|
186
|
-
when=interrupt["when"],
|
|
187
|
-
resumable=interrupt.get("resumable", True),
|
|
188
|
-
ns=interrupt.get("ns"),
|
|
189
|
-
)
|
|
190
|
-
for interrupt in task.get("interrupts")
|
|
191
|
-
)
|
|
192
|
-
if task.get("interrupts")
|
|
193
|
-
else []
|
|
194
|
-
),
|
|
195
|
-
state,
|
|
196
|
-
)
|
|
197
|
-
)
|
|
198
|
-
return tuple(result)
|
|
199
|
-
|
|
200
|
-
return StateSnapshot(
|
|
201
|
-
item.get("values"),
|
|
202
|
-
item.get("next"),
|
|
203
|
-
item.get("config"),
|
|
204
|
-
item.get("metadata"),
|
|
205
|
-
item.get("createdAt"),
|
|
206
|
-
item.get("parentConfig"),
|
|
207
|
-
_convert_tasks(item.get("tasks", [])),
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
async def aget_state(
|
|
211
|
-
self, config: RunnableConfig, *, subgraphs: bool = False
|
|
212
|
-
) -> StateSnapshot:
|
|
213
|
-
response = await self._async_client.post(
|
|
214
|
-
f"/{self.graph_id}/getState",
|
|
215
|
-
headers={"Content-Type": "application/json"},
|
|
216
|
-
data=orjson.dumps({"config": config, "subgraphs": subgraphs}),
|
|
217
|
-
)
|
|
218
|
-
return self._convert_state_snapshot(response.json())
|
|
219
|
-
|
|
220
|
-
async def aupdate_state(
|
|
221
|
-
self,
|
|
222
|
-
config: RunnableConfig,
|
|
223
|
-
values: dict[str, Any] | Any,
|
|
224
|
-
as_node: str | None = None,
|
|
225
|
-
) -> RunnableConfig:
|
|
226
|
-
response = await self._async_client.post(
|
|
227
|
-
f"/{self.graph_id}/updateState",
|
|
228
|
-
headers={"Content-Type": "application/json"},
|
|
229
|
-
data=orjson.dumps({"config": config, "values": values, "as_node": as_node}),
|
|
230
|
-
)
|
|
231
|
-
return RunnableConfig(**response.json())
|
|
232
|
-
|
|
233
|
-
async def aget_state_history(
|
|
234
|
-
self,
|
|
235
|
-
config: RunnableConfig,
|
|
236
|
-
*,
|
|
237
|
-
filter: dict[str, Any] | None = None,
|
|
238
|
-
before: RunnableConfig | None = None,
|
|
239
|
-
limit: int | None = None,
|
|
240
|
-
) -> AsyncIterator[StateSnapshot]:
|
|
241
|
-
async with aconnect_sse(
|
|
242
|
-
self._async_client,
|
|
243
|
-
"POST",
|
|
244
|
-
f"/{self.graph_id}/getStateHistory",
|
|
245
|
-
headers={"Content-Type": "application/json"},
|
|
246
|
-
data=orjson.dumps(
|
|
247
|
-
{"config": config, "limit": limit, "filter": filter, "before": before}
|
|
248
|
-
),
|
|
249
|
-
) as event_source:
|
|
250
|
-
async for sse in event_source.aiter_sse():
|
|
251
|
-
yield self._convert_state_snapshot(orjson.loads(sse["data"]))
|
|
252
|
-
|
|
253
|
-
def get_graph(
|
|
254
|
-
self,
|
|
255
|
-
config: RunnableConfig | None = None,
|
|
256
|
-
*,
|
|
257
|
-
xray: int | bool = False,
|
|
258
|
-
) -> dict[str, Any]:
|
|
259
|
-
raise Exception("Not implemented")
|
|
260
|
-
|
|
261
|
-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
|
262
|
-
raise Exception("Not implemented")
|
|
263
|
-
|
|
264
|
-
def get_output_schema(
|
|
265
|
-
self, config: RunnableConfig | None = None
|
|
266
|
-
) -> type[BaseModel]:
|
|
267
|
-
raise Exception("Not implemented")
|
|
268
|
-
|
|
269
|
-
def config_schema(self) -> type[BaseModel]:
|
|
270
|
-
raise Exception("Not implemented")
|
|
271
|
-
|
|
272
|
-
async def invoke(self, input: Any, config: RunnableConfig | None = None):
|
|
273
|
-
raise Exception("Not implemented")
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
async def run_js_process(paths_str: str, watch: bool = False):
|
|
277
|
-
# check if tsx is available
|
|
278
|
-
tsx_path = shutil.which("tsx")
|
|
279
|
-
if tsx_path is None:
|
|
280
|
-
raise FileNotFoundError("tsx not found in PATH")
|
|
281
|
-
attempt = 0
|
|
282
|
-
while True:
|
|
283
|
-
client_file = os.path.join(os.path.dirname(__file__), "client.mts")
|
|
284
|
-
args = ("tsx", client_file)
|
|
285
|
-
if watch:
|
|
286
|
-
args = ("tsx", "watch", client_file, "--skip-schema-cache")
|
|
287
|
-
try:
|
|
288
|
-
process = await asyncio.create_subprocess_exec(
|
|
289
|
-
*args,
|
|
290
|
-
env={
|
|
291
|
-
"LANGSERVE_GRAPHS": paths_str,
|
|
292
|
-
"LANGCHAIN_CALLBACKS_BACKGROUND": "true",
|
|
293
|
-
"CHOKIDAR_USEPOLLING": "true",
|
|
294
|
-
**os.environ,
|
|
295
|
-
},
|
|
296
|
-
)
|
|
297
|
-
code = await process.wait()
|
|
298
|
-
raise Exception(f"JS process exited with code {code}")
|
|
299
|
-
except asyncio.CancelledError:
|
|
300
|
-
logger.info("Terminating JS graphs process")
|
|
301
|
-
try:
|
|
302
|
-
process.terminate()
|
|
303
|
-
await process.wait()
|
|
304
|
-
except (UnboundLocalError, ProcessLookupError):
|
|
305
|
-
pass
|
|
306
|
-
raise
|
|
307
|
-
except Exception:
|
|
308
|
-
if attempt >= 3:
|
|
309
|
-
raise
|
|
310
|
-
else:
|
|
311
|
-
logger.warning(f"Retrying JS process {3 - attempt} more times...")
|
|
312
|
-
attempt += 1
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
def _get_passthrough_checkpointer(conn: AsyncConnectionProto):
|
|
316
|
-
from langgraph_storage.checkpoint import Checkpointer
|
|
317
|
-
|
|
318
|
-
class PassthroughSerialiser(SerializerProtocol):
|
|
319
|
-
def dumps(self, obj: Any) -> bytes:
|
|
320
|
-
return json_dumpb(obj)
|
|
321
|
-
|
|
322
|
-
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
|
323
|
-
return "json", json_dumpb(obj)
|
|
324
|
-
|
|
325
|
-
def loads(self, data: bytes) -> Any:
|
|
326
|
-
return orjson.loads(data)
|
|
327
|
-
|
|
328
|
-
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
|
329
|
-
type, payload = data
|
|
330
|
-
if type != "json":
|
|
331
|
-
raise ValueError(f"Unsupported type {type}")
|
|
332
|
-
return orjson.loads(payload)
|
|
333
|
-
|
|
334
|
-
checkpointer = Checkpointer(conn)
|
|
335
|
-
|
|
336
|
-
# This checkpointer does not attempt to revive LC-objects.
|
|
337
|
-
# Instead, it will pass through the JSON values as-is.
|
|
338
|
-
checkpointer.serde = PassthroughSerialiser()
|
|
339
|
-
|
|
340
|
-
return checkpointer
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
# Setup a HTTP server on top of CHECKPOINTER_SOCKET unix socket
|
|
344
|
-
# used by `client.mts` to communicate with the Python checkpointer
|
|
345
|
-
async def run_remote_checkpointer():
|
|
346
|
-
from langgraph_storage.database import connect
|
|
347
|
-
|
|
348
|
-
# Search checkpoints
|
|
349
|
-
async def list(request: Request):
|
|
350
|
-
payload = orjson.loads(await request.body())
|
|
351
|
-
result = []
|
|
352
|
-
async with connect() as conn:
|
|
353
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
354
|
-
async for item in checkpointer.alist(
|
|
355
|
-
config=payload.get("config"),
|
|
356
|
-
limit=payload.get("limit"),
|
|
357
|
-
before=payload.get("before"),
|
|
358
|
-
filter=payload.get("filter"),
|
|
359
|
-
):
|
|
360
|
-
result.append(item)
|
|
361
|
-
|
|
362
|
-
return ApiResponse(result)
|
|
363
|
-
|
|
364
|
-
# Put the new checkpoint metadata
|
|
365
|
-
async def put(request: Request):
|
|
366
|
-
payload = orjson.loads(await request.body())
|
|
367
|
-
async with connect() as conn:
|
|
368
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
369
|
-
result = await checkpointer.aput(
|
|
370
|
-
payload["config"],
|
|
371
|
-
payload["checkpoint"],
|
|
372
|
-
payload["metadata"],
|
|
373
|
-
payload.get("new_versions", {}),
|
|
374
|
-
)
|
|
375
|
-
return ApiResponse(result)
|
|
376
|
-
|
|
377
|
-
# Get actual checkpoint values (reads)
|
|
378
|
-
async def get_tuple(request: Request):
|
|
379
|
-
payload = orjson.loads(await request.body())
|
|
380
|
-
|
|
381
|
-
async with connect() as conn:
|
|
382
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
383
|
-
result = await checkpointer.aget_tuple(config=payload["config"])
|
|
384
|
-
return ApiResponse(result)
|
|
385
|
-
|
|
386
|
-
# Put actual checkpoint values (writes)
|
|
387
|
-
async def put_writes(request: Request):
|
|
388
|
-
payload = orjson.loads(await request.body())
|
|
389
|
-
|
|
390
|
-
async with connect() as conn:
|
|
391
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
392
|
-
result = await checkpointer.aput_writes(
|
|
393
|
-
payload["config"],
|
|
394
|
-
payload["writes"],
|
|
395
|
-
payload["taskId"],
|
|
396
|
-
)
|
|
397
|
-
|
|
398
|
-
return ApiResponse(result)
|
|
399
|
-
|
|
400
|
-
remote = Starlette(
|
|
401
|
-
routes=[
|
|
402
|
-
Route("/get_tuple", get_tuple, methods=["POST"]),
|
|
403
|
-
Route("/list", list, methods=["POST"]),
|
|
404
|
-
Route("/put", put, methods=["POST"]),
|
|
405
|
-
Route("/put_writes", put_writes, methods=["POST"]),
|
|
406
|
-
Route("/ok", lambda _: ApiResponse({"ok": True}), methods=["GET"]),
|
|
407
|
-
]
|
|
1
|
+
from langgraph_api.config import FF_JS_ZEROMQ_ENABLED
|
|
2
|
+
|
|
3
|
+
if FF_JS_ZEROMQ_ENABLED:
|
|
4
|
+
from langgraph_api.js.remote_new import ( # noqa: I001
|
|
5
|
+
run_js_process, # noqa: F401
|
|
6
|
+
RemotePregel, # noqa: F401
|
|
7
|
+
run_remote_checkpointer, # noqa: F401
|
|
8
|
+
wait_until_js_ready, # noqa: F401
|
|
9
|
+
js_healthcheck, # noqa: F401
|
|
408
10
|
)
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
# the structlog logger setup before.
|
|
417
|
-
# See: https://github.com/encode/uvicorn/blob/8f4c8a7f34914c16650ebd026127b96560425fde/uvicorn/config.py#L357-L393
|
|
418
|
-
log_config=None,
|
|
419
|
-
log_level=None,
|
|
420
|
-
access_log=True,
|
|
421
|
-
)
|
|
11
|
+
else:
|
|
12
|
+
from langgraph_api.js.remote_old import ( # noqa: I001
|
|
13
|
+
run_js_process, # noqa: F401
|
|
14
|
+
RemotePregel, # noqa: F401
|
|
15
|
+
run_remote_checkpointer, # noqa: F401
|
|
16
|
+
wait_until_js_ready, # noqa: F401
|
|
17
|
+
js_healthcheck, # noqa: F401
|
|
422
18
|
)
|
|
423
|
-
await server.serve()
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
def _get_passthrough_store():
|
|
427
|
-
from langgraph_storage.store import Store
|
|
428
|
-
|
|
429
|
-
return Store()
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
async def run_remote_store():
|
|
433
|
-
async def abatch(request: Request):
|
|
434
|
-
payload = orjson.loads(await request.body())
|
|
435
|
-
operations = payload.get("operations", [])
|
|
436
|
-
|
|
437
|
-
if not operations:
|
|
438
|
-
return ApiResponse({"error": "No operations provided"}, status_code=400)
|
|
439
|
-
|
|
440
|
-
# Convert raw operations to proper objects
|
|
441
|
-
processed_operations = []
|
|
442
|
-
for op in operations:
|
|
443
|
-
if "value" in op:
|
|
444
|
-
processed_operations.append(
|
|
445
|
-
PutOp(
|
|
446
|
-
namespace=tuple(op["namespace"]),
|
|
447
|
-
key=op["key"],
|
|
448
|
-
value=op["value"],
|
|
449
|
-
)
|
|
450
|
-
)
|
|
451
|
-
elif "namespace_prefix" in op:
|
|
452
|
-
processed_operations.append(
|
|
453
|
-
SearchOp(
|
|
454
|
-
namespace_prefix=tuple(op["namespace_prefix"]),
|
|
455
|
-
filter=op.get("filter"),
|
|
456
|
-
limit=op.get("limit", 10),
|
|
457
|
-
offset=op.get("offset", 0),
|
|
458
|
-
)
|
|
459
|
-
)
|
|
460
|
-
|
|
461
|
-
elif "namespace" in op and "key" in op:
|
|
462
|
-
processed_operations.append(
|
|
463
|
-
GetOp(namespace=tuple(op["namespace"]), key=op["key"])
|
|
464
|
-
)
|
|
465
|
-
elif "match_conditions" in op:
|
|
466
|
-
processed_operations.append(
|
|
467
|
-
ListNamespacesOp(
|
|
468
|
-
match_conditions=tuple(op["match_conditions"]),
|
|
469
|
-
max_depth=op.get("max_depth"),
|
|
470
|
-
limit=op.get("limit", 100),
|
|
471
|
-
offset=op.get("offset", 0),
|
|
472
|
-
)
|
|
473
|
-
)
|
|
474
|
-
else:
|
|
475
|
-
return ApiResponse(
|
|
476
|
-
{"error": f"Unknown operation type: {op}"}, status_code=400
|
|
477
|
-
)
|
|
478
|
-
|
|
479
|
-
store = _get_passthrough_store()
|
|
480
|
-
results = await store.abatch(processed_operations)
|
|
481
|
-
|
|
482
|
-
# Handle potentially undefined or non-dict results
|
|
483
|
-
processed_results = []
|
|
484
|
-
# Result is of type: Union[Item, list[Item], list[tuple[str, ...]], None]
|
|
485
|
-
for result in results:
|
|
486
|
-
if isinstance(result, Item):
|
|
487
|
-
processed_results.append(result.dict())
|
|
488
|
-
elif isinstance(result, dict):
|
|
489
|
-
processed_results.append(result)
|
|
490
|
-
elif isinstance(result, list):
|
|
491
|
-
coerced = []
|
|
492
|
-
for res in result:
|
|
493
|
-
if isinstance(res, Item):
|
|
494
|
-
coerced.append(res.dict())
|
|
495
|
-
elif isinstance(res, tuple):
|
|
496
|
-
coerced.append(list(res))
|
|
497
|
-
elif res is None:
|
|
498
|
-
coerced.append(res)
|
|
499
|
-
else:
|
|
500
|
-
coerced.append(str(res))
|
|
501
|
-
processed_results.append(coerced)
|
|
502
|
-
elif result is None:
|
|
503
|
-
processed_results.append(None)
|
|
504
|
-
else:
|
|
505
|
-
processed_results.append(str(result))
|
|
506
|
-
return ApiResponse(processed_results)
|
|
507
|
-
|
|
508
|
-
# List all stores
|
|
509
|
-
async def aget(request: Request):
|
|
510
|
-
namespaces_str = request.query_params.get("namespaces")
|
|
511
|
-
key = request.query_params.get("key")
|
|
512
|
-
|
|
513
|
-
if not namespaces_str or not key:
|
|
514
|
-
return ApiResponse(
|
|
515
|
-
{"error": "Both namespaces and key are required"}, status_code=400
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
namespaces = namespaces_str.split(".")
|
|
519
|
-
|
|
520
|
-
store = _get_passthrough_store()
|
|
521
|
-
result = await store.aget(namespaces, key)
|
|
522
|
-
|
|
523
|
-
return ApiResponse(result)
|
|
524
|
-
|
|
525
|
-
# Put the new store data
|
|
526
|
-
async def aput(request: Request):
|
|
527
|
-
payload = orjson.loads(await request.body())
|
|
528
|
-
namespace = tuple(payload["namespace"].split("."))
|
|
529
|
-
key = payload["key"]
|
|
530
|
-
value = payload["value"]
|
|
531
|
-
index = payload.get("index")
|
|
532
|
-
|
|
533
|
-
store = _get_passthrough_store()
|
|
534
|
-
await store.aput(namespace, key, value, index=index)
|
|
535
|
-
|
|
536
|
-
return ApiResponse({"success": True})
|
|
537
|
-
|
|
538
|
-
# Search stores
|
|
539
|
-
async def asearch(request: Request):
|
|
540
|
-
payload = orjson.loads(await request.body())
|
|
541
|
-
namespace_prefix = tuple(payload["namespace_prefix"])
|
|
542
|
-
filter = payload.get("filter")
|
|
543
|
-
limit = payload.get("limit", 10)
|
|
544
|
-
offset = payload.get("offset", 0)
|
|
545
|
-
query = payload.get("query")
|
|
546
|
-
|
|
547
|
-
store = _get_passthrough_store()
|
|
548
|
-
result = await store.asearch(
|
|
549
|
-
namespace_prefix, filter=filter, limit=limit, offset=offset, query=query
|
|
550
|
-
)
|
|
551
|
-
|
|
552
|
-
return ApiResponse([item.dict() for item in result])
|
|
553
|
-
|
|
554
|
-
# Delete store data
|
|
555
|
-
async def adelete(request: Request):
|
|
556
|
-
payload = orjson.loads(await request.body())
|
|
557
|
-
namespace = tuple(payload["namespace"])
|
|
558
|
-
key = payload["key"]
|
|
559
|
-
|
|
560
|
-
store = _get_passthrough_store()
|
|
561
|
-
await store.adelete(namespace, key)
|
|
562
|
-
|
|
563
|
-
return ApiResponse({"success": True})
|
|
564
|
-
|
|
565
|
-
# List all namespaces
|
|
566
|
-
async def alist_namespaces(request: Request):
|
|
567
|
-
payload = orjson.loads(await request.body())
|
|
568
|
-
prefix = tuple(payload.get("prefix", [])) or None
|
|
569
|
-
suffix = tuple(payload.get("suffix", [])) or None
|
|
570
|
-
max_depth = payload.get("max_depth")
|
|
571
|
-
limit = payload.get("limit", 100)
|
|
572
|
-
offset = payload.get("offset", 0)
|
|
573
|
-
|
|
574
|
-
store = _get_passthrough_store()
|
|
575
|
-
result = await store.alist_namespaces(
|
|
576
|
-
prefix=prefix,
|
|
577
|
-
suffix=suffix,
|
|
578
|
-
max_depth=max_depth,
|
|
579
|
-
limit=limit,
|
|
580
|
-
offset=offset,
|
|
581
|
-
)
|
|
582
|
-
|
|
583
|
-
return ApiResponse([list(ns) for ns in result])
|
|
584
|
-
|
|
585
|
-
remote = Starlette(
|
|
586
|
-
routes=[
|
|
587
|
-
Route("/items", aget, methods=["GET"]),
|
|
588
|
-
Route("/items", aput, methods=["PUT"]),
|
|
589
|
-
Route("/items", adelete, methods=["DELETE"]),
|
|
590
|
-
Route("/items/search", asearch, methods=["POST"]),
|
|
591
|
-
Route("/list/namespaces", alist_namespaces, methods=["POST"]),
|
|
592
|
-
Route("/items/batch", abatch, methods=["POST"]),
|
|
593
|
-
Route("/ok", lambda _: ApiResponse({"ok": True}), methods=["GET"]),
|
|
594
|
-
]
|
|
595
|
-
)
|
|
596
|
-
server = uvicorn.Server(
|
|
597
|
-
uvicorn.Config(
|
|
598
|
-
remote,
|
|
599
|
-
uds=STORE_SOCKET,
|
|
600
|
-
# We need to _explicitly_ set these values in order
|
|
601
|
-
# to avoid reinitialising the logger, which removes
|
|
602
|
-
# the structlog logger setup before.
|
|
603
|
-
# See: https://github.com/encode/uvicorn/blob/8f4c8a7f34914c16650ebd026127b96560425fde/uvicorn/config.py#L357-L393
|
|
604
|
-
log_config=None,
|
|
605
|
-
log_level=None,
|
|
606
|
-
access_log=True,
|
|
607
|
-
)
|
|
608
|
-
)
|
|
609
|
-
await server.serve()
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
async def wait_until_js_ready():
|
|
613
|
-
async with (
|
|
614
|
-
httpx.AsyncClient(
|
|
615
|
-
base_url="http://graph",
|
|
616
|
-
transport=httpx.AsyncHTTPTransport(uds=GRAPH_SOCKET),
|
|
617
|
-
limits=httpx.Limits(),
|
|
618
|
-
) as graph_client,
|
|
619
|
-
httpx.AsyncClient(
|
|
620
|
-
base_url="http://checkpointer",
|
|
621
|
-
transport=httpx.AsyncHTTPTransport(uds=CHECKPOINTER_SOCKET),
|
|
622
|
-
limits=httpx.Limits(),
|
|
623
|
-
) as checkpointer_client,
|
|
624
|
-
httpx.AsyncClient(
|
|
625
|
-
base_url="http://store",
|
|
626
|
-
transport=httpx.AsyncHTTPTransport(uds=STORE_SOCKET),
|
|
627
|
-
limits=httpx.Limits(),
|
|
628
|
-
) as store_client,
|
|
629
|
-
):
|
|
630
|
-
attempt = 0
|
|
631
|
-
while True:
|
|
632
|
-
try:
|
|
633
|
-
res = await graph_client.get("/ok")
|
|
634
|
-
res.raise_for_status()
|
|
635
|
-
res = await checkpointer_client.get("/ok")
|
|
636
|
-
res.raise_for_status()
|
|
637
|
-
res = await store_client.get("/ok")
|
|
638
|
-
res.raise_for_status()
|
|
639
|
-
return
|
|
640
|
-
except httpx.HTTPError:
|
|
641
|
-
if attempt > 240:
|
|
642
|
-
raise
|
|
643
|
-
else:
|
|
644
|
-
attempt += 1
|
|
645
|
-
await asyncio.sleep(0.5)
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
async def js_healthcheck():
|
|
649
|
-
async with (
|
|
650
|
-
httpx.AsyncClient(
|
|
651
|
-
base_url="http://graph",
|
|
652
|
-
transport=httpx.AsyncHTTPTransport(uds=GRAPH_SOCKET),
|
|
653
|
-
limits=httpx.Limits(),
|
|
654
|
-
) as graph_client,
|
|
655
|
-
httpx.AsyncClient(
|
|
656
|
-
base_url="http://checkpointer",
|
|
657
|
-
transport=httpx.AsyncHTTPTransport(uds=CHECKPOINTER_SOCKET),
|
|
658
|
-
limits=httpx.Limits(),
|
|
659
|
-
) as checkpointer_client,
|
|
660
|
-
httpx.AsyncClient(
|
|
661
|
-
base_url="http://store",
|
|
662
|
-
transport=httpx.AsyncHTTPTransport(uds=STORE_SOCKET),
|
|
663
|
-
limits=httpx.Limits(),
|
|
664
|
-
) as store_client,
|
|
665
|
-
):
|
|
666
|
-
try:
|
|
667
|
-
res = await graph_client.get("/ok")
|
|
668
|
-
res.raise_for_status()
|
|
669
|
-
res = await checkpointer_client.get("/ok")
|
|
670
|
-
res.raise_for_status()
|
|
671
|
-
res = await store_client.get("/ok")
|
|
672
|
-
res.raise_for_status()
|
|
673
|
-
return True
|
|
674
|
-
except httpx.HTTPError:
|
|
675
|
-
return False
|