langgraph-api 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langgraph-api might be problematic. Click here for more details.

@@ -0,0 +1,685 @@
1
+ import asyncio
2
+ import os
3
+ import shutil
4
+ from collections.abc import AsyncIterator, Callable
5
+ from contextlib import AbstractContextManager
6
+ from typing import Any, Literal
7
+
8
+ import orjson
9
+ import structlog
10
+ import zmq
11
+ import zmq.asyncio
12
+ from langchain_core.runnables.config import RunnableConfig
13
+ from langchain_core.runnables.graph import Edge, Node
14
+ from langchain_core.runnables.graph import Graph as DrawableGraph
15
+ from langchain_core.runnables.schema import (
16
+ CustomStreamEvent,
17
+ StandardStreamEvent,
18
+ StreamEvent,
19
+ )
20
+ from langgraph.checkpoint.base.id import uuid6
21
+ from langgraph.checkpoint.serde.base import SerializerProtocol
22
+ from langgraph.pregel.types import PregelTask, StateSnapshot
23
+ from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp
24
+ from langgraph.types import Command, Interrupt
25
+ from pydantic import BaseModel
26
+ from zmq.utils.monitor import recv_monitor_message
27
+
28
+ from langgraph_api.js.base import BaseRemotePregel
29
+ from langgraph_api.js.errors import RemoteException
30
+ from langgraph_api.js.schema import (
31
+ ErrorData,
32
+ RequestPayload,
33
+ ResponsePayload,
34
+ StreamData,
35
+ )
36
+ from langgraph_api.serde import json_dumpb, json_loads
37
+ from langgraph_api.utils import AsyncConnectionProto
38
+
39
+ logger = structlog.stdlib.get_logger(__name__)
40
+
41
+ CLIENT_ADDR = "tcp://0.0.0.0:5556"
42
+ REMOTE_ADDR = "tcp://0.0.0.0:5555"
43
+ MONITOR_ADDR = "inproc://router.remote"
44
+
45
+ WAIT_FOR_REQUEST_TTL_SECONDS = 15.0
46
+
47
+ context = zmq.asyncio.Context()
48
+ clientDealer = context.socket(zmq.DEALER)
49
+ remoteRouter = context.socket(zmq.ROUTER)
50
+
51
+ remoteRouter.monitor(
52
+ MONITOR_ADDR,
53
+ zmq.EVENT_HANDSHAKE_SUCCEEDED | zmq.EVENT_DISCONNECTED,
54
+ )
55
+ remoteMonitor = context.socket(zmq.PAIR)
56
+ remoteMonitor.setsockopt(zmq.LINGER, 0)
57
+
58
+ REMOTE_MGS_MAP: dict[str, asyncio.Queue] = {}
59
+
60
+
61
+ class HeartbeatPing:
62
+ pass
63
+
64
+
65
+ class RequestTimeout(Exception):
66
+ method: str
67
+
68
+ def __init__(self, method: str):
69
+ self.method = method
70
+
71
+ def __str__(self):
72
+ return f'Request to "{self.method}" timed out'
73
+
74
+
75
+ class RequestContext(AbstractContextManager["RequestContext"]):
76
+ id: str
77
+ method: str
78
+
79
+ def __init__(self, method: str):
80
+ self.id = uuid6().hex
81
+ self.method = method
82
+
83
+ def __enter__(self) -> "RequestContext":
84
+ REMOTE_MGS_MAP[self.id] = asyncio.Queue()
85
+ return self
86
+
87
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
88
+ REMOTE_MGS_MAP.pop(self.id, None)
89
+
90
+ async def get(self) -> Any:
91
+ try:
92
+ value = await asyncio.wait_for(
93
+ REMOTE_MGS_MAP[self.id].get(), timeout=WAIT_FOR_REQUEST_TTL_SECONDS
94
+ )
95
+ except TimeoutError as exc:
96
+ raise RequestTimeout(self.method) from exc
97
+
98
+ if isinstance(value, HeartbeatPing):
99
+ return await self.get()
100
+ if isinstance(value, Exception):
101
+ raise value
102
+ return value
103
+
104
+
105
+ async def _client_stream(method: str, data: Any):
106
+ with RequestContext(method) as context:
107
+ await clientDealer.send(
108
+ orjson.dumps({"method": method, "id": context.id, "data": data})
109
+ )
110
+
111
+ while True:
112
+ response: StreamData = await context.get()
113
+
114
+ if response["done"]:
115
+ if "value" in response:
116
+ yield response["value"]
117
+ break
118
+
119
+ yield response["value"]
120
+
121
+
122
+ async def _client_invoke(method: str, data: Any):
123
+ with RequestContext(method) as context:
124
+ await clientDealer.send(
125
+ orjson.dumps({"method": method, "id": context.id, "data": data})
126
+ )
127
+
128
+ return await context.get()
129
+
130
+
131
+ class RemotePregel(BaseRemotePregel):
132
+ @staticmethod
133
+ def load(graph_id: str):
134
+ model = RemotePregel()
135
+ model.graph_id = graph_id
136
+ return model
137
+
138
+ async def astream_events(
139
+ self,
140
+ input: Any,
141
+ config: RunnableConfig | None = None,
142
+ *,
143
+ version: Literal["v1", "v2"],
144
+ **kwargs: Any,
145
+ ) -> AsyncIterator[StreamEvent]:
146
+ if version != "v2":
147
+ raise ValueError("Only v2 of astream_events is supported")
148
+
149
+ data = {
150
+ "graph_id": self.graph_id,
151
+ "command" if isinstance(input, Command) else "input": input,
152
+ "config": config,
153
+ **kwargs,
154
+ }
155
+
156
+ async for event in _client_stream("streamEvents", data):
157
+ if event["event"] == "on_custom_event":
158
+ yield CustomStreamEvent(**event)
159
+ else:
160
+ yield StandardStreamEvent(**event)
161
+
162
+ async def fetch_state_schema(self):
163
+ return await _client_invoke("getSchema", {"graph_id": self.graph_id})
164
+
165
+ async def fetch_graph(
166
+ self,
167
+ config: RunnableConfig | None = None,
168
+ *,
169
+ xray: int | bool = False,
170
+ ) -> DrawableGraph:
171
+ response = await _client_invoke(
172
+ "getGraph", {"graph_id": self.graph_id, "config": config, "xray": xray}
173
+ )
174
+
175
+ nodes: list[Any] = response.pop("nodes")
176
+ edges: list[Any] = response.pop("edges")
177
+
178
+ class NoopModel(BaseModel):
179
+ pass
180
+
181
+ return DrawableGraph(
182
+ {
183
+ data["id"]: Node(
184
+ data["id"], data["id"], NoopModel(), data.get("metadata")
185
+ )
186
+ for data in nodes
187
+ },
188
+ {
189
+ Edge(
190
+ data["source"],
191
+ data["target"],
192
+ data.get("data"),
193
+ data.get("conditional", False),
194
+ )
195
+ for data in edges
196
+ },
197
+ )
198
+
199
+ async def fetch_subgraphs(
200
+ self, *, namespace: str | None = None, recurse: bool = False
201
+ ) -> dict[str, dict]:
202
+ return await _client_invoke(
203
+ "getSubgraphs",
204
+ {"graph_id": self.graph_id, "namespace": namespace, "recurse": recurse},
205
+ )
206
+
207
+ def _convert_state_snapshot(self, item: dict) -> StateSnapshot:
208
+ def _convert_tasks(tasks: list[dict]) -> tuple[PregelTask, ...]:
209
+ result: list[PregelTask] = []
210
+ for task in tasks:
211
+ state = task.get("state")
212
+
213
+ if state and isinstance(state, dict) and "config" in state:
214
+ state = self._convert_state_snapshot(state)
215
+
216
+ result.append(
217
+ PregelTask(
218
+ task["id"],
219
+ task["name"],
220
+ tuple(task["path"]) if task.get("path") else tuple(),
221
+ # TODO: figure out how to properly deserialise errors
222
+ task.get("error"),
223
+ (
224
+ tuple(
225
+ Interrupt(
226
+ value=interrupt["value"],
227
+ when=interrupt["when"],
228
+ resumable=interrupt.get("resumable", True),
229
+ ns=interrupt.get("ns"),
230
+ )
231
+ for interrupt in task.get("interrupts")
232
+ )
233
+ if task.get("interrupts")
234
+ else []
235
+ ),
236
+ state,
237
+ )
238
+ )
239
+ return tuple(result)
240
+
241
+ return StateSnapshot(
242
+ item.get("values"),
243
+ item.get("next"),
244
+ item.get("config"),
245
+ item.get("metadata"),
246
+ item.get("createdAt"),
247
+ item.get("parentConfig"),
248
+ _convert_tasks(item.get("tasks", [])),
249
+ )
250
+
251
+ async def aget_state(
252
+ self, config: RunnableConfig, *, subgraphs: bool = False
253
+ ) -> StateSnapshot:
254
+ return self._convert_state_snapshot(
255
+ await _client_invoke(
256
+ "getState",
257
+ {"graph_id": self.graph_id, "config": config, "subgraphs": subgraphs},
258
+ )
259
+ )
260
+
261
+ async def aupdate_state(
262
+ self,
263
+ config: RunnableConfig,
264
+ values: dict[str, Any] | Any,
265
+ as_node: str | None = None,
266
+ ) -> RunnableConfig:
267
+ response = await _client_invoke(
268
+ "updateState",
269
+ {
270
+ "graph_id": self.graph_id,
271
+ "config": config,
272
+ "values": values,
273
+ "as_node": as_node,
274
+ },
275
+ )
276
+ return RunnableConfig(**response)
277
+
278
+ async def aget_state_history(
279
+ self,
280
+ config: RunnableConfig,
281
+ *,
282
+ filter: dict[str, Any] | None = None,
283
+ before: RunnableConfig | None = None,
284
+ limit: int | None = None,
285
+ ) -> AsyncIterator[StateSnapshot]:
286
+ async for event in _client_stream(
287
+ "getStateHistory",
288
+ {
289
+ "graph_id": self.graph_id,
290
+ "config": config,
291
+ "limit": limit,
292
+ "filter": filter,
293
+ "before": before,
294
+ },
295
+ ):
296
+ yield self._convert_state_snapshot(event)
297
+
298
+ def get_graph(
299
+ self,
300
+ config: RunnableConfig | None = None,
301
+ *,
302
+ xray: int | bool = False,
303
+ ) -> dict[str, Any]:
304
+ raise Exception("Not implemented")
305
+
306
+ def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
307
+ raise Exception("Not implemented")
308
+
309
+ def get_output_schema(
310
+ self, config: RunnableConfig | None = None
311
+ ) -> type[BaseModel]:
312
+ raise Exception("Not implemented")
313
+
314
+ def config_schema(self) -> type[BaseModel]:
315
+ raise Exception("Not implemented")
316
+
317
+ async def invoke(self, input: Any, config: RunnableConfig | None = None):
318
+ raise Exception("Not implemented")
319
+
320
+
321
+ async def run_js_process(paths_str: str, watch: bool = False):
322
+ # check if tsx is available
323
+ tsx_path = shutil.which("tsx")
324
+ if tsx_path is None:
325
+ raise FileNotFoundError("tsx not found in PATH")
326
+ attempt = 0
327
+ while not asyncio.current_task().cancelled():
328
+ client_file = os.path.join(os.path.dirname(__file__), "client.new.mts")
329
+ args = ("tsx", client_file)
330
+ if watch:
331
+ args = ("tsx", "watch", client_file, "--skip-schema-cache")
332
+ try:
333
+ process = await asyncio.create_subprocess_exec(
334
+ *args,
335
+ env={
336
+ "LANGSERVE_GRAPHS": paths_str,
337
+ "LANGCHAIN_CALLBACKS_BACKGROUND": "true",
338
+ "CHOKIDAR_USEPOLLING": "true",
339
+ **os.environ,
340
+ },
341
+ )
342
+ code = await process.wait()
343
+ raise Exception(f"JS process exited with code {code}")
344
+ except asyncio.CancelledError:
345
+ logger.info("Terminating JS graphs process")
346
+ try:
347
+ process.terminate()
348
+ await process.wait()
349
+ except (UnboundLocalError, ProcessLookupError):
350
+ pass
351
+ raise
352
+ except Exception:
353
+ if attempt >= 3:
354
+ raise
355
+ else:
356
+ logger.warning(f"Retrying JS process {3 - attempt} more times...")
357
+ attempt += 1
358
+
359
+
360
+ def _get_passthrough_checkpointer(conn: AsyncConnectionProto):
361
+ from langgraph_storage.checkpoint import Checkpointer
362
+
363
+ class PassthroughSerialiser(SerializerProtocol):
364
+ def dumps(self, obj: Any) -> bytes:
365
+ return json_dumpb(obj)
366
+
367
+ def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
368
+ return "json", json_dumpb(obj)
369
+
370
+ def loads(self, data: bytes) -> Any:
371
+ return orjson.loads(data)
372
+
373
+ def loads_typed(self, data: tuple[str, bytes]) -> Any:
374
+ type, payload = data
375
+ if type != "json":
376
+ raise ValueError(f"Unsupported type {type}")
377
+ return orjson.loads(payload)
378
+
379
+ checkpointer = Checkpointer(conn)
380
+
381
+ # This checkpointer does not attempt to revive LC-objects.
382
+ # Instead, it will pass through the JSON values as-is.
383
+ checkpointer.serde = PassthroughSerialiser()
384
+
385
+ return checkpointer
386
+
387
+
388
+ def _get_passthrough_store():
389
+ from langgraph_storage.store import Store
390
+
391
+ return Store()
392
+
393
+
394
+ async def run_remote_checkpointer():
395
+ from langgraph_storage.database import connect
396
+
397
+ async def checkpointer_list(payload: dict):
398
+ """Search checkpoints"""
399
+
400
+ result = []
401
+ async with connect() as conn:
402
+ checkpointer = _get_passthrough_checkpointer(conn)
403
+ async for item in checkpointer.alist(
404
+ config=payload.get("config"),
405
+ limit=payload.get("limit"),
406
+ before=payload.get("before"),
407
+ filter=payload.get("filter"),
408
+ ):
409
+ result.append(item)
410
+
411
+ return result
412
+
413
+ async def checkpointer_put(payload: dict):
414
+ """Put the new checkpoint metadata"""
415
+
416
+ async with connect() as conn:
417
+ checkpointer = _get_passthrough_checkpointer(conn)
418
+ return await checkpointer.aput(
419
+ payload["config"],
420
+ payload["checkpoint"],
421
+ payload["metadata"],
422
+ payload.get("new_versions", {}),
423
+ )
424
+
425
+ async def checkpointer_get_tuple(payload: dict):
426
+ """Get actual checkpoint values (reads)"""
427
+
428
+ async with connect() as conn:
429
+ checkpointer = _get_passthrough_checkpointer(conn)
430
+ return await checkpointer.aget_tuple(config=payload["config"])
431
+
432
+ async def checkpointer_put_writes(payload: dict):
433
+ """Put actual checkpoint values (writes)"""
434
+
435
+ async with connect() as conn:
436
+ checkpointer = _get_passthrough_checkpointer(conn)
437
+ return await checkpointer.aput_writes(
438
+ payload["config"],
439
+ payload["writes"],
440
+ payload["taskId"],
441
+ )
442
+
443
+ async def store_batch(payload: dict):
444
+ """Batch operations on the store"""
445
+ operations = payload.get("operations", [])
446
+
447
+ if not operations:
448
+ raise ValueError("No operations provided")
449
+
450
+ # Convert raw operations to proper objects
451
+ processed_operations = []
452
+ for op in operations:
453
+ if "value" in op:
454
+ processed_operations.append(
455
+ PutOp(
456
+ namespace=tuple(op["namespace"]),
457
+ key=op["key"],
458
+ value=op["value"],
459
+ )
460
+ )
461
+ elif "namespace_prefix" in op:
462
+ processed_operations.append(
463
+ SearchOp(
464
+ namespace_prefix=tuple(op["namespace_prefix"]),
465
+ filter=op.get("filter"),
466
+ limit=op.get("limit", 10),
467
+ offset=op.get("offset", 0),
468
+ )
469
+ )
470
+
471
+ elif "namespace" in op and "key" in op:
472
+ processed_operations.append(
473
+ GetOp(namespace=tuple(op["namespace"]), key=op["key"])
474
+ )
475
+ elif "match_conditions" in op:
476
+ processed_operations.append(
477
+ ListNamespacesOp(
478
+ match_conditions=tuple(op["match_conditions"]),
479
+ max_depth=op.get("max_depth"),
480
+ limit=op.get("limit", 100),
481
+ offset=op.get("offset", 0),
482
+ )
483
+ )
484
+ else:
485
+ raise ValueError(f"Unknown operation type: {op}")
486
+
487
+ store = _get_passthrough_store()
488
+ results = await store.abatch(processed_operations)
489
+
490
+ # Handle potentially undefined or non-dict results
491
+ processed_results = []
492
+ # Result is of type: Union[Item, list[Item], list[tuple[str, ...]], None]
493
+ for result in results:
494
+ if isinstance(result, Item):
495
+ processed_results.append(result.dict())
496
+ elif isinstance(result, dict):
497
+ processed_results.append(result)
498
+ elif isinstance(result, list):
499
+ coerced = []
500
+ for res in result:
501
+ if isinstance(res, Item):
502
+ coerced.append(res.dict())
503
+ elif isinstance(res, tuple):
504
+ coerced.append(list(res))
505
+ elif res is None:
506
+ coerced.append(res)
507
+ else:
508
+ coerced.append(str(res))
509
+ processed_results.append(coerced)
510
+ elif result is None:
511
+ processed_results.append(None)
512
+ else:
513
+ processed_results.append(str(result))
514
+ return processed_results
515
+
516
+ async def store_get(payload: dict):
517
+ """Get store data"""
518
+ namespaces_str = payload.get("namespaces")
519
+ key = payload.get("key")
520
+
521
+ if not namespaces_str or not key:
522
+ raise ValueError("Both namespaces and key are required")
523
+
524
+ namespaces = namespaces_str.split(".")
525
+
526
+ store = _get_passthrough_store()
527
+ result = await store.aget(namespaces, key)
528
+
529
+ return result
530
+
531
+ async def store_put(payload: dict):
532
+ """Put the new store data"""
533
+
534
+ namespace = tuple(payload["namespace"].split("."))
535
+ key = payload["key"]
536
+ value = payload["value"]
537
+ index = payload.get("index")
538
+
539
+ store = _get_passthrough_store()
540
+ await store.aput(namespace, key, value, index=index)
541
+
542
+ return {"success": True}
543
+
544
+ async def store_search(payload: dict):
545
+ """Search stores"""
546
+ namespace_prefix = tuple(payload["namespace_prefix"])
547
+ filter = payload.get("filter")
548
+ limit = payload.get("limit", 10)
549
+ offset = payload.get("offset", 0)
550
+ query = payload.get("query")
551
+
552
+ store = _get_passthrough_store()
553
+ result = await store.asearch(
554
+ namespace_prefix, filter=filter, limit=limit, offset=offset, query=query
555
+ )
556
+
557
+ return [item.dict() for item in result]
558
+
559
+ async def store_delete(payload: dict):
560
+ """Delete store data"""
561
+
562
+ namespace = tuple(payload["namespace"])
563
+ key = payload["key"]
564
+
565
+ store = _get_passthrough_store()
566
+ await store.adelete(namespace, key)
567
+
568
+ return {"success": True}
569
+
570
+ async def store_list_namespaces(payload: dict):
571
+ """List all namespaces"""
572
+ prefix = tuple(payload.get("prefix", [])) or None
573
+ suffix = tuple(payload.get("suffix", [])) or None
574
+ max_depth = payload.get("max_depth")
575
+ limit = payload.get("limit", 100)
576
+ offset = payload.get("offset", 0)
577
+
578
+ store = _get_passthrough_store()
579
+ result = await store.alist_namespaces(
580
+ prefix=prefix,
581
+ suffix=suffix,
582
+ max_depth=max_depth,
583
+ limit=limit,
584
+ offset=offset,
585
+ )
586
+
587
+ return [list(ns) for ns in result]
588
+
589
+ methods = {
590
+ checkpointer_get_tuple,
591
+ checkpointer_list,
592
+ checkpointer_put,
593
+ checkpointer_put_writes,
594
+ store_get,
595
+ store_put,
596
+ store_delete,
597
+ store_search,
598
+ store_batch,
599
+ store_list_namespaces,
600
+ }
601
+
602
+ method_map: dict[str, Callable[[dict], Any]] = {
603
+ method.__name__: method for method in methods
604
+ }
605
+
606
+ with (
607
+ clientDealer.connect(CLIENT_ADDR),
608
+ remoteMonitor.connect(MONITOR_ADDR),
609
+ remoteRouter.bind(REMOTE_ADDR),
610
+ ):
611
+ poller = zmq.asyncio.Poller()
612
+ poller.register(remoteRouter, zmq.POLLIN)
613
+ poller.register(clientDealer, zmq.POLLIN)
614
+ poller.register(remoteMonitor, zmq.POLLIN)
615
+
616
+ while not asyncio.current_task().cancelled():
617
+ events = dict(await poller.poll())
618
+
619
+ if remoteRouter in events:
620
+ identity, raw_req = await remoteRouter.recv_multipart()
621
+ req: RequestPayload = json_loads(raw_req)
622
+
623
+ method = req.get("method")
624
+ id = req.get("id")
625
+
626
+ try:
627
+ if fn := method_map.get(method):
628
+ data = await fn(req.get("data"))
629
+ else:
630
+ raise ValueError(f"Unknown method {method}")
631
+
632
+ resp = {"method": method, "id": id, "success": True, "data": data}
633
+ await remoteRouter.send_multipart([identity, json_dumpb(resp)])
634
+
635
+ except BaseException as exc:
636
+ await logger.aexception(
637
+ f"Error in remote method {method}", exc_info=exc
638
+ )
639
+
640
+ resp = {"method": method, "id": id, "success": False, "data": exc}
641
+ await remoteRouter.send_multipart([identity, json_dumpb(resp)])
642
+
643
+ if clientDealer in events:
644
+ response: ResponsePayload = json_loads(await clientDealer.recv())
645
+ queue = REMOTE_MGS_MAP.get(response["id"])
646
+
647
+ if queue:
648
+ if response.get("success") is None:
649
+ await queue.put(HeartbeatPing())
650
+
651
+ elif response["success"]:
652
+ data: Any = response["data"]
653
+ await queue.put(data)
654
+ else:
655
+ data: ErrorData = response["data"]
656
+ await queue.put(RemoteException(data["error"], data["message"]))
657
+
658
+ if remoteMonitor in events:
659
+ msg = await recv_monitor_message(remoteMonitor)
660
+ if msg["event"] == zmq.EVENT_HANDSHAKE_SUCCEEDED:
661
+ logger.info("JS worker connected")
662
+ elif msg["event"] == zmq.EVENT_DISCONNECTED:
663
+ logger.info("JS worker disconnected")
664
+
665
+
666
+ async def wait_until_js_ready():
667
+ attempt = 0
668
+ while not asyncio.current_task().cancelled():
669
+ try:
670
+ await _client_invoke("ok", {})
671
+ return
672
+ except (RemoteException, TimeoutError, zmq.error.ZMQBaseError):
673
+ if attempt > 240:
674
+ raise
675
+ else:
676
+ attempt += 1
677
+ await asyncio.sleep(0.5)
678
+
679
+
680
+ async def js_healthcheck():
681
+ try:
682
+ await _client_invoke("ok", {})
683
+ return True
684
+ except (RemoteException, TimeoutError, zmq.error.ZMQBaseError):
685
+ return False