langgraph-api 0.0.27__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langgraph-api might be problematic. Click here for more details.
- langgraph_api/api/__init__.py +2 -0
- langgraph_api/api/assistants.py +43 -13
- langgraph_api/api/meta.py +1 -1
- langgraph_api/api/runs.py +14 -1
- langgraph_api/api/ui.py +68 -0
- langgraph_api/asyncio.py +43 -4
- langgraph_api/auth/middleware.py +2 -2
- langgraph_api/config.py +14 -1
- langgraph_api/cron_scheduler.py +1 -1
- langgraph_api/graph.py +5 -0
- langgraph_api/http.py +24 -7
- langgraph_api/js/.gitignore +2 -0
- langgraph_api/js/build.mts +44 -1
- langgraph_api/js/client.mts +67 -31
- langgraph_api/js/global.d.ts +1 -0
- langgraph_api/js/package.json +11 -5
- langgraph_api/js/remote.py +662 -16
- langgraph_api/js/sse.py +138 -0
- langgraph_api/js/tests/api.test.mts +28 -0
- langgraph_api/js/tests/compose-postgres.yml +2 -2
- langgraph_api/js/tests/graphs/agent.css +1 -0
- langgraph_api/js/tests/graphs/agent.ui.tsx +10 -0
- langgraph_api/js/tests/graphs/package.json +2 -2
- langgraph_api/js/tests/graphs/yarn.lock +13 -13
- langgraph_api/js/yarn.lock +706 -1188
- langgraph_api/lifespan.py +15 -5
- langgraph_api/logging.py +9 -0
- langgraph_api/metadata.py +5 -1
- langgraph_api/middleware/http_logger.py +1 -1
- langgraph_api/patch.py +2 -0
- langgraph_api/queue_entrypoint.py +63 -0
- langgraph_api/schema.py +2 -0
- langgraph_api/stream.py +1 -0
- langgraph_api/webhook.py +42 -0
- langgraph_api/{queue.py → worker.py} +52 -166
- {langgraph_api-0.0.27.dist-info → langgraph_api-0.0.28.dist-info}/METADATA +2 -2
- {langgraph_api-0.0.27.dist-info → langgraph_api-0.0.28.dist-info}/RECORD +47 -44
- langgraph_storage/database.py +8 -22
- langgraph_storage/inmem_stream.py +108 -0
- langgraph_storage/ops.py +80 -57
- langgraph_storage/queue.py +126 -103
- langgraph_storage/retry.py +5 -1
- langgraph_storage/store.py +5 -1
- openapi.json +3 -3
- langgraph_api/js/client.new.mts +0 -875
- langgraph_api/js/remote_new.py +0 -694
- langgraph_api/js/remote_old.py +0 -670
- langgraph_api/js/server_sent_events.py +0 -126
- {langgraph_api-0.0.27.dist-info → langgraph_api-0.0.28.dist-info}/LICENSE +0 -0
- {langgraph_api-0.0.27.dist-info → langgraph_api-0.0.28.dist-info}/WHEEL +0 -0
- {langgraph_api-0.0.27.dist-info → langgraph_api-0.0.28.dist-info}/entry_points.txt +0 -0
langgraph_api/js/remote_new.py
DELETED
|
@@ -1,694 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import os
|
|
3
|
-
import shutil
|
|
4
|
-
from collections.abc import AsyncIterator, Callable
|
|
5
|
-
from contextlib import AbstractContextManager
|
|
6
|
-
from typing import Any, Literal
|
|
7
|
-
|
|
8
|
-
import orjson
|
|
9
|
-
import structlog
|
|
10
|
-
import zmq
|
|
11
|
-
import zmq.asyncio
|
|
12
|
-
from langchain_core.runnables.config import RunnableConfig
|
|
13
|
-
from langchain_core.runnables.graph import Edge, Node
|
|
14
|
-
from langchain_core.runnables.graph import Graph as DrawableGraph
|
|
15
|
-
from langchain_core.runnables.schema import (
|
|
16
|
-
CustomStreamEvent,
|
|
17
|
-
StandardStreamEvent,
|
|
18
|
-
StreamEvent,
|
|
19
|
-
)
|
|
20
|
-
from langgraph.checkpoint.base.id import uuid6
|
|
21
|
-
from langgraph.checkpoint.serde.base import SerializerProtocol
|
|
22
|
-
from langgraph.pregel.types import PregelTask, StateSnapshot
|
|
23
|
-
from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp
|
|
24
|
-
from langgraph.types import Command, Interrupt
|
|
25
|
-
from pydantic import BaseModel
|
|
26
|
-
from starlette.exceptions import HTTPException
|
|
27
|
-
from zmq.utils.monitor import recv_monitor_message
|
|
28
|
-
|
|
29
|
-
from langgraph_api.js.base import BaseRemotePregel
|
|
30
|
-
from langgraph_api.js.errors import RemoteException
|
|
31
|
-
from langgraph_api.js.schema import (
|
|
32
|
-
ErrorData,
|
|
33
|
-
RequestPayload,
|
|
34
|
-
ResponsePayload,
|
|
35
|
-
StreamData,
|
|
36
|
-
)
|
|
37
|
-
from langgraph_api.serde import json_dumpb, json_loads
|
|
38
|
-
from langgraph_api.utils import AsyncConnectionProto
|
|
39
|
-
|
|
40
|
-
logger = structlog.stdlib.get_logger(__name__)
|
|
41
|
-
|
|
42
|
-
CLIENT_ADDR = "tcp://0.0.0.0:5556"
|
|
43
|
-
REMOTE_ADDR = "tcp://0.0.0.0:5555"
|
|
44
|
-
MONITOR_ADDR = "inproc://router.remote"
|
|
45
|
-
|
|
46
|
-
WAIT_FOR_REQUEST_TTL_SECONDS = 15.0
|
|
47
|
-
|
|
48
|
-
context = zmq.asyncio.Context()
|
|
49
|
-
clientDealer = context.socket(zmq.DEALER)
|
|
50
|
-
remoteRouter = context.socket(zmq.ROUTER)
|
|
51
|
-
|
|
52
|
-
remoteRouter.monitor(
|
|
53
|
-
MONITOR_ADDR,
|
|
54
|
-
zmq.EVENT_HANDSHAKE_SUCCEEDED | zmq.EVENT_DISCONNECTED,
|
|
55
|
-
)
|
|
56
|
-
remoteMonitor = context.socket(zmq.PAIR)
|
|
57
|
-
remoteMonitor.setsockopt(zmq.LINGER, 0)
|
|
58
|
-
|
|
59
|
-
REMOTE_MGS_MAP: dict[str, asyncio.Queue] = {}
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class HeartbeatPing:
|
|
63
|
-
pass
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class RequestTimeout(Exception):
|
|
67
|
-
method: str
|
|
68
|
-
|
|
69
|
-
def __init__(self, method: str):
|
|
70
|
-
self.method = method
|
|
71
|
-
|
|
72
|
-
def __str__(self):
|
|
73
|
-
return f'Request to "{self.method}" timed out'
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
class RequestContext(AbstractContextManager["RequestContext"]):
|
|
77
|
-
id: str
|
|
78
|
-
method: str
|
|
79
|
-
|
|
80
|
-
def __init__(self, method: str):
|
|
81
|
-
self.id = uuid6().hex
|
|
82
|
-
self.method = method
|
|
83
|
-
|
|
84
|
-
def __enter__(self) -> "RequestContext":
|
|
85
|
-
REMOTE_MGS_MAP[self.id] = asyncio.Queue()
|
|
86
|
-
return self
|
|
87
|
-
|
|
88
|
-
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
89
|
-
REMOTE_MGS_MAP.pop(self.id, None)
|
|
90
|
-
|
|
91
|
-
async def get(self) -> Any:
|
|
92
|
-
try:
|
|
93
|
-
value = await asyncio.wait_for(
|
|
94
|
-
REMOTE_MGS_MAP[self.id].get(), timeout=WAIT_FOR_REQUEST_TTL_SECONDS
|
|
95
|
-
)
|
|
96
|
-
except TimeoutError as exc:
|
|
97
|
-
raise RequestTimeout(self.method) from exc
|
|
98
|
-
|
|
99
|
-
if isinstance(value, HeartbeatPing):
|
|
100
|
-
return await self.get()
|
|
101
|
-
if isinstance(value, Exception):
|
|
102
|
-
raise value
|
|
103
|
-
return value
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
async def _client_stream(method: str, data: Any):
|
|
107
|
-
with RequestContext(method) as context:
|
|
108
|
-
await clientDealer.send(
|
|
109
|
-
orjson.dumps({"method": method, "id": context.id, "data": data})
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
while True:
|
|
113
|
-
response: StreamData = await context.get()
|
|
114
|
-
|
|
115
|
-
if response["done"]:
|
|
116
|
-
if "value" in response:
|
|
117
|
-
yield response["value"]
|
|
118
|
-
break
|
|
119
|
-
|
|
120
|
-
yield response["value"]
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
async def _client_invoke(method: str, data: Any):
|
|
124
|
-
with RequestContext(method) as context:
|
|
125
|
-
await clientDealer.send(
|
|
126
|
-
orjson.dumps({"method": method, "id": context.id, "data": data})
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
return await context.get()
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
class RemotePregel(BaseRemotePregel):
|
|
133
|
-
@staticmethod
|
|
134
|
-
def load(graph_id: str):
|
|
135
|
-
model = RemotePregel()
|
|
136
|
-
model.graph_id = graph_id
|
|
137
|
-
return model
|
|
138
|
-
|
|
139
|
-
async def astream_events(
|
|
140
|
-
self,
|
|
141
|
-
input: Any,
|
|
142
|
-
config: RunnableConfig | None = None,
|
|
143
|
-
*,
|
|
144
|
-
version: Literal["v1", "v2"],
|
|
145
|
-
**kwargs: Any,
|
|
146
|
-
) -> AsyncIterator[StreamEvent]:
|
|
147
|
-
if version != "v2":
|
|
148
|
-
raise ValueError("Only v2 of astream_events is supported")
|
|
149
|
-
|
|
150
|
-
data = {
|
|
151
|
-
"graph_id": self.graph_id,
|
|
152
|
-
"command" if isinstance(input, Command) else "input": input,
|
|
153
|
-
"config": config,
|
|
154
|
-
**kwargs,
|
|
155
|
-
}
|
|
156
|
-
|
|
157
|
-
async for event in _client_stream("streamEvents", data):
|
|
158
|
-
if event["event"] == "on_custom_event":
|
|
159
|
-
yield CustomStreamEvent(**event)
|
|
160
|
-
else:
|
|
161
|
-
yield StandardStreamEvent(**event)
|
|
162
|
-
|
|
163
|
-
async def fetch_state_schema(self):
|
|
164
|
-
return await _client_invoke("getSchema", {"graph_id": self.graph_id})
|
|
165
|
-
|
|
166
|
-
async def fetch_graph(
|
|
167
|
-
self,
|
|
168
|
-
config: RunnableConfig | None = None,
|
|
169
|
-
*,
|
|
170
|
-
xray: int | bool = False,
|
|
171
|
-
) -> DrawableGraph:
|
|
172
|
-
response = await _client_invoke(
|
|
173
|
-
"getGraph", {"graph_id": self.graph_id, "config": config, "xray": xray}
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
nodes: list[Any] = response.pop("nodes")
|
|
177
|
-
edges: list[Any] = response.pop("edges")
|
|
178
|
-
|
|
179
|
-
class NoopModel(BaseModel):
|
|
180
|
-
pass
|
|
181
|
-
|
|
182
|
-
return DrawableGraph(
|
|
183
|
-
{
|
|
184
|
-
data["id"]: Node(
|
|
185
|
-
data["id"], data["id"], NoopModel(), data.get("metadata")
|
|
186
|
-
)
|
|
187
|
-
for data in nodes
|
|
188
|
-
},
|
|
189
|
-
{
|
|
190
|
-
Edge(
|
|
191
|
-
data["source"],
|
|
192
|
-
data["target"],
|
|
193
|
-
data.get("data"),
|
|
194
|
-
data.get("conditional", False),
|
|
195
|
-
)
|
|
196
|
-
for data in edges
|
|
197
|
-
},
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
async def fetch_subgraphs(
|
|
201
|
-
self, *, namespace: str | None = None, recurse: bool = False
|
|
202
|
-
) -> dict[str, dict]:
|
|
203
|
-
return await _client_invoke(
|
|
204
|
-
"getSubgraphs",
|
|
205
|
-
{"graph_id": self.graph_id, "namespace": namespace, "recurse": recurse},
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
def _convert_state_snapshot(self, item: dict) -> StateSnapshot:
|
|
209
|
-
def _convert_tasks(tasks: list[dict]) -> tuple[PregelTask, ...]:
|
|
210
|
-
result: list[PregelTask] = []
|
|
211
|
-
for task in tasks:
|
|
212
|
-
state = task.get("state")
|
|
213
|
-
|
|
214
|
-
if state and isinstance(state, dict) and "config" in state:
|
|
215
|
-
state = self._convert_state_snapshot(state)
|
|
216
|
-
|
|
217
|
-
result.append(
|
|
218
|
-
PregelTask(
|
|
219
|
-
task["id"],
|
|
220
|
-
task["name"],
|
|
221
|
-
tuple(task["path"]) if task.get("path") else tuple(),
|
|
222
|
-
# TODO: figure out how to properly deserialise errors
|
|
223
|
-
task.get("error"),
|
|
224
|
-
(
|
|
225
|
-
tuple(
|
|
226
|
-
Interrupt(
|
|
227
|
-
value=interrupt["value"],
|
|
228
|
-
when=interrupt["when"],
|
|
229
|
-
resumable=interrupt.get("resumable", True),
|
|
230
|
-
ns=interrupt.get("ns"),
|
|
231
|
-
)
|
|
232
|
-
for interrupt in task.get("interrupts")
|
|
233
|
-
)
|
|
234
|
-
if task.get("interrupts")
|
|
235
|
-
else []
|
|
236
|
-
),
|
|
237
|
-
state,
|
|
238
|
-
)
|
|
239
|
-
)
|
|
240
|
-
return tuple(result)
|
|
241
|
-
|
|
242
|
-
return StateSnapshot(
|
|
243
|
-
item.get("values"),
|
|
244
|
-
item.get("next"),
|
|
245
|
-
item.get("config"),
|
|
246
|
-
item.get("metadata"),
|
|
247
|
-
item.get("createdAt"),
|
|
248
|
-
item.get("parentConfig"),
|
|
249
|
-
_convert_tasks(item.get("tasks", [])),
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
async def aget_state(
|
|
253
|
-
self, config: RunnableConfig, *, subgraphs: bool = False
|
|
254
|
-
) -> StateSnapshot:
|
|
255
|
-
return self._convert_state_snapshot(
|
|
256
|
-
await _client_invoke(
|
|
257
|
-
"getState",
|
|
258
|
-
{"graph_id": self.graph_id, "config": config, "subgraphs": subgraphs},
|
|
259
|
-
)
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
async def aupdate_state(
|
|
263
|
-
self,
|
|
264
|
-
config: RunnableConfig,
|
|
265
|
-
values: dict[str, Any] | Any,
|
|
266
|
-
as_node: str | None = None,
|
|
267
|
-
) -> RunnableConfig:
|
|
268
|
-
response = await _client_invoke(
|
|
269
|
-
"updateState",
|
|
270
|
-
{
|
|
271
|
-
"graph_id": self.graph_id,
|
|
272
|
-
"config": config,
|
|
273
|
-
"values": values,
|
|
274
|
-
"as_node": as_node,
|
|
275
|
-
},
|
|
276
|
-
)
|
|
277
|
-
return RunnableConfig(**response)
|
|
278
|
-
|
|
279
|
-
async def aget_state_history(
|
|
280
|
-
self,
|
|
281
|
-
config: RunnableConfig,
|
|
282
|
-
*,
|
|
283
|
-
filter: dict[str, Any] | None = None,
|
|
284
|
-
before: RunnableConfig | None = None,
|
|
285
|
-
limit: int | None = None,
|
|
286
|
-
) -> AsyncIterator[StateSnapshot]:
|
|
287
|
-
async for event in _client_stream(
|
|
288
|
-
"getStateHistory",
|
|
289
|
-
{
|
|
290
|
-
"graph_id": self.graph_id,
|
|
291
|
-
"config": config,
|
|
292
|
-
"limit": limit,
|
|
293
|
-
"filter": filter,
|
|
294
|
-
"before": before,
|
|
295
|
-
},
|
|
296
|
-
):
|
|
297
|
-
yield self._convert_state_snapshot(event)
|
|
298
|
-
|
|
299
|
-
def get_graph(
|
|
300
|
-
self,
|
|
301
|
-
config: RunnableConfig | None = None,
|
|
302
|
-
*,
|
|
303
|
-
xray: int | bool = False,
|
|
304
|
-
) -> dict[str, Any]:
|
|
305
|
-
raise Exception("Not implemented")
|
|
306
|
-
|
|
307
|
-
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
|
308
|
-
raise Exception("Not implemented")
|
|
309
|
-
|
|
310
|
-
def get_output_schema(
|
|
311
|
-
self, config: RunnableConfig | None = None
|
|
312
|
-
) -> type[BaseModel]:
|
|
313
|
-
raise Exception("Not implemented")
|
|
314
|
-
|
|
315
|
-
def config_schema(self) -> type[BaseModel]:
|
|
316
|
-
raise Exception("Not implemented")
|
|
317
|
-
|
|
318
|
-
async def invoke(self, input: Any, config: RunnableConfig | None = None):
|
|
319
|
-
raise Exception("Not implemented")
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
async def run_js_process(paths_str: str, watch: bool = False):
|
|
323
|
-
# check if tsx is available
|
|
324
|
-
tsx_path = shutil.which("tsx")
|
|
325
|
-
if tsx_path is None:
|
|
326
|
-
raise FileNotFoundError("tsx not found in PATH")
|
|
327
|
-
attempt = 0
|
|
328
|
-
while not asyncio.current_task().cancelled():
|
|
329
|
-
client_file = os.path.join(os.path.dirname(__file__), "client.new.mts")
|
|
330
|
-
args = ("tsx", client_file)
|
|
331
|
-
if watch:
|
|
332
|
-
args = ("tsx", "watch", client_file, "--skip-schema-cache")
|
|
333
|
-
try:
|
|
334
|
-
process = await asyncio.create_subprocess_exec(
|
|
335
|
-
*args,
|
|
336
|
-
env={
|
|
337
|
-
"LANGSERVE_GRAPHS": paths_str,
|
|
338
|
-
"LANGCHAIN_CALLBACKS_BACKGROUND": "true",
|
|
339
|
-
"NODE_ENV": "development" if watch else "production",
|
|
340
|
-
"CHOKIDAR_USEPOLLING": "true",
|
|
341
|
-
**os.environ,
|
|
342
|
-
},
|
|
343
|
-
)
|
|
344
|
-
code = await process.wait()
|
|
345
|
-
raise Exception(f"JS process exited with code {code}")
|
|
346
|
-
except asyncio.CancelledError:
|
|
347
|
-
logger.info("Terminating JS graphs process")
|
|
348
|
-
try:
|
|
349
|
-
process.terminate()
|
|
350
|
-
await process.wait()
|
|
351
|
-
except (UnboundLocalError, ProcessLookupError):
|
|
352
|
-
pass
|
|
353
|
-
raise
|
|
354
|
-
except Exception:
|
|
355
|
-
if attempt >= 3:
|
|
356
|
-
raise
|
|
357
|
-
else:
|
|
358
|
-
logger.warning(f"Retrying JS process {3 - attempt} more times...")
|
|
359
|
-
attempt += 1
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
def _get_passthrough_checkpointer(conn: AsyncConnectionProto):
|
|
363
|
-
from langgraph_storage.checkpoint import Checkpointer
|
|
364
|
-
|
|
365
|
-
class PassthroughSerialiser(SerializerProtocol):
|
|
366
|
-
def dumps(self, obj: Any) -> bytes:
|
|
367
|
-
return json_dumpb(obj)
|
|
368
|
-
|
|
369
|
-
def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
|
|
370
|
-
return "json", json_dumpb(obj)
|
|
371
|
-
|
|
372
|
-
def loads(self, data: bytes) -> Any:
|
|
373
|
-
return orjson.loads(data)
|
|
374
|
-
|
|
375
|
-
def loads_typed(self, data: tuple[str, bytes]) -> Any:
|
|
376
|
-
type, payload = data
|
|
377
|
-
if type != "json":
|
|
378
|
-
raise ValueError(f"Unsupported type {type}")
|
|
379
|
-
return orjson.loads(payload)
|
|
380
|
-
|
|
381
|
-
checkpointer = Checkpointer(conn)
|
|
382
|
-
|
|
383
|
-
# This checkpointer does not attempt to revive LC-objects.
|
|
384
|
-
# Instead, it will pass through the JSON values as-is.
|
|
385
|
-
checkpointer.serde = PassthroughSerialiser()
|
|
386
|
-
|
|
387
|
-
return checkpointer
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
def _get_passthrough_store():
|
|
391
|
-
from langgraph_storage.store import Store
|
|
392
|
-
|
|
393
|
-
return Store()
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
async def run_remote_checkpointer():
|
|
397
|
-
from langgraph_storage.database import connect
|
|
398
|
-
|
|
399
|
-
async def checkpointer_list(payload: dict):
|
|
400
|
-
"""Search checkpoints"""
|
|
401
|
-
|
|
402
|
-
result = []
|
|
403
|
-
async with connect() as conn:
|
|
404
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
405
|
-
async for item in checkpointer.alist(
|
|
406
|
-
config=payload.get("config"),
|
|
407
|
-
limit=payload.get("limit"),
|
|
408
|
-
before=payload.get("before"),
|
|
409
|
-
filter=payload.get("filter"),
|
|
410
|
-
):
|
|
411
|
-
result.append(item)
|
|
412
|
-
|
|
413
|
-
return result
|
|
414
|
-
|
|
415
|
-
async def checkpointer_put(payload: dict):
|
|
416
|
-
"""Put the new checkpoint metadata"""
|
|
417
|
-
|
|
418
|
-
async with connect() as conn:
|
|
419
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
420
|
-
return await checkpointer.aput(
|
|
421
|
-
payload["config"],
|
|
422
|
-
payload["checkpoint"],
|
|
423
|
-
payload["metadata"],
|
|
424
|
-
payload.get("new_versions", {}),
|
|
425
|
-
)
|
|
426
|
-
|
|
427
|
-
async def checkpointer_get_tuple(payload: dict):
|
|
428
|
-
"""Get actual checkpoint values (reads)"""
|
|
429
|
-
|
|
430
|
-
async with connect() as conn:
|
|
431
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
432
|
-
return await checkpointer.aget_tuple(config=payload["config"])
|
|
433
|
-
|
|
434
|
-
async def checkpointer_put_writes(payload: dict):
|
|
435
|
-
"""Put actual checkpoint values (writes)"""
|
|
436
|
-
|
|
437
|
-
async with connect() as conn:
|
|
438
|
-
checkpointer = _get_passthrough_checkpointer(conn)
|
|
439
|
-
return await checkpointer.aput_writes(
|
|
440
|
-
payload["config"],
|
|
441
|
-
payload["writes"],
|
|
442
|
-
payload["taskId"],
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
async def store_batch(payload: dict):
|
|
446
|
-
"""Batch operations on the store"""
|
|
447
|
-
operations = payload.get("operations", [])
|
|
448
|
-
|
|
449
|
-
if not operations:
|
|
450
|
-
raise ValueError("No operations provided")
|
|
451
|
-
|
|
452
|
-
# Convert raw operations to proper objects
|
|
453
|
-
processed_operations = []
|
|
454
|
-
for op in operations:
|
|
455
|
-
if "value" in op:
|
|
456
|
-
processed_operations.append(
|
|
457
|
-
PutOp(
|
|
458
|
-
namespace=tuple(op["namespace"]),
|
|
459
|
-
key=op["key"],
|
|
460
|
-
value=op["value"],
|
|
461
|
-
)
|
|
462
|
-
)
|
|
463
|
-
elif "namespace_prefix" in op:
|
|
464
|
-
processed_operations.append(
|
|
465
|
-
SearchOp(
|
|
466
|
-
namespace_prefix=tuple(op["namespace_prefix"]),
|
|
467
|
-
filter=op.get("filter"),
|
|
468
|
-
limit=op.get("limit", 10),
|
|
469
|
-
offset=op.get("offset", 0),
|
|
470
|
-
)
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
elif "namespace" in op and "key" in op:
|
|
474
|
-
processed_operations.append(
|
|
475
|
-
GetOp(namespace=tuple(op["namespace"]), key=op["key"])
|
|
476
|
-
)
|
|
477
|
-
elif "match_conditions" in op:
|
|
478
|
-
processed_operations.append(
|
|
479
|
-
ListNamespacesOp(
|
|
480
|
-
match_conditions=tuple(op["match_conditions"]),
|
|
481
|
-
max_depth=op.get("max_depth"),
|
|
482
|
-
limit=op.get("limit", 100),
|
|
483
|
-
offset=op.get("offset", 0),
|
|
484
|
-
)
|
|
485
|
-
)
|
|
486
|
-
else:
|
|
487
|
-
raise ValueError(f"Unknown operation type: {op}")
|
|
488
|
-
|
|
489
|
-
store = _get_passthrough_store()
|
|
490
|
-
results = await store.abatch(processed_operations)
|
|
491
|
-
|
|
492
|
-
# Handle potentially undefined or non-dict results
|
|
493
|
-
processed_results = []
|
|
494
|
-
# Result is of type: Union[Item, list[Item], list[tuple[str, ...]], None]
|
|
495
|
-
for result in results:
|
|
496
|
-
if isinstance(result, Item):
|
|
497
|
-
processed_results.append(result.dict())
|
|
498
|
-
elif isinstance(result, dict):
|
|
499
|
-
processed_results.append(result)
|
|
500
|
-
elif isinstance(result, list):
|
|
501
|
-
coerced = []
|
|
502
|
-
for res in result:
|
|
503
|
-
if isinstance(res, Item):
|
|
504
|
-
coerced.append(res.dict())
|
|
505
|
-
elif isinstance(res, tuple):
|
|
506
|
-
coerced.append(list(res))
|
|
507
|
-
elif res is None:
|
|
508
|
-
coerced.append(res)
|
|
509
|
-
else:
|
|
510
|
-
coerced.append(str(res))
|
|
511
|
-
processed_results.append(coerced)
|
|
512
|
-
elif result is None:
|
|
513
|
-
processed_results.append(None)
|
|
514
|
-
else:
|
|
515
|
-
processed_results.append(str(result))
|
|
516
|
-
return processed_results
|
|
517
|
-
|
|
518
|
-
async def store_get(payload: dict):
|
|
519
|
-
"""Get store data"""
|
|
520
|
-
namespaces_str = payload.get("namespaces")
|
|
521
|
-
key = payload.get("key")
|
|
522
|
-
|
|
523
|
-
if not namespaces_str or not key:
|
|
524
|
-
raise ValueError("Both namespaces and key are required")
|
|
525
|
-
|
|
526
|
-
namespaces = namespaces_str.split(".")
|
|
527
|
-
|
|
528
|
-
store = _get_passthrough_store()
|
|
529
|
-
result = await store.aget(namespaces, key)
|
|
530
|
-
|
|
531
|
-
return result
|
|
532
|
-
|
|
533
|
-
async def store_put(payload: dict):
|
|
534
|
-
"""Put the new store data"""
|
|
535
|
-
|
|
536
|
-
namespace = tuple(payload["namespace"].split("."))
|
|
537
|
-
key = payload["key"]
|
|
538
|
-
value = payload["value"]
|
|
539
|
-
index = payload.get("index")
|
|
540
|
-
|
|
541
|
-
store = _get_passthrough_store()
|
|
542
|
-
await store.aput(namespace, key, value, index=index)
|
|
543
|
-
|
|
544
|
-
return {"success": True}
|
|
545
|
-
|
|
546
|
-
async def store_search(payload: dict):
|
|
547
|
-
"""Search stores"""
|
|
548
|
-
namespace_prefix = tuple(payload["namespace_prefix"])
|
|
549
|
-
filter = payload.get("filter")
|
|
550
|
-
limit = payload.get("limit", 10)
|
|
551
|
-
offset = payload.get("offset", 0)
|
|
552
|
-
query = payload.get("query")
|
|
553
|
-
|
|
554
|
-
store = _get_passthrough_store()
|
|
555
|
-
result = await store.asearch(
|
|
556
|
-
namespace_prefix, filter=filter, limit=limit, offset=offset, query=query
|
|
557
|
-
)
|
|
558
|
-
|
|
559
|
-
return [item.dict() for item in result]
|
|
560
|
-
|
|
561
|
-
async def store_delete(payload: dict):
|
|
562
|
-
"""Delete store data"""
|
|
563
|
-
|
|
564
|
-
namespace = tuple(payload["namespace"])
|
|
565
|
-
key = payload["key"]
|
|
566
|
-
|
|
567
|
-
store = _get_passthrough_store()
|
|
568
|
-
await store.adelete(namespace, key)
|
|
569
|
-
|
|
570
|
-
return {"success": True}
|
|
571
|
-
|
|
572
|
-
async def store_list_namespaces(payload: dict):
|
|
573
|
-
"""List all namespaces"""
|
|
574
|
-
prefix = tuple(payload.get("prefix", [])) or None
|
|
575
|
-
suffix = tuple(payload.get("suffix", [])) or None
|
|
576
|
-
max_depth = payload.get("max_depth")
|
|
577
|
-
limit = payload.get("limit", 100)
|
|
578
|
-
offset = payload.get("offset", 0)
|
|
579
|
-
|
|
580
|
-
store = _get_passthrough_store()
|
|
581
|
-
result = await store.alist_namespaces(
|
|
582
|
-
prefix=prefix,
|
|
583
|
-
suffix=suffix,
|
|
584
|
-
max_depth=max_depth,
|
|
585
|
-
limit=limit,
|
|
586
|
-
offset=offset,
|
|
587
|
-
)
|
|
588
|
-
|
|
589
|
-
return [list(ns) for ns in result]
|
|
590
|
-
|
|
591
|
-
methods = {
|
|
592
|
-
checkpointer_get_tuple,
|
|
593
|
-
checkpointer_list,
|
|
594
|
-
checkpointer_put,
|
|
595
|
-
checkpointer_put_writes,
|
|
596
|
-
store_get,
|
|
597
|
-
store_put,
|
|
598
|
-
store_delete,
|
|
599
|
-
store_search,
|
|
600
|
-
store_batch,
|
|
601
|
-
store_list_namespaces,
|
|
602
|
-
}
|
|
603
|
-
|
|
604
|
-
method_map: dict[str, Callable[[dict], Any]] = {
|
|
605
|
-
method.__name__: method for method in methods
|
|
606
|
-
}
|
|
607
|
-
|
|
608
|
-
with (
|
|
609
|
-
clientDealer.connect(CLIENT_ADDR),
|
|
610
|
-
remoteMonitor.connect(MONITOR_ADDR),
|
|
611
|
-
remoteRouter.bind(REMOTE_ADDR),
|
|
612
|
-
):
|
|
613
|
-
poller = zmq.asyncio.Poller()
|
|
614
|
-
poller.register(remoteRouter, zmq.POLLIN)
|
|
615
|
-
poller.register(clientDealer, zmq.POLLIN)
|
|
616
|
-
poller.register(remoteMonitor, zmq.POLLIN)
|
|
617
|
-
|
|
618
|
-
while not asyncio.current_task().cancelled():
|
|
619
|
-
events = dict(await poller.poll())
|
|
620
|
-
|
|
621
|
-
if remoteRouter in events:
|
|
622
|
-
identity, raw_req = await remoteRouter.recv_multipart()
|
|
623
|
-
req: RequestPayload = json_loads(raw_req)
|
|
624
|
-
|
|
625
|
-
method = req.get("method")
|
|
626
|
-
id = req.get("id")
|
|
627
|
-
|
|
628
|
-
try:
|
|
629
|
-
if fn := method_map.get(method):
|
|
630
|
-
data = await fn(req.get("data"))
|
|
631
|
-
else:
|
|
632
|
-
raise ValueError(f"Unknown method {method}")
|
|
633
|
-
|
|
634
|
-
resp = {"method": method, "id": id, "success": True, "data": data}
|
|
635
|
-
await remoteRouter.send_multipart([identity, json_dumpb(resp)])
|
|
636
|
-
|
|
637
|
-
except BaseException as exc:
|
|
638
|
-
await logger.aexception(
|
|
639
|
-
f"Error in remote method {method}", exc_info=exc
|
|
640
|
-
)
|
|
641
|
-
|
|
642
|
-
resp = {"method": method, "id": id, "success": False, "data": exc}
|
|
643
|
-
await remoteRouter.send_multipart([identity, json_dumpb(resp)])
|
|
644
|
-
|
|
645
|
-
if clientDealer in events:
|
|
646
|
-
response: ResponsePayload = json_loads(await clientDealer.recv())
|
|
647
|
-
queue = REMOTE_MGS_MAP.get(response["id"])
|
|
648
|
-
|
|
649
|
-
if queue:
|
|
650
|
-
if response.get("success") is None:
|
|
651
|
-
await queue.put(HeartbeatPing())
|
|
652
|
-
|
|
653
|
-
elif response["success"]:
|
|
654
|
-
data: Any = response["data"]
|
|
655
|
-
await queue.put(data)
|
|
656
|
-
else:
|
|
657
|
-
data: ErrorData = response["data"]
|
|
658
|
-
await queue.put(RemoteException(data["error"], data["message"]))
|
|
659
|
-
|
|
660
|
-
if remoteMonitor in events:
|
|
661
|
-
msg = await recv_monitor_message(remoteMonitor)
|
|
662
|
-
if msg["event"] == zmq.EVENT_HANDSHAKE_SUCCEEDED:
|
|
663
|
-
logger.info("JS worker connected")
|
|
664
|
-
elif msg["event"] == zmq.EVENT_DISCONNECTED:
|
|
665
|
-
logger.info("JS worker disconnected")
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
async def wait_until_js_ready():
|
|
669
|
-
attempt = 0
|
|
670
|
-
while not asyncio.current_task().cancelled():
|
|
671
|
-
try:
|
|
672
|
-
await _client_invoke("ok", {})
|
|
673
|
-
return
|
|
674
|
-
except (RemoteException, TimeoutError, zmq.error.ZMQBaseError):
|
|
675
|
-
if attempt > 240:
|
|
676
|
-
raise
|
|
677
|
-
else:
|
|
678
|
-
attempt += 1
|
|
679
|
-
await asyncio.sleep(0.5)
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
async def js_healthcheck():
|
|
683
|
-
try:
|
|
684
|
-
await _client_invoke("ok", {})
|
|
685
|
-
return True
|
|
686
|
-
except (RemoteException, TimeoutError, zmq.error.ZMQBaseError) as exc:
|
|
687
|
-
logger.warning(
|
|
688
|
-
"JS healthcheck failed. Either the JS server is not running or the event loop is blocked by a CPU-intensive task.",
|
|
689
|
-
error=exc,
|
|
690
|
-
)
|
|
691
|
-
raise HTTPException(
|
|
692
|
-
status_code=500,
|
|
693
|
-
detail="JS healthcheck failed. Either the JS server is not running or the event loop is blocked by a CPU-intensive task.",
|
|
694
|
-
) from exc
|