langgraph-executor 0.0.1a1__tar.gz → 0.0.1a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/PKG-INFO +1 -1
  2. langgraph_executor-0.0.1a2/langgraph_executor/__init__.py +1 -0
  3. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/execute_task.py +0 -9
  4. langgraph_executor-0.0.1a2/langgraph_executor/executor.py +162 -0
  5. langgraph_executor-0.0.1a2/langgraph_executor/executor_base.py +473 -0
  6. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/info_logger.py +3 -3
  7. langgraph_executor-0.0.1a2/langgraph_executor/pb/executor_pb2.py +84 -0
  8. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/executor_pb2.pyi +24 -2
  9. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/executor_pb2_grpc.py +44 -0
  10. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/executor_pb2_grpc.pyi +20 -0
  11. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/server.py +22 -25
  12. langgraph_executor-0.0.1a1/langgraph_executor/__init__.py +0 -1
  13. langgraph_executor-0.0.1a1/langgraph_executor/executor.py +0 -376
  14. langgraph_executor-0.0.1a1/langgraph_executor/pb/executor_pb2.py +0 -82
  15. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/.gitignore +0 -0
  16. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/README.md +0 -0
  17. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/common.py +0 -0
  18. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/example.py +0 -0
  19. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/extract_graph.py +0 -0
  20. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/__init__.py +0 -0
  21. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/graph_pb2.py +0 -0
  22. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/graph_pb2.pyi +0 -0
  23. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/graph_pb2_grpc.py +0 -0
  24. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/graph_pb2_grpc.pyi +0 -0
  25. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/runtime_pb2.py +0 -0
  26. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/runtime_pb2.pyi +0 -0
  27. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/runtime_pb2_grpc.py +0 -0
  28. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/runtime_pb2_grpc.pyi +0 -0
  29. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/types_pb2.py +0 -0
  30. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/types_pb2.pyi +0 -0
  31. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/types_pb2_grpc.py +0 -0
  32. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/pb/types_pb2_grpc.pyi +0 -0
  33. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/py.typed +0 -0
  34. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/setup.sh +0 -0
  35. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/langgraph_executor/stream_utils.py +0 -0
  36. {langgraph_executor-0.0.1a1 → langgraph_executor-0.0.1a2}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-executor
