langgraph-executor 0.0.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_executor/__init__.py +1 -0
- langgraph_executor/common.py +395 -0
- langgraph_executor/example.py +29 -0
- langgraph_executor/execute_task.py +239 -0
- langgraph_executor/executor.py +341 -0
- langgraph_executor/extract_graph.py +178 -0
- langgraph_executor/info_logger.py +111 -0
- langgraph_executor/pb/__init__.py +0 -0
- langgraph_executor/pb/executor_pb2.py +79 -0
- langgraph_executor/pb/executor_pb2.pyi +415 -0
- langgraph_executor/pb/executor_pb2_grpc.py +321 -0
- langgraph_executor/pb/executor_pb2_grpc.pyi +150 -0
- langgraph_executor/pb/graph_pb2.py +55 -0
- langgraph_executor/pb/graph_pb2.pyi +230 -0
- langgraph_executor/pb/graph_pb2_grpc.py +24 -0
- langgraph_executor/pb/graph_pb2_grpc.pyi +17 -0
- langgraph_executor/pb/runtime_pb2.py +68 -0
- langgraph_executor/pb/runtime_pb2.pyi +364 -0
- langgraph_executor/pb/runtime_pb2_grpc.py +322 -0
- langgraph_executor/pb/runtime_pb2_grpc.pyi +151 -0
- langgraph_executor/pb/types_pb2.py +144 -0
- langgraph_executor/pb/types_pb2.pyi +1044 -0
- langgraph_executor/pb/types_pb2_grpc.py +24 -0
- langgraph_executor/pb/types_pb2_grpc.pyi +17 -0
- langgraph_executor/py.typed +0 -0
- langgraph_executor/server.py +186 -0
- langgraph_executor/setup.sh +29 -0
- langgraph_executor/stream_utils.py +96 -0
- langgraph_executor-0.0.1a0.dist-info/METADATA +14 -0
- langgraph_executor-0.0.1a0.dist-info/RECORD +31 -0
- langgraph_executor-0.0.1a0.dist-info/WHEEL +4 -0
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "0.0.1a0"
|
@@ -0,0 +1,395 @@
|
|
1
|
+
import traceback
|
2
|
+
from collections.abc import Mapping, Sequence
|
3
|
+
from collections.abc import Sequence as SequenceType
|
4
|
+
from typing import Any, cast
|
5
|
+
|
6
|
+
from google.protobuf.json_format import MessageToDict
|
7
|
+
from langchain_core.runnables import RunnableConfig
|
8
|
+
from langgraph._internal._constants import (
|
9
|
+
CONFIG_KEY_CHECKPOINT_ID,
|
10
|
+
CONFIG_KEY_CHECKPOINT_MAP,
|
11
|
+
CONFIG_KEY_CHECKPOINT_NS,
|
12
|
+
CONFIG_KEY_DURABILITY,
|
13
|
+
CONFIG_KEY_RESUME_MAP,
|
14
|
+
CONFIG_KEY_RESUMING,
|
15
|
+
CONFIG_KEY_TASK_ID,
|
16
|
+
CONFIG_KEY_THREAD_ID,
|
17
|
+
TASKS,
|
18
|
+
)
|
19
|
+
from langgraph._internal._scratchpad import PregelScratchpad
|
20
|
+
from langgraph._internal._typing import MISSING
|
21
|
+
from langgraph.channels.base import BaseChannel, EmptyChannelError
|
22
|
+
from langgraph.checkpoint.base import Checkpoint, PendingWrite
|
23
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
24
|
+
from langgraph.errors import GraphBubbleUp, GraphInterrupt
|
25
|
+
from langgraph.managed.base import ManagedValue, ManagedValueMapping
|
26
|
+
from langgraph.pregel import Pregel
|
27
|
+
from langgraph.pregel._algo import PregelTaskWrites
|
28
|
+
from langgraph.pregel._read import PregelNode
|
29
|
+
from langgraph.types import Command, Interrupt, Send
|
30
|
+
|
31
|
+
from langgraph_executor.pb import types_pb2
|
32
|
+
|
33
|
+
|
34
|
+
def map_reserved_configurable(
|
35
|
+
reserved_configurable: types_pb2.ReservedConfigurable,
|
36
|
+
) -> dict[str, Any]:
|
37
|
+
# serde = JsonPlusSerializer()
|
38
|
+
|
39
|
+
return {
|
40
|
+
CONFIG_KEY_RESUMING: reserved_configurable.resuming,
|
41
|
+
CONFIG_KEY_TASK_ID: reserved_configurable.task_id,
|
42
|
+
CONFIG_KEY_THREAD_ID: reserved_configurable.thread_id,
|
43
|
+
CONFIG_KEY_CHECKPOINT_MAP: reserved_configurable.checkpoint_map,
|
44
|
+
CONFIG_KEY_CHECKPOINT_ID: reserved_configurable.checkpoint_id,
|
45
|
+
CONFIG_KEY_CHECKPOINT_NS: reserved_configurable.checkpoint_ns,
|
46
|
+
CONFIG_KEY_RESUME_MAP: {
|
47
|
+
k: pb_to_val(v) for k, v in reserved_configurable.resume_map.items()
|
48
|
+
}
|
49
|
+
if reserved_configurable.resume_map
|
50
|
+
else None,
|
51
|
+
# CONFIG_KEY_PREVIOUS: serde.loads_typed( TODO
|
52
|
+
# (
|
53
|
+
# reserved_configurable.previous.method,
|
54
|
+
# reserved_configurable.previous.value,
|
55
|
+
# )
|
56
|
+
# )
|
57
|
+
# if reserved_configurable.previous
|
58
|
+
# and reserved_configurable.previous.method != "missing"
|
59
|
+
# else None,
|
60
|
+
CONFIG_KEY_DURABILITY: reserved_configurable.durability,
|
61
|
+
}
|
62
|
+
|
63
|
+
|
64
|
+
def reconstruct_config(pb_config: types_pb2.RunnableConfig) -> RunnableConfig:
|
65
|
+
configurable = MessageToDict(pb_config.configurable)
|
66
|
+
for k, v in map_reserved_configurable(pb_config.reserved_configurable).items():
|
67
|
+
if v or k not in configurable:
|
68
|
+
configurable[k] = v
|
69
|
+
return RunnableConfig(
|
70
|
+
tags=list(pb_config.tags),
|
71
|
+
metadata=MessageToDict(pb_config.metadata),
|
72
|
+
run_name=pb_config.run_name,
|
73
|
+
run_id=pb_config.run_id,
|
74
|
+
max_concurrency=pb_config.max_concurrency,
|
75
|
+
recursion_limit=pb_config.recursion_limit,
|
76
|
+
configurable=configurable,
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def revive_channel(channel: BaseChannel, channel_pb: types_pb2.Channel) -> BaseChannel:
|
81
|
+
val_pb = channel_pb.checkpoint_result
|
82
|
+
val = pb_to_val(val_pb)
|
83
|
+
|
84
|
+
return channel.copy().from_checkpoint(val)
|
85
|
+
|
86
|
+
|
87
|
+
def reconstruct_channels(
|
88
|
+
channels_pb: dict[str, types_pb2.Channel],
|
89
|
+
graph: Pregel,
|
90
|
+
scratchpad: PregelScratchpad,
|
91
|
+
) -> tuple[dict[str, BaseChannel], ManagedValueMapping]:
|
92
|
+
channels = {}
|
93
|
+
managed = {}
|
94
|
+
for k, v in graph.channels.items():
|
95
|
+
if isinstance(v, BaseChannel):
|
96
|
+
assert k in channels_pb
|
97
|
+
channels[k] = revive_channel(v, channels_pb[k])
|
98
|
+
elif isinstance(v, ManagedValue) and scratchpad is not None: # managed values
|
99
|
+
managed[k] = v.get(scratchpad)
|
100
|
+
else:
|
101
|
+
raise NotImplementedError(f"Unrecognized channel value: {type(v)}")
|
102
|
+
|
103
|
+
return channels, managed
|
104
|
+
|
105
|
+
|
106
|
+
def reconstruct_checkpoint(request_checkpoint: types_pb2.Checkpoint) -> Checkpoint:
|
107
|
+
channel_versions = dict(request_checkpoint.channel_versions)
|
108
|
+
versions_seen = {
|
109
|
+
k: dict(v.channel_versions) for k, v in request_checkpoint.versions_seen.items()
|
110
|
+
}
|
111
|
+
# channel_values = unwrap_channel_values(request_checkpoint.channel_values)
|
112
|
+
|
113
|
+
return Checkpoint(
|
114
|
+
v=request_checkpoint.v,
|
115
|
+
id=request_checkpoint.id,
|
116
|
+
channel_versions=channel_versions,
|
117
|
+
# channel_values=channel_values,
|
118
|
+
versions_seen=versions_seen,
|
119
|
+
ts=request_checkpoint.ts,
|
120
|
+
)
|
121
|
+
|
122
|
+
|
123
|
+
def reconstruct_task_writes(
|
124
|
+
request_tasks: SequenceType[Any],
|
125
|
+
) -> SequenceType[PregelTaskWrites]:
|
126
|
+
# serde = JsonPlusSerializer()
|
127
|
+
return [
|
128
|
+
PregelTaskWrites(
|
129
|
+
tuple(t.task_path),
|
130
|
+
t.name,
|
131
|
+
[(w.channel, pb_to_val(w.value)) for w in t.writes],
|
132
|
+
t.triggers,
|
133
|
+
)
|
134
|
+
for t in request_tasks
|
135
|
+
]
|
136
|
+
|
137
|
+
|
138
|
+
def checkpoint_to_proto(checkpoint: Checkpoint) -> types_pb2.Checkpoint:
|
139
|
+
checkpoint_proto = types_pb2.Checkpoint()
|
140
|
+
checkpoint_proto.channel_versions.update(checkpoint["channel_versions"])
|
141
|
+
for node, versions_dict in checkpoint["versions_seen"].items():
|
142
|
+
checkpoint_proto.versions_seen[node].channel_versions.update(versions_dict)
|
143
|
+
|
144
|
+
return checkpoint_proto
|
145
|
+
|
146
|
+
|
147
|
+
def updates_to_proto(
|
148
|
+
checkpoint_proto: types_pb2.Checkpoint,
|
149
|
+
updated_channel_names: Sequence[str],
|
150
|
+
channels: types_pb2.Channels,
|
151
|
+
) -> types_pb2.Updates:
|
152
|
+
return types_pb2.Updates(
|
153
|
+
checkpoint=checkpoint_proto,
|
154
|
+
updated_channels=updated_channel_names,
|
155
|
+
channels=channels,
|
156
|
+
)
|
157
|
+
|
158
|
+
|
159
|
+
def get_graph(
|
160
|
+
graph_name: str,
|
161
|
+
graphs: dict[str, Pregel],
|
162
|
+
) -> Pregel:
|
163
|
+
if graph_name not in graphs:
|
164
|
+
raise ValueError(f"Graph {graph_name} not supported")
|
165
|
+
return graphs[graph_name]
|
166
|
+
|
167
|
+
|
168
|
+
def get_node(node_name: str, graph: Pregel, graph_name: str) -> PregelNode:
|
169
|
+
if node_name not in graph.nodes:
|
170
|
+
raise ValueError(f"Node {node_name} not found in graph {graph_name}")
|
171
|
+
return graph.nodes[node_name]
|
172
|
+
|
173
|
+
|
174
|
+
def pb_to_val(value: types_pb2.Value) -> Any:
|
175
|
+
serde = JsonPlusSerializer()
|
176
|
+
|
177
|
+
value_kind = value.WhichOneof("message")
|
178
|
+
if value_kind == "base_value":
|
179
|
+
return serde.loads_typed((value.base_value.method, value.base_value.value))
|
180
|
+
if value_kind == "sends":
|
181
|
+
sends = []
|
182
|
+
for send in value.sends.sends:
|
183
|
+
node = send.node
|
184
|
+
arg = pb_to_val(send.arg)
|
185
|
+
sends.append(Send(node, arg))
|
186
|
+
return sends
|
187
|
+
if value_kind == "missing":
|
188
|
+
return MISSING
|
189
|
+
if value_kind == "command":
|
190
|
+
graph, update, resume, goto = None, None, None, ()
|
191
|
+
if value.command.graph is not None:
|
192
|
+
graph = value.command.graph
|
193
|
+
if value.command.update is not None:
|
194
|
+
if (
|
195
|
+
isinstance(value.command.update, dict)
|
196
|
+
and len(value.command.update) == 1
|
197
|
+
and "__root__" in value.command.update
|
198
|
+
):
|
199
|
+
update = pb_to_val(value.command.update["__root__"])
|
200
|
+
else:
|
201
|
+
update = {k: pb_to_val(v) for k, v in value.command.update.items()}
|
202
|
+
if value.command.resume:
|
203
|
+
which = value.command.resume.WhichOneof("message")
|
204
|
+
if which == "value":
|
205
|
+
resume = pb_to_val(value.command.resume.value)
|
206
|
+
else:
|
207
|
+
resume_map = {
|
208
|
+
k: pb_to_val(v)
|
209
|
+
for k, v in value.command.resume.values.values.items()
|
210
|
+
}
|
211
|
+
resume = resume_map
|
212
|
+
if value.command.gotos:
|
213
|
+
gotos = []
|
214
|
+
for g in value.command.gotos:
|
215
|
+
which = g.WhichOneof("message")
|
216
|
+
if which == "node_name":
|
217
|
+
gotos.append(g.node_name.name)
|
218
|
+
else:
|
219
|
+
gotos.append(Send(g.send.node, pb_to_val(g.send.arg)))
|
220
|
+
if len(gotos) == 1:
|
221
|
+
gotos = gotos[0]
|
222
|
+
goto = gotos
|
223
|
+
return Command(graph=graph, update=update, resume=resume, goto=goto)
|
224
|
+
raise NotImplementedError(f"Unrecognized value kind: {value_kind}")
|
225
|
+
|
226
|
+
|
227
|
+
def send_to_pb(send: Send) -> types_pb2.Send:
|
228
|
+
return types_pb2.Send(
|
229
|
+
node=send.node,
|
230
|
+
arg=val_to_pb(TASKS if isinstance(send.arg, Send) else None, send.arg),
|
231
|
+
)
|
232
|
+
|
233
|
+
|
234
|
+
def sends_to_pb(sends: list[Send]) -> types_pb2.Value:
|
235
|
+
if not sends:
|
236
|
+
return missing_to_pb()
|
237
|
+
pb = []
|
238
|
+
for send in sends:
|
239
|
+
pb.append(send_to_pb(send))
|
240
|
+
|
241
|
+
return types_pb2.Value(sends=types_pb2.Sends(sends=pb))
|
242
|
+
|
243
|
+
|
244
|
+
def command_to_pb(cmd: Command) -> types_pb2.Value:
|
245
|
+
cmd_pb = types_pb2.Command()
|
246
|
+
if cmd.graph:
|
247
|
+
if not cmd.graph == Command.PARENT:
|
248
|
+
raise ValueError("command graph must be null or parent")
|
249
|
+
cmd_pb.graph = cmd.graph
|
250
|
+
if cmd.update:
|
251
|
+
if isinstance(cmd.update, dict):
|
252
|
+
cmd_pb.update.update({k: val_to_pb(None, v) for k, v in cmd.update.items()})
|
253
|
+
else:
|
254
|
+
cmd_pb.update.update({"__root__": val_to_pb(None, cmd.update)})
|
255
|
+
if cmd.resume:
|
256
|
+
if isinstance(cmd.resume, dict):
|
257
|
+
cmd_pb.resume.CopyFrom(resume_map_to_pb(cmd.resume))
|
258
|
+
else:
|
259
|
+
resume_val = types_pb2.Resume(value=val_to_pb(None, cmd.resume))
|
260
|
+
cmd_pb.resume.CopyFrom(resume_val)
|
261
|
+
if cmd.goto:
|
262
|
+
gotos = []
|
263
|
+
goto = cmd.goto
|
264
|
+
if isinstance(goto, list):
|
265
|
+
for g in goto:
|
266
|
+
gotos.append(goto_to_pb(g))
|
267
|
+
else:
|
268
|
+
gotos.append(goto_to_pb(cast(Send | str, goto)))
|
269
|
+
cmd_pb.gotos.extend(gotos)
|
270
|
+
|
271
|
+
return types_pb2.Value(command=cmd_pb)
|
272
|
+
|
273
|
+
|
274
|
+
def resume_map_to_pb(resume: dict[str, Any] | Any) -> types_pb2.Resume:
|
275
|
+
vals = {k: val_to_pb(None, v) for k, v in resume.items()}
|
276
|
+
return types_pb2.Resume(values=types_pb2.InterruptValues(values=vals))
|
277
|
+
|
278
|
+
|
279
|
+
def goto_to_pb(goto: Send | str) -> types_pb2.Goto:
|
280
|
+
if isinstance(goto, Send):
|
281
|
+
return types_pb2.Goto(send=send_to_pb(goto))
|
282
|
+
if isinstance(goto, str):
|
283
|
+
return types_pb2.Goto(node_name=types_pb2.NodeName(name=goto))
|
284
|
+
raise ValueError("goto must be send or node name")
|
285
|
+
|
286
|
+
|
287
|
+
def missing_to_pb() -> types_pb2.Value:
|
288
|
+
pb = types_pb2.Value()
|
289
|
+
pb.missing.SetInParent()
|
290
|
+
return pb
|
291
|
+
|
292
|
+
|
293
|
+
def base_value_to_pb(value: Any) -> types_pb2.Value:
|
294
|
+
serialized_value = serialize_value(value)
|
295
|
+
|
296
|
+
return types_pb2.Value(base_value=serialized_value)
|
297
|
+
|
298
|
+
|
299
|
+
def serialize_value(value: Any) -> types_pb2.SerializedValue:
|
300
|
+
serde = JsonPlusSerializer()
|
301
|
+
|
302
|
+
meth, ser_val = serde.dumps_typed(value)
|
303
|
+
return types_pb2.SerializedValue(method=meth, value=bytes(ser_val))
|
304
|
+
|
305
|
+
|
306
|
+
def val_to_pb(channel_name: str | None, value: Any) -> types_pb2.Value:
|
307
|
+
if channel_name == TASKS and value != MISSING:
|
308
|
+
if not isinstance(value, list):
|
309
|
+
if not isinstance(value, Send):
|
310
|
+
raise ValueError(
|
311
|
+
"Task must be a Send object objects."
|
312
|
+
f" Got type={type(value)} value={value}",
|
313
|
+
)
|
314
|
+
value = [value]
|
315
|
+
else:
|
316
|
+
for v in value:
|
317
|
+
if not isinstance(v, Send):
|
318
|
+
raise ValueError(
|
319
|
+
"Task must be a list of Send objects."
|
320
|
+
f" Got types={[type(v) for v in value]} values={value}",
|
321
|
+
)
|
322
|
+
return sends_to_pb(value)
|
323
|
+
if value == MISSING:
|
324
|
+
return missing_to_pb()
|
325
|
+
if isinstance(value, Command):
|
326
|
+
return command_to_pb(value)
|
327
|
+
return base_value_to_pb(value)
|
328
|
+
|
329
|
+
|
330
|
+
def extract_channel(name: str, channel: BaseChannel) -> types_pb2.Channel:
|
331
|
+
try:
|
332
|
+
get_result = channel.get()
|
333
|
+
except EmptyChannelError:
|
334
|
+
get_result = MISSING
|
335
|
+
|
336
|
+
return types_pb2.Channel(
|
337
|
+
get_result=val_to_pb(name, get_result),
|
338
|
+
is_available_result=channel.is_available(),
|
339
|
+
checkpoint_result=val_to_pb(name, channel.checkpoint()),
|
340
|
+
)
|
341
|
+
|
342
|
+
|
343
|
+
def extract_channels(
|
344
|
+
channels: Mapping[str, BaseChannel | type[ManagedValue]],
|
345
|
+
) -> types_pb2.Channels:
|
346
|
+
pb = {}
|
347
|
+
for name, channel in channels.items():
|
348
|
+
if isinstance(channel, BaseChannel):
|
349
|
+
pb[name] = extract_channel(name, channel)
|
350
|
+
return types_pb2.Channels(channels=pb)
|
351
|
+
|
352
|
+
|
353
|
+
def exception_to_pb(exc: Exception) -> types_pb2.ExecutorError:
|
354
|
+
executor_error_pb = None
|
355
|
+
if isinstance(exc, GraphInterrupt):
|
356
|
+
if exc.args[0]:
|
357
|
+
interrupts = [interrupt_to_pb(interrupt) for interrupt in exc.args[0]]
|
358
|
+
graph_interrupt_pb = types_pb2.GraphInterrupt(
|
359
|
+
interrupts=interrupts,
|
360
|
+
interrupts_serialized=serialize_value(
|
361
|
+
exc.args[0] if len(exc.args[0]) != 1 else exc.args[0][0],
|
362
|
+
), # brittle fix
|
363
|
+
)
|
364
|
+
else:
|
365
|
+
graph_interrupt_pb = types_pb2.GraphInterrupt()
|
366
|
+
executor_error_pb = types_pb2.ExecutorError(graph_interrupt=graph_interrupt_pb)
|
367
|
+
elif isinstance(exc, GraphBubbleUp):
|
368
|
+
bubbleup_pb = types_pb2.GraphBubbleUp()
|
369
|
+
executor_error_pb = types_pb2.ExecutorError(graph_bubble_up=bubbleup_pb)
|
370
|
+
else:
|
371
|
+
base_error_pb = types_pb2.BaseError(
|
372
|
+
error_type=str(type(exc)),
|
373
|
+
error_message=str(exc),
|
374
|
+
error_serialized=serialize_value(exc),
|
375
|
+
)
|
376
|
+
executor_error_pb = types_pb2.ExecutorError(base_error=base_error_pb)
|
377
|
+
executor_error_pb.traceback = traceback.format_exc()
|
378
|
+
|
379
|
+
return executor_error_pb
|
380
|
+
|
381
|
+
|
382
|
+
def interrupt_to_pb(interrupt: Interrupt) -> types_pb2.Interrupt:
|
383
|
+
return types_pb2.Interrupt(
|
384
|
+
value=val_to_pb(None, interrupt.value),
|
385
|
+
id=interrupt.id,
|
386
|
+
)
|
387
|
+
|
388
|
+
|
389
|
+
def pb_to_pending_writes(
|
390
|
+
pb: SequenceType[types_pb2.PendingWrite],
|
391
|
+
) -> list[PendingWrite] | None:
|
392
|
+
if not pb:
|
393
|
+
return None
|
394
|
+
|
395
|
+
return [(pw.task_id, pw.channel, pb_to_val(pw.value)) for pw in pb]
|
@@ -0,0 +1,29 @@
|
|
1
|
+
import grpc
|
2
|
+
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
3
|
+
|
4
|
+
from langgraph_executor.pb import runtime_pb2, runtime_pb2_grpc, types_pb2
|
5
|
+
|
6
|
+
RUNTIME_SERVER_ADDRESS = "localhost:50051"
|
7
|
+
|
8
|
+
if __name__ == "__main__":
|
9
|
+
channel = grpc.insecure_channel(RUNTIME_SERVER_ADDRESS)
|
10
|
+
stub = runtime_pb2_grpc.LangGraphRuntimeStub(channel)
|
11
|
+
|
12
|
+
serde = JsonPlusSerializer()
|
13
|
+
input_raw = {"messages": ["hi"], "count": 0}
|
14
|
+
method, ser = serde.dumps_typed(input_raw)
|
15
|
+
input = types_pb2.SerializedValue(method=method, value=bytes(ser))
|
16
|
+
|
17
|
+
request = runtime_pb2.InvokeRequest(
|
18
|
+
graph_name="example",
|
19
|
+
input=input,
|
20
|
+
config=types_pb2.RunnableConfig(
|
21
|
+
recursion_limit=25,
|
22
|
+
max_concurrency=1,
|
23
|
+
reserved_configurable=types_pb2.ReservedConfigurable(),
|
24
|
+
),
|
25
|
+
)
|
26
|
+
|
27
|
+
response = stub.Invoke(request)
|
28
|
+
|
29
|
+
print("response: ", response)
|
@@ -0,0 +1,239 @@
|
|
1
|
+
from collections import deque
|
2
|
+
from dataclasses import is_dataclass
|
3
|
+
from functools import partial
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from langchain_core.runnables import RunnableConfig
|
7
|
+
from langgraph._internal._config import patch_config
|
8
|
+
from langgraph._internal._constants import (
|
9
|
+
CACHE_NS_WRITES,
|
10
|
+
CONF,
|
11
|
+
CONFIG_KEY_CHECKPOINT_NS,
|
12
|
+
CONFIG_KEY_READ,
|
13
|
+
CONFIG_KEY_RESUME_MAP,
|
14
|
+
CONFIG_KEY_RUNTIME,
|
15
|
+
CONFIG_KEY_SCRATCHPAD,
|
16
|
+
CONFIG_KEY_SEND,
|
17
|
+
PULL,
|
18
|
+
PUSH,
|
19
|
+
)
|
20
|
+
from langgraph._internal._scratchpad import PregelScratchpad
|
21
|
+
from langgraph.pregel import Pregel
|
22
|
+
from langgraph.pregel._algo import (
|
23
|
+
PregelTaskWrites,
|
24
|
+
_proc_input,
|
25
|
+
_scratchpad,
|
26
|
+
local_read,
|
27
|
+
)
|
28
|
+
from langgraph.pregel._call import identifier
|
29
|
+
from langgraph.runtime import DEFAULT_RUNTIME, Runtime
|
30
|
+
from langgraph.store.base import BaseStore
|
31
|
+
from langgraph.types import CacheKey, PregelExecutableTask
|
32
|
+
from pydantic import BaseModel
|
33
|
+
from xxhash import xxh3_128_hexdigest
|
34
|
+
|
35
|
+
from langgraph_executor.common import (
|
36
|
+
get_node,
|
37
|
+
pb_to_pending_writes,
|
38
|
+
pb_to_val,
|
39
|
+
reconstruct_channels,
|
40
|
+
reconstruct_config,
|
41
|
+
val_to_pb,
|
42
|
+
)
|
43
|
+
from langgraph_executor.pb import types_pb2
|
44
|
+
|
45
|
+
|
46
|
+
def get_init_request(request_iterator):
|
47
|
+
request = next(request_iterator)
|
48
|
+
|
49
|
+
if not hasattr(request, "init"):
|
50
|
+
raise ValueError("First message must be init")
|
51
|
+
|
52
|
+
return request.init
|
53
|
+
|
54
|
+
|
55
|
+
def reconstruct_task(
|
56
|
+
request,
|
57
|
+
graph: Pregel,
|
58
|
+
*,
|
59
|
+
store: BaseStore | None = None,
|
60
|
+
config: RunnableConfig | None = None,
|
61
|
+
) -> PregelExecutableTask:
|
62
|
+
pb_task = request.task
|
63
|
+
|
64
|
+
try:
|
65
|
+
proc = get_node(pb_task.name, graph, pb_task.graph_name)
|
66
|
+
if config is None:
|
67
|
+
config = reconstruct_config(pb_task.config)
|
68
|
+
configurable = config.get(CONF, {})
|
69
|
+
|
70
|
+
scratchpad = create_scratchpad(
|
71
|
+
config,
|
72
|
+
pb_task,
|
73
|
+
request.step,
|
74
|
+
request.stop,
|
75
|
+
)
|
76
|
+
channels, managed = reconstruct_channels(
|
77
|
+
request.channels.channels,
|
78
|
+
graph,
|
79
|
+
scratchpad,
|
80
|
+
)
|
81
|
+
if pb_task.task_path[0] == PULL:
|
82
|
+
val = _proc_input(
|
83
|
+
proc,
|
84
|
+
managed,
|
85
|
+
channels,
|
86
|
+
for_execution=True,
|
87
|
+
scratchpad=scratchpad,
|
88
|
+
input_cache=None,
|
89
|
+
)
|
90
|
+
elif pb_task.task_path[0] == PUSH:
|
91
|
+
val = pb_to_val(pb_task.input["PUSH_INPUT"])
|
92
|
+
|
93
|
+
writes = deque()
|
94
|
+
runtime = ensure_runtime(configurable, store, graph)
|
95
|
+
|
96
|
+
# Generate cache key if cache policy exists
|
97
|
+
cache_policy = getattr(proc, "cache_policy", None)
|
98
|
+
cache_key = None
|
99
|
+
if cache_policy:
|
100
|
+
args_key = cache_policy.key_func(
|
101
|
+
*([val] if not isinstance(val, list | tuple) else val),
|
102
|
+
)
|
103
|
+
cache_key = CacheKey(
|
104
|
+
(CACHE_NS_WRITES, identifier(proc.node) or "__dynamic__"),
|
105
|
+
xxh3_128_hexdigest(
|
106
|
+
args_key.encode() if isinstance(args_key, str) else args_key,
|
107
|
+
),
|
108
|
+
cache_policy.ttl,
|
109
|
+
)
|
110
|
+
|
111
|
+
task = PregelExecutableTask(
|
112
|
+
name=pb_task.name,
|
113
|
+
input=val,
|
114
|
+
proc=proc.node,
|
115
|
+
writes=writes,
|
116
|
+
config=patch_config(
|
117
|
+
config,
|
118
|
+
configurable={
|
119
|
+
CONFIG_KEY_SEND: writes.extend,
|
120
|
+
CONFIG_KEY_READ: partial(
|
121
|
+
local_read,
|
122
|
+
scratchpad,
|
123
|
+
channels,
|
124
|
+
managed,
|
125
|
+
PregelTaskWrites(
|
126
|
+
tuple(pb_task.task_path)[:3],
|
127
|
+
pb_task.name,
|
128
|
+
writes,
|
129
|
+
pb_task.triggers,
|
130
|
+
),
|
131
|
+
),
|
132
|
+
CONFIG_KEY_RUNTIME: runtime,
|
133
|
+
CONFIG_KEY_SCRATCHPAD: scratchpad,
|
134
|
+
},
|
135
|
+
),
|
136
|
+
triggers=pb_task.triggers,
|
137
|
+
id=pb_task.id,
|
138
|
+
path=pb_task.task_path,
|
139
|
+
retry_policy=proc.retry_policy or [], # TODO support
|
140
|
+
cache_key=cache_key, # TODO support
|
141
|
+
writers=proc.flat_writers,
|
142
|
+
subgraphs=proc.subgraphs,
|
143
|
+
)
|
144
|
+
|
145
|
+
except Exception as e:
|
146
|
+
raise e
|
147
|
+
|
148
|
+
return task
|
149
|
+
|
150
|
+
|
151
|
+
def extract_writes(writes) -> list[types_pb2.Write]:
|
152
|
+
w = []
|
153
|
+
for channel, val in writes:
|
154
|
+
val_pb = val_to_pb(channel, val)
|
155
|
+
channel_write = types_pb2.Write(channel=channel, value=val_pb)
|
156
|
+
w.append(channel_write)
|
157
|
+
|
158
|
+
return w
|
159
|
+
|
160
|
+
|
161
|
+
def create_scratchpad(
|
162
|
+
config: RunnableConfig,
|
163
|
+
pb_task: types_pb2.Task,
|
164
|
+
step: int,
|
165
|
+
stop: int,
|
166
|
+
) -> PregelScratchpad:
|
167
|
+
task_checkpoint_ns: str = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) or ""
|
168
|
+
pending_writes = (
|
169
|
+
pb_to_pending_writes(pb_task.pending_writes)
|
170
|
+
if (hasattr(pb_task, "pending_writes") and len(pb_task.pending_writes) > 0)
|
171
|
+
else []
|
172
|
+
)
|
173
|
+
|
174
|
+
scratchpad = _scratchpad(
|
175
|
+
config[CONF].get(CONFIG_KEY_SCRATCHPAD),
|
176
|
+
pending_writes or [],
|
177
|
+
pb_task.id,
|
178
|
+
xxh3_128_hexdigest(task_checkpoint_ns.encode()),
|
179
|
+
config[CONF].get(CONFIG_KEY_RESUME_MAP),
|
180
|
+
step,
|
181
|
+
stop,
|
182
|
+
)
|
183
|
+
|
184
|
+
return scratchpad
|
185
|
+
|
186
|
+
|
187
|
+
def ensure_runtime(configurable, store, graph):
|
188
|
+
runtime = configurable.get(CONFIG_KEY_RUNTIME)
|
189
|
+
if runtime is None:
|
190
|
+
return DEFAULT_RUNTIME.override(store=store)
|
191
|
+
if isinstance(runtime, Runtime):
|
192
|
+
return runtime.override(store=store)
|
193
|
+
if isinstance(runtime, dict):
|
194
|
+
context = _coerce_context(graph, runtime.get("context"))
|
195
|
+
return Runtime(**(runtime | {"store": store, "context": context}))
|
196
|
+
raise ValueError("Invalid runtime")
|
197
|
+
|
198
|
+
|
199
|
+
def _coerce_context(graph: Pregel, context: Any) -> Any:
|
200
|
+
if context is None:
|
201
|
+
return None
|
202
|
+
|
203
|
+
context_schema = graph.context_schema
|
204
|
+
if context_schema is None:
|
205
|
+
return context
|
206
|
+
|
207
|
+
schema_is_class = issubclass(context_schema, BaseModel) or is_dataclass(
|
208
|
+
context_schema,
|
209
|
+
)
|
210
|
+
if isinstance(context, dict) and schema_is_class:
|
211
|
+
return context_schema(**_filter_context_by_schema(context, graph))
|
212
|
+
|
213
|
+
return context
|
214
|
+
|
215
|
+
|
216
|
+
_CACHE = {}
|
217
|
+
|
218
|
+
|
219
|
+
def _filter_context_by_schema(context: dict[str, Any], graph: Pregel) -> dict[str, Any]:
|
220
|
+
if graph not in _CACHE:
|
221
|
+
_CACHE[graph] = graph.get_context_jsonschema()
|
222
|
+
if len(_CACHE) > 500:
|
223
|
+
_CACHE.popitem()
|
224
|
+
json_schema = _CACHE[graph]
|
225
|
+
if not json_schema or not context:
|
226
|
+
return context
|
227
|
+
|
228
|
+
# Extract valid properties from the schema
|
229
|
+
properties = json_schema.get("properties", {})
|
230
|
+
if not properties:
|
231
|
+
return context
|
232
|
+
|
233
|
+
# Filter context to only include parameters defined in the schema
|
234
|
+
filtered_context = {}
|
235
|
+
for key, value in context.items():
|
236
|
+
if key in properties:
|
237
|
+
filtered_context[key] = value
|
238
|
+
|
239
|
+
return filtered_context
|