langgraph-executor 0.0.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1a0"
@@ -0,0 +1,395 @@
1
+ import traceback
2
+ from collections.abc import Mapping, Sequence
3
+ from collections.abc import Sequence as SequenceType
4
+ from typing import Any, cast
5
+
6
+ from google.protobuf.json_format import MessageToDict
7
+ from langchain_core.runnables import RunnableConfig
8
+ from langgraph._internal._constants import (
9
+ CONFIG_KEY_CHECKPOINT_ID,
10
+ CONFIG_KEY_CHECKPOINT_MAP,
11
+ CONFIG_KEY_CHECKPOINT_NS,
12
+ CONFIG_KEY_DURABILITY,
13
+ CONFIG_KEY_RESUME_MAP,
14
+ CONFIG_KEY_RESUMING,
15
+ CONFIG_KEY_TASK_ID,
16
+ CONFIG_KEY_THREAD_ID,
17
+ TASKS,
18
+ )
19
+ from langgraph._internal._scratchpad import PregelScratchpad
20
+ from langgraph._internal._typing import MISSING
21
+ from langgraph.channels.base import BaseChannel, EmptyChannelError
22
+ from langgraph.checkpoint.base import Checkpoint, PendingWrite
23
+ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
24
+ from langgraph.errors import GraphBubbleUp, GraphInterrupt
25
+ from langgraph.managed.base import ManagedValue, ManagedValueMapping
26
+ from langgraph.pregel import Pregel
27
+ from langgraph.pregel._algo import PregelTaskWrites
28
+ from langgraph.pregel._read import PregelNode
29
+ from langgraph.types import Command, Interrupt, Send
30
+
31
+ from langgraph_executor.pb import types_pb2
32
+
33
+
34
+ def map_reserved_configurable(
35
+ reserved_configurable: types_pb2.ReservedConfigurable,
36
+ ) -> dict[str, Any]:
37
+ # serde = JsonPlusSerializer()
38
+
39
+ return {
40
+ CONFIG_KEY_RESUMING: reserved_configurable.resuming,
41
+ CONFIG_KEY_TASK_ID: reserved_configurable.task_id,
42
+ CONFIG_KEY_THREAD_ID: reserved_configurable.thread_id,
43
+ CONFIG_KEY_CHECKPOINT_MAP: reserved_configurable.checkpoint_map,
44
+ CONFIG_KEY_CHECKPOINT_ID: reserved_configurable.checkpoint_id,
45
+ CONFIG_KEY_CHECKPOINT_NS: reserved_configurable.checkpoint_ns,
46
+ CONFIG_KEY_RESUME_MAP: {
47
+ k: pb_to_val(v) for k, v in reserved_configurable.resume_map.items()
48
+ }
49
+ if reserved_configurable.resume_map
50
+ else None,
51
+ # CONFIG_KEY_PREVIOUS: serde.loads_typed( TODO
52
+ # (
53
+ # reserved_configurable.previous.method,
54
+ # reserved_configurable.previous.value,
55
+ # )
56
+ # )
57
+ # if reserved_configurable.previous
58
+ # and reserved_configurable.previous.method != "missing"
59
+ # else None,
60
+ CONFIG_KEY_DURABILITY: reserved_configurable.durability,
61
+ }
62
+
63
+
64
+ def reconstruct_config(pb_config: types_pb2.RunnableConfig) -> RunnableConfig:
65
+ configurable = MessageToDict(pb_config.configurable)
66
+ for k, v in map_reserved_configurable(pb_config.reserved_configurable).items():
67
+ if v or k not in configurable:
68
+ configurable[k] = v
69
+ return RunnableConfig(
70
+ tags=list(pb_config.tags),
71
+ metadata=MessageToDict(pb_config.metadata),
72
+ run_name=pb_config.run_name,
73
+ run_id=pb_config.run_id,
74
+ max_concurrency=pb_config.max_concurrency,
75
+ recursion_limit=pb_config.recursion_limit,
76
+ configurable=configurable,
77
+ )
78
+
79
+
80
+ def revive_channel(channel: BaseChannel, channel_pb: types_pb2.Channel) -> BaseChannel:
81
+ val_pb = channel_pb.checkpoint_result
82
+ val = pb_to_val(val_pb)
83
+
84
+ return channel.copy().from_checkpoint(val)
85
+
86
+
87
+ def reconstruct_channels(
88
+ channels_pb: dict[str, types_pb2.Channel],
89
+ graph: Pregel,
90
+ scratchpad: PregelScratchpad,
91
+ ) -> tuple[dict[str, BaseChannel], ManagedValueMapping]:
92
+ channels = {}
93
+ managed = {}
94
+ for k, v in graph.channels.items():
95
+ if isinstance(v, BaseChannel):
96
+ assert k in channels_pb
97
+ channels[k] = revive_channel(v, channels_pb[k])
98
+ elif isinstance(v, ManagedValue) and scratchpad is not None: # managed values
99
+ managed[k] = v.get(scratchpad)
100
+ else:
101
+ raise NotImplementedError(f"Unrecognized channel value: {type(v)}")
102
+
103
+ return channels, managed
104
+
105
+
106
+ def reconstruct_checkpoint(request_checkpoint: types_pb2.Checkpoint) -> Checkpoint:
107
+ channel_versions = dict(request_checkpoint.channel_versions)
108
+ versions_seen = {
109
+ k: dict(v.channel_versions) for k, v in request_checkpoint.versions_seen.items()
110
+ }
111
+ # channel_values = unwrap_channel_values(request_checkpoint.channel_values)
112
+
113
+ return Checkpoint(
114
+ v=request_checkpoint.v,
115
+ id=request_checkpoint.id,
116
+ channel_versions=channel_versions,
117
+ # channel_values=channel_values,
118
+ versions_seen=versions_seen,
119
+ ts=request_checkpoint.ts,
120
+ )
121
+
122
+
123
+ def reconstruct_task_writes(
124
+ request_tasks: SequenceType[Any],
125
+ ) -> SequenceType[PregelTaskWrites]:
126
+ # serde = JsonPlusSerializer()
127
+ return [
128
+ PregelTaskWrites(
129
+ tuple(t.task_path),
130
+ t.name,
131
+ [(w.channel, pb_to_val(w.value)) for w in t.writes],
132
+ t.triggers,
133
+ )
134
+ for t in request_tasks
135
+ ]
136
+
137
+
138
+ def checkpoint_to_proto(checkpoint: Checkpoint) -> types_pb2.Checkpoint:
139
+ checkpoint_proto = types_pb2.Checkpoint()
140
+ checkpoint_proto.channel_versions.update(checkpoint["channel_versions"])
141
+ for node, versions_dict in checkpoint["versions_seen"].items():
142
+ checkpoint_proto.versions_seen[node].channel_versions.update(versions_dict)
143
+
144
+ return checkpoint_proto
145
+
146
+
147
+ def updates_to_proto(
148
+ checkpoint_proto: types_pb2.Checkpoint,
149
+ updated_channel_names: Sequence[str],
150
+ channels: types_pb2.Channels,
151
+ ) -> types_pb2.Updates:
152
+ return types_pb2.Updates(
153
+ checkpoint=checkpoint_proto,
154
+ updated_channels=updated_channel_names,
155
+ channels=channels,
156
+ )
157
+
158
+
159
+ def get_graph(
160
+ graph_name: str,
161
+ graphs: dict[str, Pregel],
162
+ ) -> Pregel:
163
+ if graph_name not in graphs:
164
+ raise ValueError(f"Graph {graph_name} not supported")
165
+ return graphs[graph_name]
166
+
167
+
168
+ def get_node(node_name: str, graph: Pregel, graph_name: str) -> PregelNode:
169
+ if node_name not in graph.nodes:
170
+ raise ValueError(f"Node {node_name} not found in graph {graph_name}")
171
+ return graph.nodes[node_name]
172
+
173
+
174
+ def pb_to_val(value: types_pb2.Value) -> Any:
175
+ serde = JsonPlusSerializer()
176
+
177
+ value_kind = value.WhichOneof("message")
178
+ if value_kind == "base_value":
179
+ return serde.loads_typed((value.base_value.method, value.base_value.value))
180
+ if value_kind == "sends":
181
+ sends = []
182
+ for send in value.sends.sends:
183
+ node = send.node
184
+ arg = pb_to_val(send.arg)
185
+ sends.append(Send(node, arg))
186
+ return sends
187
+ if value_kind == "missing":
188
+ return MISSING
189
+ if value_kind == "command":
190
+ graph, update, resume, goto = None, None, None, ()
191
+ if value.command.graph is not None:
192
+ graph = value.command.graph
193
+ if value.command.update is not None:
194
+ if (
195
+ isinstance(value.command.update, dict)
196
+ and len(value.command.update) == 1
197
+ and "__root__" in value.command.update
198
+ ):
199
+ update = pb_to_val(value.command.update["__root__"])
200
+ else:
201
+ update = {k: pb_to_val(v) for k, v in value.command.update.items()}
202
+ if value.command.resume:
203
+ which = value.command.resume.WhichOneof("message")
204
+ if which == "value":
205
+ resume = pb_to_val(value.command.resume.value)
206
+ else:
207
+ resume_map = {
208
+ k: pb_to_val(v)
209
+ for k, v in value.command.resume.values.values.items()
210
+ }
211
+ resume = resume_map
212
+ if value.command.gotos:
213
+ gotos = []
214
+ for g in value.command.gotos:
215
+ which = g.WhichOneof("message")
216
+ if which == "node_name":
217
+ gotos.append(g.node_name.name)
218
+ else:
219
+ gotos.append(Send(g.send.node, pb_to_val(g.send.arg)))
220
+ if len(gotos) == 1:
221
+ gotos = gotos[0]
222
+ goto = gotos
223
+ return Command(graph=graph, update=update, resume=resume, goto=goto)
224
+ raise NotImplementedError(f"Unrecognized value kind: {value_kind}")
225
+
226
+
227
+ def send_to_pb(send: Send) -> types_pb2.Send:
228
+ return types_pb2.Send(
229
+ node=send.node,
230
+ arg=val_to_pb(TASKS if isinstance(send.arg, Send) else None, send.arg),
231
+ )
232
+
233
+
234
+ def sends_to_pb(sends: list[Send]) -> types_pb2.Value:
235
+ if not sends:
236
+ return missing_to_pb()
237
+ pb = []
238
+ for send in sends:
239
+ pb.append(send_to_pb(send))
240
+
241
+ return types_pb2.Value(sends=types_pb2.Sends(sends=pb))
242
+
243
+
244
+ def command_to_pb(cmd: Command) -> types_pb2.Value:
245
+ cmd_pb = types_pb2.Command()
246
+ if cmd.graph:
247
+ if not cmd.graph == Command.PARENT:
248
+ raise ValueError("command graph must be null or parent")
249
+ cmd_pb.graph = cmd.graph
250
+ if cmd.update:
251
+ if isinstance(cmd.update, dict):
252
+ cmd_pb.update.update({k: val_to_pb(None, v) for k, v in cmd.update.items()})
253
+ else:
254
+ cmd_pb.update.update({"__root__": val_to_pb(None, cmd.update)})
255
+ if cmd.resume:
256
+ if isinstance(cmd.resume, dict):
257
+ cmd_pb.resume.CopyFrom(resume_map_to_pb(cmd.resume))
258
+ else:
259
+ resume_val = types_pb2.Resume(value=val_to_pb(None, cmd.resume))
260
+ cmd_pb.resume.CopyFrom(resume_val)
261
+ if cmd.goto:
262
+ gotos = []
263
+ goto = cmd.goto
264
+ if isinstance(goto, list):
265
+ for g in goto:
266
+ gotos.append(goto_to_pb(g))
267
+ else:
268
+ gotos.append(goto_to_pb(cast(Send | str, goto)))
269
+ cmd_pb.gotos.extend(gotos)
270
+
271
+ return types_pb2.Value(command=cmd_pb)
272
+
273
+
274
+ def resume_map_to_pb(resume: dict[str, Any] | Any) -> types_pb2.Resume:
275
+ vals = {k: val_to_pb(None, v) for k, v in resume.items()}
276
+ return types_pb2.Resume(values=types_pb2.InterruptValues(values=vals))
277
+
278
+
279
+ def goto_to_pb(goto: Send | str) -> types_pb2.Goto:
280
+ if isinstance(goto, Send):
281
+ return types_pb2.Goto(send=send_to_pb(goto))
282
+ if isinstance(goto, str):
283
+ return types_pb2.Goto(node_name=types_pb2.NodeName(name=goto))
284
+ raise ValueError("goto must be send or node name")
285
+
286
+
287
+ def missing_to_pb() -> types_pb2.Value:
288
+ pb = types_pb2.Value()
289
+ pb.missing.SetInParent()
290
+ return pb
291
+
292
+
293
+ def base_value_to_pb(value: Any) -> types_pb2.Value:
294
+ serialized_value = serialize_value(value)
295
+
296
+ return types_pb2.Value(base_value=serialized_value)
297
+
298
+
299
+ def serialize_value(value: Any) -> types_pb2.SerializedValue:
300
+ serde = JsonPlusSerializer()
301
+
302
+ meth, ser_val = serde.dumps_typed(value)
303
+ return types_pb2.SerializedValue(method=meth, value=bytes(ser_val))
304
+
305
+
306
+ def val_to_pb(channel_name: str | None, value: Any) -> types_pb2.Value:
307
+ if channel_name == TASKS and value != MISSING:
308
+ if not isinstance(value, list):
309
+ if not isinstance(value, Send):
310
+ raise ValueError(
311
+ "Task must be a Send object objects."
312
+ f" Got type={type(value)} value={value}",
313
+ )
314
+ value = [value]
315
+ else:
316
+ for v in value:
317
+ if not isinstance(v, Send):
318
+ raise ValueError(
319
+ "Task must be a list of Send objects."
320
+ f" Got types={[type(v) for v in value]} values={value}",
321
+ )
322
+ return sends_to_pb(value)
323
+ if value == MISSING:
324
+ return missing_to_pb()
325
+ if isinstance(value, Command):
326
+ return command_to_pb(value)
327
+ return base_value_to_pb(value)
328
+
329
+
330
+ def extract_channel(name: str, channel: BaseChannel) -> types_pb2.Channel:
331
+ try:
332
+ get_result = channel.get()
333
+ except EmptyChannelError:
334
+ get_result = MISSING
335
+
336
+ return types_pb2.Channel(
337
+ get_result=val_to_pb(name, get_result),
338
+ is_available_result=channel.is_available(),
339
+ checkpoint_result=val_to_pb(name, channel.checkpoint()),
340
+ )
341
+
342
+
343
+ def extract_channels(
344
+ channels: Mapping[str, BaseChannel | type[ManagedValue]],
345
+ ) -> types_pb2.Channels:
346
+ pb = {}
347
+ for name, channel in channels.items():
348
+ if isinstance(channel, BaseChannel):
349
+ pb[name] = extract_channel(name, channel)
350
+ return types_pb2.Channels(channels=pb)
351
+
352
+
353
+ def exception_to_pb(exc: Exception) -> types_pb2.ExecutorError:
354
+ executor_error_pb = None
355
+ if isinstance(exc, GraphInterrupt):
356
+ if exc.args[0]:
357
+ interrupts = [interrupt_to_pb(interrupt) for interrupt in exc.args[0]]
358
+ graph_interrupt_pb = types_pb2.GraphInterrupt(
359
+ interrupts=interrupts,
360
+ interrupts_serialized=serialize_value(
361
+ exc.args[0] if len(exc.args[0]) != 1 else exc.args[0][0],
362
+ ), # brittle fix
363
+ )
364
+ else:
365
+ graph_interrupt_pb = types_pb2.GraphInterrupt()
366
+ executor_error_pb = types_pb2.ExecutorError(graph_interrupt=graph_interrupt_pb)
367
+ elif isinstance(exc, GraphBubbleUp):
368
+ bubbleup_pb = types_pb2.GraphBubbleUp()
369
+ executor_error_pb = types_pb2.ExecutorError(graph_bubble_up=bubbleup_pb)
370
+ else:
371
+ base_error_pb = types_pb2.BaseError(
372
+ error_type=str(type(exc)),
373
+ error_message=str(exc),
374
+ error_serialized=serialize_value(exc),
375
+ )
376
+ executor_error_pb = types_pb2.ExecutorError(base_error=base_error_pb)
377
+ executor_error_pb.traceback = traceback.format_exc()
378
+
379
+ return executor_error_pb
380
+
381
+
382
+ def interrupt_to_pb(interrupt: Interrupt) -> types_pb2.Interrupt:
383
+ return types_pb2.Interrupt(
384
+ value=val_to_pb(None, interrupt.value),
385
+ id=interrupt.id,
386
+ )
387
+
388
+
389
+ def pb_to_pending_writes(
390
+ pb: SequenceType[types_pb2.PendingWrite],
391
+ ) -> list[PendingWrite] | None:
392
+ if not pb:
393
+ return None
394
+
395
+ return [(pw.task_id, pw.channel, pb_to_val(pw.value)) for pw in pb]
@@ -0,0 +1,29 @@
1
+ import grpc
2
+ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
3
+
4
+ from langgraph_executor.pb import runtime_pb2, runtime_pb2_grpc, types_pb2
5
+
6
+ RUNTIME_SERVER_ADDRESS = "localhost:50051"
7
+
8
+ if __name__ == "__main__":
9
+ channel = grpc.insecure_channel(RUNTIME_SERVER_ADDRESS)
10
+ stub = runtime_pb2_grpc.LangGraphRuntimeStub(channel)
11
+
12
+ serde = JsonPlusSerializer()
13
+ input_raw = {"messages": ["hi"], "count": 0}
14
+ method, ser = serde.dumps_typed(input_raw)
15
+ input = types_pb2.SerializedValue(method=method, value=bytes(ser))
16
+
17
+ request = runtime_pb2.InvokeRequest(
18
+ graph_name="example",
19
+ input=input,
20
+ config=types_pb2.RunnableConfig(
21
+ recursion_limit=25,
22
+ max_concurrency=1,
23
+ reserved_configurable=types_pb2.ReservedConfigurable(),
24
+ ),
25
+ )
26
+
27
+ response = stub.Invoke(request)
28
+
29
+ print("response: ", response)
@@ -0,0 +1,239 @@
1
+ from collections import deque
2
+ from dataclasses import is_dataclass
3
+ from functools import partial
4
+ from typing import Any
5
+
6
+ from langchain_core.runnables import RunnableConfig
7
+ from langgraph._internal._config import patch_config
8
+ from langgraph._internal._constants import (
9
+ CACHE_NS_WRITES,
10
+ CONF,
11
+ CONFIG_KEY_CHECKPOINT_NS,
12
+ CONFIG_KEY_READ,
13
+ CONFIG_KEY_RESUME_MAP,
14
+ CONFIG_KEY_RUNTIME,
15
+ CONFIG_KEY_SCRATCHPAD,
16
+ CONFIG_KEY_SEND,
17
+ PULL,
18
+ PUSH,
19
+ )
20
+ from langgraph._internal._scratchpad import PregelScratchpad
21
+ from langgraph.pregel import Pregel
22
+ from langgraph.pregel._algo import (
23
+ PregelTaskWrites,
24
+ _proc_input,
25
+ _scratchpad,
26
+ local_read,
27
+ )
28
+ from langgraph.pregel._call import identifier
29
+ from langgraph.runtime import DEFAULT_RUNTIME, Runtime
30
+ from langgraph.store.base import BaseStore
31
+ from langgraph.types import CacheKey, PregelExecutableTask
32
+ from pydantic import BaseModel
33
+ from xxhash import xxh3_128_hexdigest
34
+
35
+ from langgraph_executor.common import (
36
+ get_node,
37
+ pb_to_pending_writes,
38
+ pb_to_val,
39
+ reconstruct_channels,
40
+ reconstruct_config,
41
+ val_to_pb,
42
+ )
43
+ from langgraph_executor.pb import types_pb2
44
+
45
+
46
+ def get_init_request(request_iterator):
47
+ request = next(request_iterator)
48
+
49
+ if not hasattr(request, "init"):
50
+ raise ValueError("First message must be init")
51
+
52
+ return request.init
53
+
54
+
55
+ def reconstruct_task(
56
+ request,
57
+ graph: Pregel,
58
+ *,
59
+ store: BaseStore | None = None,
60
+ config: RunnableConfig | None = None,
61
+ ) -> PregelExecutableTask:
62
+ pb_task = request.task
63
+
64
+ try:
65
+ proc = get_node(pb_task.name, graph, pb_task.graph_name)
66
+ if config is None:
67
+ config = reconstruct_config(pb_task.config)
68
+ configurable = config.get(CONF, {})
69
+
70
+ scratchpad = create_scratchpad(
71
+ config,
72
+ pb_task,
73
+ request.step,
74
+ request.stop,
75
+ )
76
+ channels, managed = reconstruct_channels(
77
+ request.channels.channels,
78
+ graph,
79
+ scratchpad,
80
+ )
81
+ if pb_task.task_path[0] == PULL:
82
+ val = _proc_input(
83
+ proc,
84
+ managed,
85
+ channels,
86
+ for_execution=True,
87
+ scratchpad=scratchpad,
88
+ input_cache=None,
89
+ )
90
+ elif pb_task.task_path[0] == PUSH:
91
+ val = pb_to_val(pb_task.input["PUSH_INPUT"])
92
+
93
+ writes = deque()
94
+ runtime = ensure_runtime(configurable, store, graph)
95
+
96
+ # Generate cache key if cache policy exists
97
+ cache_policy = getattr(proc, "cache_policy", None)
98
+ cache_key = None
99
+ if cache_policy:
100
+ args_key = cache_policy.key_func(
101
+ *([val] if not isinstance(val, list | tuple) else val),
102
+ )
103
+ cache_key = CacheKey(
104
+ (CACHE_NS_WRITES, identifier(proc.node) or "__dynamic__"),
105
+ xxh3_128_hexdigest(
106
+ args_key.encode() if isinstance(args_key, str) else args_key,
107
+ ),
108
+ cache_policy.ttl,
109
+ )
110
+
111
+ task = PregelExecutableTask(
112
+ name=pb_task.name,
113
+ input=val,
114
+ proc=proc.node,
115
+ writes=writes,
116
+ config=patch_config(
117
+ config,
118
+ configurable={
119
+ CONFIG_KEY_SEND: writes.extend,
120
+ CONFIG_KEY_READ: partial(
121
+ local_read,
122
+ scratchpad,
123
+ channels,
124
+ managed,
125
+ PregelTaskWrites(
126
+ tuple(pb_task.task_path)[:3],
127
+ pb_task.name,
128
+ writes,
129
+ pb_task.triggers,
130
+ ),
131
+ ),
132
+ CONFIG_KEY_RUNTIME: runtime,
133
+ CONFIG_KEY_SCRATCHPAD: scratchpad,
134
+ },
135
+ ),
136
+ triggers=pb_task.triggers,
137
+ id=pb_task.id,
138
+ path=pb_task.task_path,
139
+ retry_policy=proc.retry_policy or [], # TODO support
140
+ cache_key=cache_key, # TODO support
141
+ writers=proc.flat_writers,
142
+ subgraphs=proc.subgraphs,
143
+ )
144
+
145
+ except Exception as e:
146
+ raise e
147
+
148
+ return task
149
+
150
+
151
+ def extract_writes(writes) -> list[types_pb2.Write]:
152
+ w = []
153
+ for channel, val in writes:
154
+ val_pb = val_to_pb(channel, val)
155
+ channel_write = types_pb2.Write(channel=channel, value=val_pb)
156
+ w.append(channel_write)
157
+
158
+ return w
159
+
160
+
161
+ def create_scratchpad(
162
+ config: RunnableConfig,
163
+ pb_task: types_pb2.Task,
164
+ step: int,
165
+ stop: int,
166
+ ) -> PregelScratchpad:
167
+ task_checkpoint_ns: str = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) or ""
168
+ pending_writes = (
169
+ pb_to_pending_writes(pb_task.pending_writes)
170
+ if (hasattr(pb_task, "pending_writes") and len(pb_task.pending_writes) > 0)
171
+ else []
172
+ )
173
+
174
+ scratchpad = _scratchpad(
175
+ config[CONF].get(CONFIG_KEY_SCRATCHPAD),
176
+ pending_writes or [],
177
+ pb_task.id,
178
+ xxh3_128_hexdigest(task_checkpoint_ns.encode()),
179
+ config[CONF].get(CONFIG_KEY_RESUME_MAP),
180
+ step,
181
+ stop,
182
+ )
183
+
184
+ return scratchpad
185
+
186
+
187
+ def ensure_runtime(configurable, store, graph):
188
+ runtime = configurable.get(CONFIG_KEY_RUNTIME)
189
+ if runtime is None:
190
+ return DEFAULT_RUNTIME.override(store=store)
191
+ if isinstance(runtime, Runtime):
192
+ return runtime.override(store=store)
193
+ if isinstance(runtime, dict):
194
+ context = _coerce_context(graph, runtime.get("context"))
195
+ return Runtime(**(runtime | {"store": store, "context": context}))
196
+ raise ValueError("Invalid runtime")
197
+
198
+
199
+ def _coerce_context(graph: Pregel, context: Any) -> Any:
200
+ if context is None:
201
+ return None
202
+
203
+ context_schema = graph.context_schema
204
+ if context_schema is None:
205
+ return context
206
+
207
+ schema_is_class = issubclass(context_schema, BaseModel) or is_dataclass(
208
+ context_schema,
209
+ )
210
+ if isinstance(context, dict) and schema_is_class:
211
+ return context_schema(**_filter_context_by_schema(context, graph))
212
+
213
+ return context
214
+
215
+
216
+ _CACHE = {}
217
+
218
+
219
+ def _filter_context_by_schema(context: dict[str, Any], graph: Pregel) -> dict[str, Any]:
220
+ if graph not in _CACHE:
221
+ _CACHE[graph] = graph.get_context_jsonschema()
222
+ if len(_CACHE) > 500:
223
+ _CACHE.popitem()
224
+ json_schema = _CACHE[graph]
225
+ if not json_schema or not context:
226
+ return context
227
+
228
+ # Extract valid properties from the schema
229
+ properties = json_schema.get("properties", {})
230
+ if not properties:
231
+ return context
232
+
233
+ # Filter context to only include parameters defined in the schema
234
+ filtered_context = {}
235
+ for key, value in context.items():
236
+ if key in properties:
237
+ filtered_context[key] = value
238
+
239
+ return filtered_context