3
- Version: 0.0.1a1
3
+ Version: 0.0.1a2
4
4
  Summary: LangGraph python RPC server executable by the langgraph-go orchestrator.
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: grpcio>=1.73.1
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1a2"
@@ -43,15 +43,6 @@ from langgraph_executor.common import (
43
43
  from langgraph_executor.pb import types_pb2
44
44
 
45
45
 
46
- def get_init_request(request_iterator):
47
- request = next(request_iterator)
48
-
49
- if not hasattr(request, "init"):
50
- raise ValueError("First message must be init")
51
-
52
- return request.init
53
-
54
-
55
46
  def reconstruct_task(
56
47
  request,
57
48
  graph: Pregel,
@@ -0,0 +1,162 @@
1
+ import contextlib
2
+ import functools
3
+ import logging
4
+ from typing import Any
5
+
6
+ import grpc
7
+ import grpc.aio
8
+ from langgraph._internal._constants import NS_SEP
9
+ from langgraph.pregel import Pregel
10
+
11
+ from langgraph_executor.executor_base import LangGraphExecutorServicer
12
+ from langgraph_executor.pb.executor_pb2_grpc import (
13
+ add_LangGraphExecutorServicer_to_server,
14
+ )
15
+
16
+ # Internal helpers
17
+ LOGGER = logging.getLogger(__name__)
18
+
19
+
20
+ def create_server(graphs: dict[str, Pregel], address: str) -> grpc.aio.Server:
21
+ graphs, subgraph_map = _load_graphs(graphs)
22
+ server = grpc.aio.server(
23
+ # Be permissive: allow client pings without active RPCs and accept intervals
24
+ # as low as 50s. Our clients still default to ~5m, but this avoids penalizing
25
+ # other, more frequent clients.
26
+ options=[
27
+ ("grpc.keepalive_permit_without_calls", 1),
28
+ ("grpc.http2.min_recv_ping_interval_without_data_ms", 50000), # 50s
29
+ ("grpc.http2.max_ping_strikes", 2),
30
+ ]
31
+ )
32
+ getter = functools.partial(get_graph, graphs=graphs)
33
+ add_LangGraphExecutorServicer_to_server(
34
+ LangGraphExecutorServicer(graphs, subgraph_map=subgraph_map, get_graph=getter),
35
+ server,
36
+ )
37
+ server.add_insecure_port(address)
38
+ return server
39
+
40
+
41
+ @contextlib.asynccontextmanager
42
+ async def get_graph(graph_name: str, config: Any, *, graphs: dict[str, Pregel]):
43
+ yield graphs[graph_name]
44
+
45
+
46
+ def _load_graphs(graphs: dict[str, Pregel]) -> tuple[dict[str, Pregel], dict[str, str]]:
47
+ """Load graphs and their subgraphs recursively in hierarchical order.
48
+
49
+ Args:
50
+ graphs: Dictionary of root graphs to load
51
+ """
52
+ # First, ensure all root graphs have unique names
53
+ _ensure_unique_root_names(graphs)
54
+ subgraph_map: dict[str, str] = {}
55
+
56
+ # Then, collect all subgraphs and mappings
57
+ all_subgraphs: dict[str, Pregel] = {}
58
+ subgraph_to_parent: dict[str, str] = {}
59
+
60
+ for root_graph in graphs.values():
61
+ subgraphs, mappings = _collect_subgraphs(root_graph, root_graph.name)
62
+ all_subgraphs.update(subgraphs)
63
+ subgraph_to_parent.update(mappings)
64
+
65
+ subgraph_map.update(subgraph_to_parent)
66
+
67
+ # Now build self.graphs in hierarchical order (parents before children)
68
+ for root_name in sorted(graphs.keys()):
69
+ _load_graph_and_children(
70
+ root_name, graphs, {**graphs, **all_subgraphs}, subgraph_map
71
+ )
72
+
73
+ _log_supported_graphs(graphs, subgraph_map)
74
+ return graphs, subgraph_map
75
+
76
+
77
+ def _ensure_unique_root_names(graphs: dict[str, Pregel]) -> None:
78
+ """Ensure all root graphs have unique names"""
79
+ seen_names = set()
80
+
81
+ for name in graphs:
82
+ if name in seen_names:
83
+ raise ValueError(
84
+ f"Root graph name conflict detected: {name}. Root graphs must have unique names"
85
+ )
86
+ seen_names.add(name)
87
+
88
+
89
+ def _collect_subgraphs(
90
+ graph: Pregel, namespace: str
91
+ ) -> tuple[dict[str, Pregel], dict[str, str]]:
92
+ """Recursively collect all subgraphs from a root graph"""
93
+ subgraphs = {}
94
+ mappings = {}
95
+
96
+ for idx, (node_name, subgraph) in enumerate(graph.get_subgraphs(recurse=False)):
97
+ # Generate subgraph name
98
+ subgraph.name = f"{namespace}{NS_SEP}{node_name}{NS_SEP}{idx}"
99
+
100
+ # Add this subgraph
101
+ subgraphs[subgraph.name] = subgraph
102
+ mappings[subgraph.name] = graph.name
103
+
104
+ # Recursively process this subgraph's children
105
+ nested_subgraphs, nested_mappings = _collect_subgraphs(subgraph, namespace)
106
+
107
+ subgraphs.update(nested_subgraphs)
108
+ mappings.update(nested_mappings)
109
+
110
+ return subgraphs, mappings
111
+
112
+
113
+ def _load_graph_and_children(
114
+ graph_name: str,
115
+ graphs: dict[str, Pregel],
116
+ all_graphs: dict[str, Pregel],
117
+ subgraph_map: dict[str, str],
118
+ ) -> None:
119
+ """Recursively add a graph and its children to self.graphs in order"""
120
+
121
+ # Add this graph to self.graphs (maintaining insertion order)
122
+ graphs[graph_name] = all_graphs[graph_name]
123
+
124
+ # Get direct children of this graph
125
+ children = [
126
+ child_name
127
+ for child_name, parent_name in subgraph_map.items()
128
+ if parent_name == graph_name
129
+ ]
130
+
131
+ # Add children in sorted order (for deterministic output)
132
+ for child_name in sorted(children):
133
+ _load_graph_and_children(child_name, graphs, all_graphs, subgraph_map)
134
+
135
+
136
+ def _log_supported_graphs(
137
+ graphs: dict[str, Pregel], subgraph_map: dict[str, str]
138
+ ) -> None:
139
+ """Log the complete graph hierarchy in a tree-like format."""
140
+ LOGGER.info("Loaded graphs:")
141
+
142
+ # Get root graphs
143
+ root_graphs = {name for name in graphs if name not in subgraph_map}
144
+
145
+ for root_name in sorted(root_graphs):
146
+ LOGGER.info(f" {root_name}")
147
+ _log_graph_children(root_name, subgraph_map, indent=2)
148
+
149
+
150
+ def _log_graph_children(
151
+ parent_name: str, subgraph_map: dict[str, str], *, indent: int = 0
152
+ ) -> None:
153
+ """Recursively log children of a graph with proper indentation."""
154
+ children = [
155
+ child for child, parent in subgraph_map.items() if parent == parent_name
156
+ ]
157
+
158
+ for child in sorted(children):
159
+ prefix = " " * indent + "└─ "
160
+ LOGGER.info(f"{prefix}{child}")
161
+ # Recursively log this child's children
162
+ _log_graph_children(child, subgraph_map, indent=indent + 1)
@@ -0,0 +1,473 @@
1
+ import asyncio
2
+ import contextlib
3
+ import functools
4
+ import logging
5
+ import uuid
6
+ from collections.abc import AsyncIterator, Callable, Collection, Iterator, Sequence
7
+ from typing import Any, Protocol, cast
8
+
9
+ import grpc
10
+ import grpc.aio
11
+ from google.protobuf.struct_pb2 import Struct # type: ignore[import-untyped]
12
+ from langchain_core.messages import BaseMessage, BaseMessageChunk
13
+ from langchain_core.runnables import RunnableConfig
14
+ from langgraph.checkpoint.base import Checkpoint
15
+ from langgraph.errors import GraphBubbleUp, GraphInterrupt
16
+ from langgraph.pregel import Pregel
17
+ from langgraph.pregel._algo import apply_writes
18
+ from langgraph.pregel._checkpoint import channels_from_checkpoint
19
+ from langgraph.pregel._retry import arun_with_retry
20
+ from langgraph.types import PregelExecutableTask
21
+
22
+ from langgraph_executor.common import (
23
+ checkpoint_to_proto,
24
+ exception_to_pb,
25
+ extract_channels,
26
+ pb_to_val,
27
+ reconstruct_channels,
28
+ reconstruct_checkpoint,
29
+ reconstruct_config,
30
+ reconstruct_task_writes,
31
+ updates_to_proto,
32
+ )
33
+ from langgraph_executor.execute_task import (
34
+ extract_writes,
35
+ reconstruct_task,
36
+ )
37
+ from langgraph_executor.extract_graph import extract_graph
38
+ from langgraph_executor.pb import executor_pb2, executor_pb2_grpc, types_pb2
39
+ from langgraph_executor.stream_utils import ExecutorStreamHandler
40
+
41
+
42
+ class Logger(Protocol):
43
+ def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
44
+ def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
45
+ def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
46
+ def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
47
+ def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
48
+ def critical(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
49
+
50
+
51
+ LOGGER = logging.getLogger(__name__)
52
+ SENTINEL = cast(executor_pb2.ExecuteTaskResponse, object())
53
+ GetGraph = Callable[
54
+ [str, RunnableConfig], contextlib.AbstractAsyncContextManager[Pregel]
55
+ ]
56
+
57
+
58
+ class LangGraphExecutorServicer(executor_pb2_grpc.LangGraphExecutorServicer):
59
+ """gRPC servicer for LangGraph runtime execution operations."""
60
+
61
+ def __init__(
62
+ self,
63
+ graphs: Collection[str],
64
+ *,
65
+ subgraph_map: dict[str, str],
66
+ get_graph: GetGraph,
67
+ logger: Logger | None = None,
68
+ ):
69
+ """Initialize the servicer with compiled graphs.
70
+
71
+ Args:
72
+ graphs: Dictionary mapping graph names to compiled graphs
73
+ subgraph_map: Dictionary mapping subgraph names to parent graph names
74
+ get_graph: Function to get a graph by name
75
+ logger: Optional logger
76
+
77
+ """
78
+ self.logger = logger or LOGGER
79
+ self.graphs = set(graphs)
80
+ self.graph_names = sorted(self.graphs)
81
+ self.subgraph_map = subgraph_map
82
+ self.get_graph = get_graph
83
+ _patch_base_message_with_ids()
84
+ self._graph_definition_cache: dict[str, executor_pb2.GetGraphResponse] = {}
85
+
86
+ async def ListGraphs(
87
+ self, request: Any, context: grpc.aio.ServicerContext
88
+ ) -> executor_pb2.ListGraphsResponse: # type: ignore[name-defined]
89
+ """List available graphs."""
90
+ return executor_pb2.ListGraphsResponse(
91
+ graph_names=self.graph_names,
92
+ )
93
+
94
+ async def GetGraph(
95
+ self, request: Any, context: grpc.aio.ServicerContext
96
+ ) -> executor_pb2.GetGraphResponse: # type: ignore[name-defined]
97
+ """Get graph definition."""
98
+ try:
99
+ self.logger.debug("GetGraph called")
100
+ graph_name: str = request.graph_name
101
+ return await self._get_graph_definition(graph_name)
102
+
103
+ except Exception as e:
104
+ self.logger.error(f"GetGraph Error: {e}", exc_info=True)
105
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
106
+
107
+ async def _get_graph_definition(self, name: str) -> executor_pb2.GetGraphResponse:
108
+ if (resp := self._graph_definition_cache.get(name)) is not None:
109
+ return resp
110
+ async with self.get_graph(name, RunnableConfig()) as graph:
111
+ graph_definition = extract_graph(graph)
112
+
113
+ resp = executor_pb2.GetGraphResponse(
114
+ graph_definition=graph_definition,
115
+ parent_name=self.subgraph_map.get(name, None),
116
+ checkpointer=graph.checkpointer is not None,
117
+ )
118
+ self._graph_definition_cache[name] = resp
119
+ return resp
120
+
121
+ async def GetAllGraphs(
122
+ self,
123
+ request: executor_pb2.GetAllGraphsRequest,
124
+ context: grpc.aio.ServicerContext,
125
+ ) -> AsyncIterator[executor_pb2.GetGraphResponse]:
126
+ try:
127
+ self.logger.debug("GetAllGraphs called")
128
+ for name in self.graph_names:
129
+ yield await self._get_graph_definition(name)
130
+
131
+ except Exception as e:
132
+ self.logger.error(f"GetAllGraphs Error: {e}", exc_info=True)
133
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
134
+
135
+ async def ChannelsFromCheckpoint(
136
+ self, request: Any, context: grpc.aio.ServicerContext
137
+ ) -> executor_pb2.ChannelsFromCheckpointResponse: # type: ignore[name-defined]
138
+ try:
139
+ self.logger.debug("ChannelsFromCheckpoint called")
140
+ async with self.get_graph(request.graph_name, RunnableConfig()) as graph:
141
+ # reconstruct specs
142
+ specs, _ = reconstruct_channels(
143
+ request.specs.channels,
144
+ graph,
145
+ scratchpad=None, # type: ignore[invalid-arg-type]
146
+ )
147
+
148
+ # initialize channels from specs and checkpoint channel values
149
+ checkpoint_dummy = Checkpoint( # type: ignore[typeddict-item]
150
+ channel_values={
151
+ k: pb_to_val(v)
152
+ for k, v in request.checkpoint_channel_values.items()
153
+ },
154
+ )
155
+ channels, _ = channels_from_checkpoint(specs, checkpoint_dummy)
156
+
157
+ # channels to pb
158
+ channels = extract_channels(channels)
159
+
160
+ return executor_pb2.ChannelsFromCheckpointResponse(channels=channels)
161
+
162
+ except Exception as e:
163
+ self.logger.error(f"ChannelsFromCheckpoint Error: {e}", exc_info=True)
164
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
165
+
166
+ async def ExecuteTask(
167
+ self,
168
+ request_iterator: Iterator[executor_pb2.ExecuteTaskRequest], # type: ignore[name-defined]
169
+ context: grpc.aio.ServicerContext,
170
+ ) -> AsyncIterator[executor_pb2.ExecuteTaskResponse]: # type: ignore[name-defined]
171
+ self.logger.debug("ExecuteTask called")
172
+
173
+ # Right now, only handle task execution without interrupts, etc
174
+ try:
175
+ request = await _get_init_request(request_iterator)
176
+ config = reconstruct_config(request.task.config)
177
+ async with self.get_graph(request.graph_name, config) as graph:
178
+ stream_messages = "messages" in request.stream_modes
179
+ stream_custom = "custom" in request.stream_modes
180
+
181
+ stream_queue = asyncio.Queue()
182
+
183
+ custom_stream_writer = (
184
+ _create_custom_stream_writer(stream_queue, self.logger)
185
+ if stream_custom
186
+ else None
187
+ )
188
+
189
+ task = reconstruct_task(
190
+ request, graph, custom_stream_writer=custom_stream_writer
191
+ )
192
+ if stream_messages:
193
+ # Create and inject callback handler
194
+ stream_handler = ExecutorStreamHandler(
195
+ functools.partial(
196
+ stream_callback,
197
+ logger=self.logger,
198
+ stream_queue=stream_queue,
199
+ ),
200
+ task.id,
201
+ )
202
+
203
+ # Add handler to task config callbacks
204
+ if "callbacks" not in task.config:
205
+ task.config["callbacks"] = []
206
+ task.config["callbacks"].append(stream_handler) # type: ignore[union-attr]
207
+
208
+ # Execute task, catching interrupts
209
+ # Check cache if task has cache key - send request to Go orchestrator
210
+ should_execute = True
211
+ if task.cache_key:
212
+ self.logger.debug(
213
+ f"Task {task.id} has cache key, sending cache check request to Go",
214
+ )
215
+
216
+ # Send cache check request to Go runtime
217
+ cache_check_request = executor_pb2.CacheCheckRequest(
218
+ cache_namespace=list(task.cache_key.ns),
219
+ cache_key=task.cache_key.key,
220
+ ttl=task.cache_key.ttl,
221
+ )
222
+
223
+ yield executor_pb2.ExecuteTaskResponse(
224
+ cache_check_request=cache_check_request,
225
+ )
226
+
227
+ # Wait for Go's response via the bidirectional stream
228
+ try:
229
+ cache_response_request = next(request_iterator)
230
+ if hasattr(cache_response_request, "cache_check_response"):
231
+ cache_response = cache_response_request.cache_check_response
232
+ should_execute = not cache_response.cache_hit
233
+ self.logger.debug(
234
+ f"Received cache response for task {task.id}: cache_hit={cache_response.cache_hit}",
235
+ )
236
+ else:
237
+ self.logger.warning(
238
+ f"Expected cache_check_response for task {task.id}, got unexpected message type",
239
+ )
240
+ should_execute = (
241
+ True # Default to execution if unexpected response
242
+ )
243
+ except StopIteration:
244
+ self.logger.warning(
245
+ f"No cache response received for task {task.id}, defaulting to execution",
246
+ )
247
+ should_execute = True # Default to execution if no response
248
+
249
+ # TODO patch retry policy
250
+ # TODO configurable to deal with _call and the functional api
251
+ exception_pb = None
252
+ if should_execute:
253
+ runner_task = asyncio.create_task(
254
+ _run_task(task, logger=self.logger, stream_queue=stream_queue)
255
+ )
256
+ # Drain the queue and stream responses to client
257
+ while True:
258
+ item = await stream_queue.get()
259
+ if item is SENTINEL:
260
+ break
261
+ yield item
262
+ exception_pb = await runner_task
263
+
264
+ # Ensure the final chat messages are emitted (if any)
265
+ final_messages = _extract_output_messages(task.writes)
266
+ if final_messages:
267
+ for message in final_messages:
268
+ yield executor_pb2.ExecuteTaskResponse(
269
+ message_or_message_chunk=message
270
+ )
271
+
272
+ # Final task result
273
+ yield executor_pb2.ExecuteTaskResponse(
274
+ task_result=executor_pb2.TaskResult(
275
+ error=exception_pb, writes=extract_writes(task.writes)
276
+ )
277
+ )
278
+
279
+ except Exception as e:
280
+ self.logger.exception(f"ExecuteTask error: {e}")
281
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
282
+
283
+ async def ApplyWrites(
284
+ self, request: Any, context: grpc.aio.ServicerContext
285
+ ) -> executor_pb2.ApplyWritesResponse: # type: ignore[name-defined]
286
+ # get graph
287
+ self.logger.debug("ApplyWrites called")
288
+ try:
289
+ async with self.get_graph(request.graph_name, RunnableConfig()) as graph:
290
+ channels, _ = reconstruct_channels(
291
+ request.channels.channels,
292
+ graph,
293
+ # TODO: figure this out
294
+ scratchpad=None, # type: ignore[invalid-arg-type]
295
+ )
296
+ checkpoint = reconstruct_checkpoint(request.checkpoint)
297
+ tasks = reconstruct_task_writes(request.tasks)
298
+
299
+ # apply writes
300
+ updated_channel_names_set = apply_writes(
301
+ checkpoint,
302
+ channels,
303
+ tasks,
304
+ lambda *args: request.next_version,
305
+ graph.trigger_to_nodes,
306
+ )
307
+ updated_channel_names = list(updated_channel_names_set)
308
+
309
+ # Reconstruct protos
310
+ updated_channels = extract_channels(channels)
311
+ checkpoint_proto = checkpoint_to_proto(checkpoint)
312
+
313
+ # Respond with updates
314
+ return executor_pb2.ApplyWritesResponse(
315
+ updates=updates_to_proto(
316
+ checkpoint_proto,
317
+ updated_channel_names,
318
+ updated_channels,
319
+ ),
320
+ )
321
+
322
+ except Exception as e:
323
+ self.logger.exception(f"ApplyWrites error: {e}")
324
+ await context.abort(grpc.StatusCode.INTERNAL, str(e))
325
+
326
+ async def GenerateCacheKey(
327
+ self,
328
+ request: executor_pb2.GenerateCacheKeyRequest,
329
+ context: grpc.aio.ServicerContext,
330
+ ) -> executor_pb2.GenerateCacheKeyResponse:
331
+ """Generate cache key for a node execution"""
332
+ raise NotImplementedError("GenerateCacheKey not implemented")
333
+
334
+
335
+ # Helpers
336
+
337
+
338
+ async def _run_task(
339
+ task: PregelExecutableTask,
340
+ *,
341
+ logger: Logger,
342
+ stream_queue: asyncio.Queue[executor_pb2.ExecuteTaskResponse],
343
+ ) -> types_pb2.ExecutorError | None:
344
+ try:
345
+ await arun_with_retry(
346
+ task,
347
+ retry_policy=None,
348
+ )
349
+
350
+ except Exception as e:
351
+ if isinstance(e, GraphBubbleUp | GraphInterrupt):
352
+ logger.info(f"Interrupt in task {task.id}: {e}")
353
+ else:
354
+ logger.exception(
355
+ f"Exception running task {task.id}: {e}\nTask: {task}\n\n",
356
+ exc_info=True,
357
+ )
358
+ return exception_to_pb(e)
359
+ finally:
360
+ await stream_queue.put(SENTINEL)
361
+
362
+
363
+ def stream_callback(
364
+ message: BaseMessageChunk,
365
+ metadata: dict,
366
+ *,
367
+ logger: Logger,
368
+ stream_queue: asyncio.Queue[executor_pb2.ExecuteTaskResponse],
369
+ ):
370
+ """Callback to capture stream chunks and queue them."""
371
+ try:
372
+ stream_queue.put_nowait(
373
+ executor_pb2.ExecuteTaskResponse(
374
+ message_or_message_chunk=_extract_output_message(message)
375
+ )
376
+ )
377
+ except Exception as e:
378
+ logger.warning(f"Failed to create stream chunk: {e}", exc_info=True)
379
+
380
+
381
+ def _create_custom_stream_writer(stream_queue: asyncio.Queue[Any], logger: Logger):
382
+ """Create a proper stream_writer function for custom mode (like langgraph does)."""
383
+
384
+ def stream_writer(content):
385
+ """Custom stream writer that creates CustomStreamEvent messages."""
386
+ try:
387
+ # Create payload struct (like langgraph does)
388
+ payload = Struct()
389
+ if isinstance(content, str):
390
+ payload.update({"content": content})
391
+ elif isinstance(content, dict):
392
+ payload.update(content)
393
+ else:
394
+ payload.update({"content": str(content)})
395
+
396
+ # Create CustomStreamEvent
397
+ custom_event = executor_pb2.CustomStreamEvent(payload=payload)
398
+ custom_event_response = executor_pb2.ExecuteTaskResponse(
399
+ custom_stream_event=custom_event
400
+ )
401
+ stream_queue.put_nowait(custom_event_response)
402
+
403
+ except Exception as e:
404
+ logger.warning(f"Failed to create custom stream event: {e}", exc_info=True)
405
+
406
+ return stream_writer
407
+
408
+
409
+ def _extract_output_messages(writes: Sequence[Any]) -> list[types_pb2.Message]: # type: ignore[name-defined]
410
+ messages = []
411
+ for write in writes:
412
+ # Not sure this check is right
413
+ if isinstance(write[1], BaseMessage):
414
+ messages.append(_extract_output_message(write[1]))
415
+ elif isinstance(write[1], Sequence):
416
+ messages.extend(
417
+ [
418
+ _extract_output_message(w)
419
+ for w in write[1]
420
+ if isinstance(w, BaseMessage)
421
+ ]
422
+ )
423
+
424
+ return messages
425
+
426
+
427
+ def _extract_output_message(write: Any) -> types_pb2.Message: # type: ignore[name-defined]
428
+ message = Struct()
429
+ message.update(
430
+ {
431
+ "is_streaming_chunk": False,
432
+ "message": {
433
+ "id": getattr(write, "id", None) or uuid.uuid4().hex,
434
+ "type": getattr(write, "type", None),
435
+ "content": str(getattr(write, "content", "") or ""),
436
+ "additional_kwargs": getattr(write, "additional_kwargs", {}),
437
+ "usage_metadata": getattr(write, "usage_metadata", {}),
438
+ "tool_calls": getattr(write, "tool_calls", []),
439
+ "tool_call_id": getattr(write, "tool_call_id", ""),
440
+ "tool_call_chunks": getattr(write, "tool_call_chunks", []),
441
+ "response_metadata": getattr(write, "response_metadata", {}),
442
+ },
443
+ "metadata": {},
444
+ }
445
+ )
446
+ return types_pb2.Message(payload=message)
447
+
448
+
449
+ async def _get_init_request(request_iterator):
450
+ request = await anext(request_iterator)
451
+
452
+ if not hasattr(request, "init"):
453
+ raise ValueError("First message must be init")
454
+
455
+ return request.init
456
+
457
+
458
+ @functools.lru_cache(maxsize=1)
459
+ def _patch_base_message_with_ids() -> None:
460
+ """Patch the specific BaseMessage class used in your system."""
461
+ try:
462
+ from langchain_core.messages import BaseMessage
463
+
464
+ original_init = BaseMessage.__init__
465
+
466
+ def patched_init(self, content: Any, **kwargs: Any) -> None:
467
+ original_init(self, content, **kwargs)
468
+ if self.id is None:
469
+ self.id = str(uuid.uuid4())
470
+
471
+ BaseMessage.__init__ = patched_init # type: ignore[method-assign]
472
+ except Exception as e:
473
+ LOGGER.warning("Failed to patch BaseMessage with IDs: %s", e)
@@ -8,7 +8,7 @@ from pathlib import Path
8
8
  class ExecutorInfo:
9
9
  id: str
10
10
  pid: int
11
- port: int
11
+ address: str
12
12
  status: str
13
13
  start_time: float
14
14
  end_time: float | None = None
@@ -33,7 +33,7 @@ class ExecutorInfoLogger:
33
33
  data = {
34
34
  "id": executor_info.id,
35
35
  "pid": executor_info.pid,
36
- "port": executor_info.port,
36
+ "address": executor_info.address,
37
37
  "status": executor_info.status,
38
38
  "start_time": executor_info.start_time,
39
39
  "end_time": executor_info.end_time,
@@ -61,7 +61,7 @@ class ExecutorInfoLogger:
61
61
  return ExecutorInfo(
62
62
  id=data["id"],
63
63
  pid=data["pid"],
64
- port=data["port"],
64
+ address=data["address"],
65
65
  status=data["status"],
66
66
  start_time=data["start_time"],
67
67
  end_time=data.get("end_time"),