langgraph-executor 0.0.1a1__py3-none-any.whl → 0.0.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- __version__ = "0.0.1a1"
1
+ __version__ = "0.0.1a2"
@@ -43,15 +43,6 @@ from langgraph_executor.common import (
43
43
  from langgraph_executor.pb import types_pb2
44
44
 
45
45
 
46
- def get_init_request(request_iterator):
47
- request = next(request_iterator)
48
-
49
- if not hasattr(request, "init"):
50
- raise ValueError("First message must be init")
51
-
52
- return request.init
53
-
54
-
55
46
  def reconstruct_task(
56
47
  request,
57
48
  graph: Pregel,
@@ -1,376 +1,162 @@
1
+ import contextlib
2
+ import functools
1
3
  import logging
2
- import uuid
3
- from collections.abc import Iterator, Sequence
4
- from functools import lru_cache
5
4
  from typing import Any
6
5
 
7
6
  import grpc
8
- from google.protobuf.struct_pb2 import Struct # type: ignore[import-untyped]
9
- from langchain_core.messages import BaseMessage, BaseMessageChunk
10
- from langgraph.checkpoint.base import Checkpoint
11
- from langgraph.errors import GraphBubbleUp, GraphInterrupt
7
+ import grpc.aio
8
+ from langgraph._internal._constants import NS_SEP
12
9
  from langgraph.pregel import Pregel
13
- from langgraph.pregel._algo import apply_writes
14
- from langgraph.pregel._checkpoint import channels_from_checkpoint
15
- from langgraph.pregel._retry import run_with_retry
16
-
17
- from langgraph_executor.common import (
18
- checkpoint_to_proto,
19
- exception_to_pb,
20
- extract_channels,
21
- get_graph,
22
- pb_to_val,
23
- reconstruct_channels,
24
- reconstruct_checkpoint,
25
- reconstruct_task_writes,
26
- updates_to_proto,
27
- )
28
- from langgraph_executor.execute_task import (
29
- extract_writes,
30
- get_init_request,
31
- reconstruct_task,
10
+
11
+ from langgraph_executor.executor_base import LangGraphExecutorServicer
12
+ from langgraph_executor.pb.executor_pb2_grpc import (
13
+ add_LangGraphExecutorServicer_to_server,
32
14
  )
33
- from langgraph_executor.extract_graph import extract_graph
34
- from langgraph_executor.pb import executor_pb2, executor_pb2_grpc, types_pb2
35
- from langgraph_executor.stream_utils import ExecutorStreamHandler
36
15
 
16
+ # Internal helpers
17
+ LOGGER = logging.getLogger(__name__)
18
+
19
+
20
+ def create_server(graphs: dict[str, Pregel], address: str) -> grpc.aio.Server:
21
+ graphs, subgraph_map = _load_graphs(graphs)
22
+ server = grpc.aio.server(
23
+ # Be permissive: allow client pings without active RPCs and accept intervals
24
+ # as low as 50s. Our clients still default to ~5m, but this avoids penalizing
25
+ # other, more frequent clients.
26
+ options=[
27
+ ("grpc.keepalive_permit_without_calls", 1),
28
+ ("grpc.http2.min_recv_ping_interval_without_data_ms", 50000), # 50s
29
+ ("grpc.http2.max_ping_strikes", 2),
30
+ ]
31
+ )
32
+ getter = functools.partial(get_graph, graphs=graphs)
33
+ add_LangGraphExecutorServicer_to_server(
34
+ LangGraphExecutorServicer(graphs, subgraph_map=subgraph_map, get_graph=getter),
35
+ server,
36
+ )
37
+ server.add_insecure_port(address)
38
+ return server
37
39
 
38
- class LangGraphExecutorServicer(executor_pb2_grpc.LangGraphExecutorServicer):
39
- """gRPC servicer for LangGraph runtime execution operations."""
40
40
 
41
- def __init__(self, graphs: dict[str, Pregel]):
42
- """Initialize the servicer with compiled graphs.
41
+ @contextlib.asynccontextmanager
42
+ async def get_graph(graph_name: str, config: Any, *, graphs: dict[str, Pregel]):
43
+ yield graphs[graph_name]
43
44
 
44
- Args:
45
- graphs: Dictionary mapping graph names to compiled graphs
46
45
 
47
- """
48
- self.graphs = graphs
49
- self.logger = logging.getLogger(__name__)
46
+ def _load_graphs(graphs: dict[str, Pregel]) -> tuple[dict[str, Pregel], dict[str, str]]:
47
+ """Load graphs and their subgraphs recursively in hierarchical order.
50
48
 
51
- def ListGraphs(self, request: Any, context: Any) -> executor_pb2.ListGraphsResponse: # type: ignore[name-defined]
52
- """List available graphs."""
53
- return executor_pb2.ListGraphsResponse(
54
- graph_names=list(self.graphs.keys()),
55
- )
49
+ Args:
50
+ graphs: Dictionary of root graphs to load
51
+ """
52
+ # First, ensure all root graphs have unique names
53
+ _ensure_unique_root_names(graphs)
54
+ subgraph_map: dict[str, str] = {}
56
55
 
57
- def GetGraph(self, request: Any, context: Any) -> executor_pb2.GetGraphResponse: # type: ignore[name-defined]
58
- """Get graph definition."""
59
- try:
60
- self.logger.debug("GetGraph called")
56
+ # Then, collect all subgraphs and mappings
57
+ all_subgraphs: dict[str, Pregel] = {}
58
+ subgraph_to_parent: dict[str, str] = {}
61
59
 
62
- graph = self.graphs[request.graph_name]
60
+ for root_graph in graphs.values():
61
+ subgraphs, mappings = _collect_subgraphs(root_graph, root_graph.name)
62
+ all_subgraphs.update(subgraphs)
63
+ subgraph_to_parent.update(mappings)
63
64
 
64
- # extract graph
65
- graph_definition = extract_graph(graph)
65
+ subgraph_map.update(subgraph_to_parent)
66
66
 
67
- return executor_pb2.GetGraphResponse(graph_definition=graph_definition)
67
+ # Now build self.graphs in hierarchical order (parents before children)
68
+ for root_name in sorted(graphs.keys()):
69
+ _load_graph_and_children(
70
+ root_name, graphs, {**graphs, **all_subgraphs}, subgraph_map
71
+ )
68
72
 
69
- except Exception as e:
70
- self.logger.error(f"GetGraph Error: {e}", exc_info=True)
71
- context.abort(grpc.StatusCode.INTERNAL, str(e))
73
+ _log_supported_graphs(graphs, subgraph_map)
74
+ return graphs, subgraph_map
72
75
 
73
- def ChannelsFromCheckpoint(
74
- self, request: Any, context: Any
75
- ) -> executor_pb2.ChannelsFromCheckpointResponse: # type: ignore[name-defined]
76
- try:
77
- self.logger.debug("ChannelsFromCheckpoint called")
78
76
 
79
- graph = get_graph(request.graph_name, self.graphs)
77
+ def _ensure_unique_root_names(graphs: dict[str, Pregel]) -> None:
78
+ """Ensure all root graphs have unique names"""
79
+ seen_names = set()
80
80
 
81
- # reconstruct specs
82
- specs, _ = reconstruct_channels(
83
- request.specs.channels,
84
- graph,
85
- scratchpad=None, # type: ignore[invalid-arg-type]
81
+ for name in graphs:
82
+ if name in seen_names:
83
+ raise ValueError(
84
+ f"Root graph name conflict detected: {name}. Root graphs must have unique names"
86
85
  )
86
+ seen_names.add(name)
87
87
 
88
- # initialize channels from specs and checkpoint channel values
89
- checkpoint_dummy = Checkpoint( # type: ignore[typeddict-item]
90
- channel_values={
91
- k: pb_to_val(v)
92
- for k, v in request.checkpoint_channel_values.items()
93
- },
94
- )
95
- channels, _ = channels_from_checkpoint(specs, checkpoint_dummy)
96
88
 
97
- # channels to pb
98
- channels = extract_channels(channels)
89
+ def _collect_subgraphs(
90
+ graph: Pregel, namespace: str
91
+ ) -> tuple[dict[str, Pregel], dict[str, str]]:
92
+ """Recursively collect all subgraphs from a root graph"""
93
+ subgraphs = {}
94
+ mappings = {}
99
95
 
100
- return executor_pb2.ChannelsFromCheckpointResponse(channels=channels)
96
+ for idx, (node_name, subgraph) in enumerate(graph.get_subgraphs(recurse=False)):
97
+ # Generate subgraph name
98
+ subgraph.name = f"{namespace}{NS_SEP}{node_name}{NS_SEP}{idx}"
101
99
 
102
- except Exception as e:
103
- self.logger.error(f"ChannelsFromCheckpoint Error: {e}", exc_info=True)
104
- context.abort(grpc.StatusCode.INTERNAL, str(e))
100
+ # Add this subgraph
101
+ subgraphs[subgraph.name] = subgraph
102
+ mappings[subgraph.name] = graph.name
105
103
 
106
- def ExecuteTask(
107
- self,
108
- request_iterator: Iterator[executor_pb2.ExecuteTaskRequest], # type: ignore[name-defined]
109
- context: Any,
110
- ) -> Iterator[executor_pb2.ExecuteTaskResponse]: # type: ignore[name-defined]
111
- self.logger.debug("ExecuteTask called")
112
- _patch_specific_base_message()
104
+ # Recursively process this subgraph's children
105
+ nested_subgraphs, nested_mappings = _collect_subgraphs(subgraph, namespace)
113
106
 
114
- # Right now, only handle task execution without interrupts, etc
115
- try:
116
- request = get_init_request(request_iterator)
107
+ subgraphs.update(nested_subgraphs)
108
+ mappings.update(nested_mappings)
117
109
 
118
- # Reconstruct PregelExecutableTask
119
- graph = get_graph(request.graph_name, self.graphs)
120
- stream_messages = "messages" in request.stream_modes
121
- stream_custom = "custom" in request.stream_modes
110
+ return subgraphs, mappings
122
111
 
123
- stream_chunks = []
124
112
 
125
- custom_stream_writer = (
126
- self._create_custom_stream_writer(stream_chunks)
127
- if stream_custom
128
- else None
129
- )
113
+ def _load_graph_and_children(
114
+ graph_name: str,
115
+ graphs: dict[str, Pregel],
116
+ all_graphs: dict[str, Pregel],
117
+ subgraph_map: dict[str, str],
118
+ ) -> None:
119
+ """Recursively add a graph and its children to self.graphs in order"""
130
120
 
131
- task = reconstruct_task(
132
- request, graph, custom_stream_writer=custom_stream_writer
133
- )
134
- if stream_messages:
135
-
136
- def stream_callback(message: BaseMessageChunk, metadata: dict):
137
- """Callback to capture stream chunks and queue them."""
138
- try:
139
- stream_chunks.append(
140
- executor_pb2.ExecuteTaskResponse(
141
- message_or_message_chunk=extract_output_message(message)
142
- )
143
- )
144
- except Exception as e:
145
- self.logger.warning(
146
- f"Failed to create stream chunk: {e}", exc_info=True
147
- )
148
-
149
- # Create and inject callback handler
150
- stream_handler = ExecutorStreamHandler(stream_callback, task.id)
151
-
152
- # Add handler to task config callbacks
153
- if "callbacks" not in task.config:
154
- task.config["callbacks"] = []
155
- task.config["callbacks"].append(stream_handler) # type: ignore[union-attr]
156
-
157
- # Execute task, catching interrupts
158
- # Check cache if task has cache key - send request to Go orchestrator
159
- should_execute = True
160
- if task.cache_key:
161
- self.logger.debug(
162
- f"Task {task.id} has cache key, sending cache check request to Go",
163
- )
164
-
165
- # Send cache check request to Go runtime
166
- cache_check_request = executor_pb2.CacheCheckRequest(
167
- cache_namespace=list(task.cache_key.ns),
168
- cache_key=task.cache_key.key,
169
- ttl=task.cache_key.ttl,
170
- )
171
-
172
- yield executor_pb2.ExecuteTaskResponse(
173
- cache_check_request=cache_check_request,
174
- )
175
-
176
- # Wait for Go's response via the bidirectional stream
177
- try:
178
- cache_response_request = next(request_iterator)
179
- if hasattr(cache_response_request, "cache_check_response"):
180
- cache_response = cache_response_request.cache_check_response
181
- should_execute = not cache_response.cache_hit
182
- self.logger.debug(
183
- f"Received cache response for task {task.id}: cache_hit={cache_response.cache_hit}",
184
- )
185
- else:
186
- self.logger.warning(
187
- f"Expected cache_check_response for task {task.id}, got unexpected message type",
188
- )
189
- should_execute = (
190
- True # Default to execution if unexpected response
191
- )
192
- except StopIteration:
193
- self.logger.warning(
194
- f"No cache response received for task {task.id}, defaulting to execution",
195
- )
196
- should_execute = True # Default to execution if no response
197
-
198
- # TODO patch retry policy
199
- # TODO configurable to deal with _call and the functional api
200
-
201
- exception_pb = None
202
- if not should_execute:
203
- # Skip execution but still send response
204
- pass
205
- try:
206
- run_with_retry(
207
- task,
208
- retry_policy=None,
209
- )
210
- # Yield any accumulated stream chunks
211
- yield from stream_chunks
212
-
213
- except Exception as e:
214
- if isinstance(e, GraphBubbleUp | GraphInterrupt):
215
- self.logger.info(f"Interrupt in task {task.id}: {e}")
216
- else:
217
- self.logger.exception(
218
- f"Exception running task {task.id}: {e}\nTask: {task}\n\n",
219
- exc_info=True,
220
- )
221
- exception_pb = exception_to_pb(e)
222
-
223
- # Send final messages via message_chunk if they exist
224
- final_messages = extract_output_messages(task.writes)
225
- if final_messages:
226
- for message in final_messages:
227
- yield executor_pb2.ExecuteTaskResponse(
228
- message_or_message_chunk=message
229
- )
230
-
231
- # Extract and yield channel writes
232
- writes_pb = extract_writes(task.writes)
233
- task_result_pb = (
234
- executor_pb2.TaskResult(error=exception_pb, writes=writes_pb)
235
- if exception_pb
236
- else executor_pb2.TaskResult(writes=writes_pb)
237
- )
121
+ # Add this graph to self.graphs (maintaining insertion order)
122
+ graphs[graph_name] = all_graphs[graph_name]
238
123
 
239
- yield executor_pb2.ExecuteTaskResponse(task_result=task_result_pb)
240
-
241
- # Generate streaming chunks
242
- # for chunk in output_writes(task, request):
243
- # yield executor_pb2.ExecuteTaskResponse(stream_chunk=chunk)
244
-
245
- except Exception as e:
246
- self.logger.exception(f"ExecuteTask error: {e}")
247
- context.abort(grpc.StatusCode.INTERNAL, str(e))
248
-
249
- def ApplyWrites(
250
- self, request: Any, context: Any
251
- ) -> executor_pb2.ApplyWritesResponse: # type: ignore[name-defined]
252
- # get graph
253
- self.logger.debug("ApplyWrites called")
254
- try:
255
- # Reconstruct python objects from proto
256
- graph = get_graph(request.graph_name, self.graphs)
257
- channels, _ = reconstruct_channels(
258
- request.channels.channels,
259
- graph,
260
- # TODO: figure this out
261
- scratchpad=None, # type: ignore[invalid-arg-type]
262
- )
263
- checkpoint = reconstruct_checkpoint(request.checkpoint)
264
- tasks = reconstruct_task_writes(request.tasks)
265
-
266
- # apply writes
267
- updated_channel_names_set = apply_writes(
268
- checkpoint,
269
- channels,
270
- tasks,
271
- lambda *args: request.next_version,
272
- graph.trigger_to_nodes,
273
- )
274
- updated_channel_names = list(updated_channel_names_set)
275
-
276
- # Reconstruct protos
277
- updated_channels = extract_channels(channels)
278
- checkpoint_proto = checkpoint_to_proto(checkpoint)
279
-
280
- # Respond with updates
281
- return executor_pb2.ApplyWritesResponse(
282
- updates=updates_to_proto(
283
- checkpoint_proto,
284
- updated_channel_names,
285
- updated_channels,
286
- ),
287
- )
124
+ # Get direct children of this graph
125
+ children = [
126
+ child_name
127
+ for child_name, parent_name in subgraph_map.items()
128
+ if parent_name == graph_name
129
+ ]
288
130
 
289
- except Exception as e:
290
- self.logger.exception(f"ApplyWrites error: {e}")
291
- context.abort(grpc.StatusCode.INTERNAL, str(e))
292
-
293
- def _create_custom_stream_writer(self, stream_chunks):
294
- """Create a proper stream_writer function for custom mode (like langgraph does)."""
295
- from google.protobuf.struct_pb2 import Struct # type: ignore[unresolved-import]
296
-
297
- def stream_writer(content):
298
- """Custom stream writer that creates CustomStreamEvent messages."""
299
- try:
300
- # Create payload struct (like langgraph does)
301
- payload = Struct()
302
- if isinstance(content, str):
303
- payload.update({"content": content})
304
- elif isinstance(content, dict):
305
- payload.update(content)
306
- else:
307
- payload.update({"content": str(content)})
308
-
309
- # Create CustomStreamEvent
310
- custom_event = executor_pb2.CustomStreamEvent(payload=payload)
311
- custom_event_response = executor_pb2.ExecuteTaskResponse(
312
- custom_stream_event=custom_event
313
- )
314
- stream_chunks.append(custom_event_response)
315
-
316
- except Exception as e:
317
- self.logger.warning(
318
- f"Failed to create custom stream event: {e}", exc_info=True
319
- )
320
-
321
- return stream_writer
322
-
323
-
324
- def extract_output_messages(writes: Sequence[Any]) -> list[types_pb2.Message]: # type: ignore[name-defined]
325
- messages = []
326
- for write in writes:
327
- # Not sure this check is right
328
- if isinstance(write[1], BaseMessage):
329
- messages.append(extract_output_message(write[1]))
330
- elif isinstance(write[1], Sequence):
331
- messages.extend(
332
- [
333
- extract_output_message(w)
334
- for w in write[1]
335
- if isinstance(w, BaseMessage)
336
- ]
337
- )
131
+ # Add children in sorted order (for deterministic output)
132
+ for child_name in sorted(children):
133
+ _load_graph_and_children(child_name, graphs, all_graphs, subgraph_map)
338
134
 
339
- return messages
340
-
341
-
342
- def extract_output_message(write: Any) -> types_pb2.Message: # type: ignore[name-defined]
343
- message = Struct()
344
- message.update(
345
- {
346
- "is_streaming_chunk": False,
347
- "message": {
348
- "id": getattr(write, "id", None) or uuid.uuid4().hex,
349
- "type": getattr(write, "type", None),
350
- "content": str(getattr(write, "content", "") or ""),
351
- "additional_kwargs": getattr(write, "additional_kwargs", {}),
352
- "usage_metadata": getattr(write, "usage_metadata", {}),
353
- "tool_calls": getattr(write, "tool_calls", []),
354
- "tool_call_id": getattr(write, "tool_call_id", ""),
355
- "tool_call_chunks": getattr(write, "tool_call_chunks", []),
356
- "response_metadata": getattr(write, "response_metadata", {}),
357
- },
358
- "metadata": {},
359
- }
360
- )
361
- return types_pb2.Message(payload=message)
362
135
 
136
+ def _log_supported_graphs(
137
+ graphs: dict[str, Pregel], subgraph_map: dict[str, str]
138
+ ) -> None:
139
+ """Log the complete graph hierarchy in a tree-like format."""
140
+ LOGGER.info("Loaded graphs:")
141
+
142
+ # Get root graphs
143
+ root_graphs = {name for name in graphs if name not in subgraph_map}
363
144
 
364
- @lru_cache(maxsize=1)
365
- def _patch_specific_base_message() -> None:
366
- """Patch the specific BaseMessage class used in your system."""
367
- from langchain_core.messages import BaseMessage
145
+ for root_name in sorted(root_graphs):
146
+ LOGGER.info(f" {root_name}")
147
+ _log_graph_children(root_name, subgraph_map, indent=2)
368
148
 
369
- original_init = BaseMessage.__init__
370
149
 
371
- def patched_init(self, content: Any, **kwargs: Any) -> None:
372
- original_init(self, content, **kwargs)
373
- if self.id is None:
374
- self.id = str(uuid.uuid4())
150
+ def _log_graph_children(
151
+ parent_name: str, subgraph_map: dict[str, str], *, indent: int = 0
152
+ ) -> None:
153
+ """Recursively log children of a graph with proper indentation."""
154
+ children = [
155
+ child for child, parent in subgraph_map.items() if parent == parent_name
156
+ ]
375
157
 
376
- BaseMessage.__init__ = patched_init # type: ignore[method-assign]
158
+ for child in sorted(children):
159
+ prefix = " " * indent + "└─ "
160
+ LOGGER.info(f"{prefix}{child}")
161
+ # Recursively log this child's children
162
+ _log_graph_children(child, subgraph_map, indent=indent + 1)