langgraph-executor 0.0.1a0__py3-none-any.whl → 0.0.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1 @@
1
- __version__ = "0.0.1a0"
1
+ __version__ = "0.0.1a2"
@@ -43,21 +43,13 @@ from langgraph_executor.common import (
43
43
  from langgraph_executor.pb import types_pb2
44
44
 
45
45
 
46
- def get_init_request(request_iterator):
47
- request = next(request_iterator)
48
-
49
- if not hasattr(request, "init"):
50
- raise ValueError("First message must be init")
51
-
52
- return request.init
53
-
54
-
55
46
  def reconstruct_task(
56
47
  request,
57
48
  graph: Pregel,
58
49
  *,
59
50
  store: BaseStore | None = None,
60
51
  config: RunnableConfig | None = None,
52
+ custom_stream_writer=None,
61
53
  ) -> PregelExecutableTask:
62
54
  pb_task = request.task
63
55
 
@@ -91,7 +83,9 @@ def reconstruct_task(
91
83
  val = pb_to_val(pb_task.input["PUSH_INPUT"])
92
84
 
93
85
  writes = deque()
94
- runtime = ensure_runtime(configurable, store, graph)
86
+ runtime = ensure_runtime(
87
+ configurable, store, graph, custom_stream_writer=custom_stream_writer
88
+ )
95
89
 
96
90
  # Generate cache key if cache policy exists
97
91
  cache_policy = getattr(proc, "cache_policy", None)
@@ -184,15 +178,31 @@ def create_scratchpad(
184
178
  return scratchpad
185
179
 
186
180
 
187
- def ensure_runtime(configurable, store, graph):
181
+ def ensure_runtime(configurable, store, graph, custom_stream_writer=None):
188
182
  runtime = configurable.get(CONFIG_KEY_RUNTIME)
183
+
184
+ # Prepare runtime overrides
185
+ overrides = {"store": store}
186
+ if custom_stream_writer is not None:
187
+ overrides["stream_writer"] = custom_stream_writer
188
+
189
189
  if runtime is None:
190
- return DEFAULT_RUNTIME.override(store=store)
190
+ return DEFAULT_RUNTIME.override(**overrides)
191
191
  if isinstance(runtime, Runtime):
192
- return runtime.override(store=store)
192
+ return runtime.override(**overrides)
193
193
  if isinstance(runtime, dict):
194
194
  context = _coerce_context(graph, runtime.get("context"))
195
- return Runtime(**(runtime | {"store": store, "context": context}))
195
+ return Runtime(
196
+ **(
197
+ runtime
198
+ | {"store": store, "context": context}
199
+ | (
200
+ {"stream_writer": custom_stream_writer}
201
+ if custom_stream_writer
202
+ else {}
203
+ )
204
+ )
205
+ )
196
206
  raise ValueError("Invalid runtime")
197
207
 
198
208
 
@@ -1,341 +1,162 @@
1
+ import contextlib
2
+ import functools
1
3
  import logging
2
- import uuid
3
- from collections.abc import Iterator, Sequence
4
- from functools import lru_cache
5
4
  from typing import Any
6
5
 
7
6
  import grpc
8
- from google.protobuf.struct_pb2 import Struct # type: ignore[import-untyped]
9
- from langchain_core.messages import BaseMessage, BaseMessageChunk
10
- from langgraph.checkpoint.base import Checkpoint
11
- from langgraph.errors import GraphBubbleUp, GraphInterrupt
7
+ import grpc.aio
8
+ from langgraph._internal._constants import NS_SEP
12
9
  from langgraph.pregel import Pregel
13
- from langgraph.pregel._algo import apply_writes
14
- from langgraph.pregel._checkpoint import channels_from_checkpoint
15
- from langgraph.pregel._retry import run_with_retry
16
-
17
- from langgraph_executor.common import (
18
- checkpoint_to_proto,
19
- exception_to_pb,
20
- extract_channels,
21
- get_graph,
22
- pb_to_val,
23
- reconstruct_channels,
24
- reconstruct_checkpoint,
25
- reconstruct_task_writes,
26
- updates_to_proto,
27
- )
28
- from langgraph_executor.execute_task import (
29
- extract_writes,
30
- get_init_request,
31
- reconstruct_task,
10
+
11
+ from langgraph_executor.executor_base import LangGraphExecutorServicer
12
+ from langgraph_executor.pb.executor_pb2_grpc import (
13
+ add_LangGraphExecutorServicer_to_server,
32
14
  )
33
- from langgraph_executor.extract_graph import extract_graph
34
- from langgraph_executor.pb import executor_pb2, executor_pb2_grpc, types_pb2
35
- from langgraph_executor.stream_utils import ExecutorStreamHandler
36
15
 
16
+ # Internal helpers
17
+ LOGGER = logging.getLogger(__name__)
18
+
19
+
20
+ def create_server(graphs: dict[str, Pregel], address: str) -> grpc.aio.Server:
21
+ graphs, subgraph_map = _load_graphs(graphs)
22
+ server = grpc.aio.server(
23
+ # Be permissive: allow client pings without active RPCs and accept intervals
24
+ # as low as 50s. Our clients still default to ~5m, but this avoids penalizing
25
+ # other, more frequent clients.
26
+ options=[
27
+ ("grpc.keepalive_permit_without_calls", 1),
28
+ ("grpc.http2.min_recv_ping_interval_without_data_ms", 50000), # 50s
29
+ ("grpc.http2.max_ping_strikes", 2),
30
+ ]
31
+ )
32
+ getter = functools.partial(get_graph, graphs=graphs)
33
+ add_LangGraphExecutorServicer_to_server(
34
+ LangGraphExecutorServicer(graphs, subgraph_map=subgraph_map, get_graph=getter),
35
+ server,
36
+ )
37
+ server.add_insecure_port(address)
38
+ return server
37
39
 
38
- class LangGraphExecutorServicer(executor_pb2_grpc.LangGraphExecutorServicer):
39
- """gRPC servicer for LangGraph runtime execution operations."""
40
40
 
41
- def __init__(self, graphs: dict[str, Pregel]):
42
- """Initialize the servicer with compiled graphs.
41
+ @contextlib.asynccontextmanager
42
+ async def get_graph(graph_name: str, config: Any, *, graphs: dict[str, Pregel]):
43
+ yield graphs[graph_name]
43
44
 
44
- Args:
45
- graphs: Dictionary mapping graph names to compiled graphs
46
45
 
47
- """
48
- self.graphs = graphs
49
- self.logger = logging.getLogger(__name__)
46
+ def _load_graphs(graphs: dict[str, Pregel]) -> tuple[dict[str, Pregel], dict[str, str]]:
47
+ """Load graphs and their subgraphs recursively in hierarchical order.
50
48
 
51
- def ListGraphs(self, request: Any, context: Any) -> executor_pb2.ListGraphsResponse: # type: ignore[name-defined]
52
- """List available graphs."""
53
- return executor_pb2.ListGraphsResponse(
54
- graph_names=list(self.graphs.keys()),
55
- )
49
+ Args:
50
+ graphs: Dictionary of root graphs to load
51
+ """
52
+ # First, ensure all root graphs have unique names
53
+ _ensure_unique_root_names(graphs)
54
+ subgraph_map: dict[str, str] = {}
56
55
 
57
- def GetGraph(self, request: Any, context: Any) -> executor_pb2.GetGraphResponse: # type: ignore[name-defined]
58
- """Get graph definition."""
59
- try:
60
- self.logger.debug("GetGraph called")
56
+ # Then, collect all subgraphs and mappings
57
+ all_subgraphs: dict[str, Pregel] = {}
58
+ subgraph_to_parent: dict[str, str] = {}
61
59
 
62
- graph = self.graphs[request.graph_name]
60
+ for root_graph in graphs.values():
61
+ subgraphs, mappings = _collect_subgraphs(root_graph, root_graph.name)
62
+ all_subgraphs.update(subgraphs)
63
+ subgraph_to_parent.update(mappings)
63
64
 
64
- # extract graph
65
- graph_definition = extract_graph(graph)
65
+ subgraph_map.update(subgraph_to_parent)
66
66
 
67
- return executor_pb2.GetGraphResponse(graph_definition=graph_definition)
67
+ # Now build self.graphs in hierarchical order (parents before children)
68
+ for root_name in sorted(graphs.keys()):
69
+ _load_graph_and_children(
70
+ root_name, graphs, {**graphs, **all_subgraphs}, subgraph_map
71
+ )
68
72
 
69
- except Exception as e:
70
- self.logger.error(f"GetGraph Error: {e}", exc_info=True)
71
- context.abort(grpc.StatusCode.INTERNAL, str(e))
73
+ _log_supported_graphs(graphs, subgraph_map)
74
+ return graphs, subgraph_map
72
75
 
73
- def ChannelsFromCheckpoint(
74
- self, request: Any, context: Any
75
- ) -> executor_pb2.ChannelsFromCheckpointResponse: # type: ignore[name-defined]
76
- try:
77
- self.logger.debug("ChannelsFromCheckpoint called")
78
76
 
79
- graph = get_graph(request.graph_name, self.graphs)
77
+ def _ensure_unique_root_names(graphs: dict[str, Pregel]) -> None:
78
+ """Ensure all root graphs have unique names"""
79
+ seen_names = set()
80
80
 
81
- # reconstruct specs
82
- specs, _ = reconstruct_channels(
83
- request.specs.channels,
84
- graph,
85
- scratchpad=None, # type: ignore[invalid-arg-type]
81
+ for name in graphs:
82
+ if name in seen_names:
83
+ raise ValueError(
84
+ f"Root graph name conflict detected: {name}. Root graphs must have unique names"
86
85
  )
86
+ seen_names.add(name)
87
87
 
88
- # initialize channels from specs and checkpoint channel values
89
- checkpoint_dummy = Checkpoint( # type: ignore[typeddict-item]
90
- channel_values={
91
- k: pb_to_val(v)
92
- for k, v in request.checkpoint_channel_values.items()
93
- },
94
- )
95
- channels, _ = channels_from_checkpoint(specs, checkpoint_dummy)
96
-
97
- # channels to pb
98
- channels = extract_channels(channels)
99
-
100
- return executor_pb2.ChannelsFromCheckpointResponse(channels=channels)
101
-
102
- except Exception as e:
103
- self.logger.error(f"ChannelsFromCheckpoint Error: {e}", exc_info=True)
104
- context.abort(grpc.StatusCode.INTERNAL, str(e))
105
-
106
- def ExecuteTask(
107
- self,
108
- request_iterator: Iterator[executor_pb2.ExecuteTaskRequest], # type: ignore[name-defined]
109
- context: Any,
110
- ) -> Iterator[executor_pb2.ExecuteTaskResponse]: # type: ignore[name-defined]
111
- self.logger.debug("ExecuteTask called")
112
- _patch_specific_base_message()
113
-
114
- # Right now, only handle task execution without interrupts, etc
115
- try:
116
- # Get request
117
- request = get_init_request(request_iterator)
118
-
119
- # Reconstruct PregelExecutableTask
120
- graph = get_graph(request.graph_name, self.graphs)
121
- task = reconstruct_task(request, graph)
122
-
123
- # Check if streaming is requested (for messages mode)
124
- stream_messages = "messages" in request.stream_modes
125
-
126
- # Set up streaming callback if needed
127
- stream_chunks = []
128
- if stream_messages:
129
-
130
- def stream_callback(message: BaseMessageChunk, metadata: dict):
131
- """Callback to capture stream chunks and queue them."""
132
- try:
133
- stream_chunks.append(
134
- executor_pb2.ExecuteTaskResponse(
135
- message_or_message_chunk=extract_output_message(message)
136
- )
137
- )
138
- except Exception as e:
139
- self.logger.warning(
140
- f"Failed to create stream chunk: {e}", exc_info=True
141
- )
142
-
143
- # Create and inject callback handler
144
- stream_handler = ExecutorStreamHandler(stream_callback, task.id)
145
-
146
- # Add handler to task config callbacks
147
- if "callbacks" not in task.config:
148
- task.config["callbacks"] = []
149
- task.config["callbacks"].append(stream_handler) # type: ignore[union-attr]
150
-
151
- # Execute task, catching interrupts
152
- # Check cache if task has cache key - send request to Go orchestrator
153
- should_execute = True
154
- if task.cache_key:
155
- self.logger.debug(
156
- f"Task {task.id} has cache key, sending cache check request to Go",
157
- )
158
-
159
- # Send cache check request to Go runtime
160
- cache_check_request = executor_pb2.CacheCheckRequest(
161
- cache_namespace=list(task.cache_key.ns),
162
- cache_key=task.cache_key.key,
163
- ttl=task.cache_key.ttl,
164
- )
165
-
166
- yield executor_pb2.ExecuteTaskResponse(
167
- cache_check_request=cache_check_request,
168
- )
169
-
170
- # Wait for Go's response via the bidirectional stream
171
- try:
172
- cache_response_request = next(request_iterator)
173
- if hasattr(cache_response_request, "cache_check_response"):
174
- cache_response = cache_response_request.cache_check_response
175
- should_execute = not cache_response.cache_hit
176
- self.logger.debug(
177
- f"Received cache response for task {task.id}: cache_hit={cache_response.cache_hit}",
178
- )
179
- else:
180
- self.logger.warning(
181
- f"Expected cache_check_response for task {task.id}, got unexpected message type",
182
- )
183
- should_execute = (
184
- True # Default to execution if unexpected response
185
- )
186
- except StopIteration:
187
- self.logger.warning(
188
- f"No cache response received for task {task.id}, defaulting to execution",
189
- )
190
- should_execute = True # Default to execution if no response
191
-
192
- # TODO patch retry policy
193
- # TODO configurable to deal with _call and the functional api
194
-
195
- exception_pb = None
196
- if not should_execute:
197
- # Skip execution but still send response
198
- pass
199
- try:
200
- run_with_retry(
201
- task,
202
- retry_policy=None,
203
- )
204
-
205
- # Yield any accumulated stream chunks
206
- yield from stream_chunks
207
-
208
- except Exception as e:
209
- if isinstance(e, GraphBubbleUp | GraphInterrupt):
210
- self.logger.info(f"Interrupt in task {task.id}: {e}")
211
- else:
212
- self.logger.exception(
213
- f"Exception running task {task.id}: {e}\nTask: {task}\n\n",
214
- exc_info=True,
215
- )
216
- exception_pb = exception_to_pb(e)
217
-
218
- # Send final messages via message_chunk if they exist
219
- final_messages = extract_output_messages(task.writes)
220
- if final_messages:
221
- for message in final_messages:
222
- yield executor_pb2.ExecuteTaskResponse(
223
- message_or_message_chunk=message
224
- )
225
-
226
- # Extract and yield channel writes
227
- writes_pb = extract_writes(task.writes)
228
- task_result_pb = (
229
- executor_pb2.TaskResult(error=exception_pb, writes=writes_pb)
230
- if exception_pb
231
- else executor_pb2.TaskResult(writes=writes_pb)
232
- )
233
88
 
234
- yield executor_pb2.ExecuteTaskResponse(task_result=task_result_pb)
235
-
236
- # Generate streaming chunks
237
- # for chunk in output_writes(task, request):
238
- # yield executor_pb2.ExecuteTaskResponse(stream_chunk=chunk)
239
-
240
- except Exception as e:
241
- self.logger.exception(f"ExecuteTask error: {e}")
242
- context.abort(grpc.StatusCode.INTERNAL, str(e))
243
-
244
- def ApplyWrites(
245
- self, request: Any, context: Any
246
- ) -> executor_pb2.ApplyWritesResponse: # type: ignore[name-defined]
247
- # get graph
248
- self.logger.debug("ApplyWrites called")
249
- try:
250
- # Reconstruct python objects from proto
251
- graph = get_graph(request.graph_name, self.graphs)
252
- channels, _ = reconstruct_channels(
253
- request.channels.channels,
254
- graph,
255
- # TODO: figure this out
256
- scratchpad=None, # type: ignore[invalid-arg-type]
257
- )
258
- checkpoint = reconstruct_checkpoint(request.checkpoint)
259
- tasks = reconstruct_task_writes(request.tasks)
260
-
261
- # apply writes
262
- updated_channel_names_set = apply_writes(
263
- checkpoint,
264
- channels,
265
- tasks,
266
- lambda *args: request.next_version,
267
- graph.trigger_to_nodes,
268
- )
269
- updated_channel_names = list(updated_channel_names_set)
270
-
271
- # Reconstruct protos
272
- updated_channels = extract_channels(channels)
273
- checkpoint_proto = checkpoint_to_proto(checkpoint)
274
-
275
- # Respond with updates
276
- return executor_pb2.ApplyWritesResponse(
277
- updates=updates_to_proto(
278
- checkpoint_proto,
279
- updated_channel_names,
280
- updated_channels,
281
- ),
282
- )
89
+ def _collect_subgraphs(
90
+ graph: Pregel, namespace: str
91
+ ) -> tuple[dict[str, Pregel], dict[str, str]]:
92
+ """Recursively collect all subgraphs from a root graph"""
93
+ subgraphs = {}
94
+ mappings = {}
283
95
 
284
- except Exception as e:
285
- self.logger.exception(f"ApplyWrites error: {e}")
286
- context.abort(grpc.StatusCode.INTERNAL, str(e))
287
-
288
-
289
- def extract_output_messages(writes: Sequence[Any]) -> list[types_pb2.Message]: # type: ignore[name-defined]
290
- messages = []
291
- for write in writes:
292
- # Not sure this check is right
293
- if isinstance(write[1], BaseMessage):
294
- messages.append(extract_output_message(write[1]))
295
- elif isinstance(write[1], Sequence):
296
- messages.extend(
297
- [
298
- extract_output_message(w)
299
- for w in write[1]
300
- if isinstance(w, BaseMessage)
301
- ]
302
- )
96
+ for idx, (node_name, subgraph) in enumerate(graph.get_subgraphs(recurse=False)):
97
+ # Generate subgraph name
98
+ subgraph.name = f"{namespace}{NS_SEP}{node_name}{NS_SEP}{idx}"
303
99
 
304
- return messages
305
-
306
-
307
- def extract_output_message(write: Any) -> types_pb2.Message: # type: ignore[name-defined]
308
- message = Struct()
309
- message.update(
310
- {
311
- "is_streaming_chunk": False,
312
- "message": {
313
- "id": getattr(write, "id", None) or uuid.uuid4().hex,
314
- "type": getattr(write, "type", None),
315
- "content": str(getattr(write, "content", "") or ""),
316
- "additional_kwargs": getattr(write, "additional_kwargs", {}),
317
- "usage_metadata": getattr(write, "usage_metadata", {}),
318
- "tool_calls": getattr(write, "tool_calls", []),
319
- "tool_call_id": getattr(write, "tool_call_id", ""),
320
- "tool_call_chunks": getattr(write, "tool_call_chunks", []),
321
- "response_metadata": getattr(write, "response_metadata", {}),
322
- },
323
- "metadata": {},
324
- }
325
- )
326
- return types_pb2.Message(payload=message)
100
+ # Add this subgraph
101
+ subgraphs[subgraph.name] = subgraph
102
+ mappings[subgraph.name] = graph.name
103
+
104
+ # Recursively process this subgraph's children
105
+ nested_subgraphs, nested_mappings = _collect_subgraphs(subgraph, namespace)
106
+
107
+ subgraphs.update(nested_subgraphs)
108
+ mappings.update(nested_mappings)
109
+
110
+ return subgraphs, mappings
111
+
112
+
113
+ def _load_graph_and_children(
114
+ graph_name: str,
115
+ graphs: dict[str, Pregel],
116
+ all_graphs: dict[str, Pregel],
117
+ subgraph_map: dict[str, str],
118
+ ) -> None:
119
+ """Recursively add a graph and its children to self.graphs in order"""
120
+
121
+ # Add this graph to self.graphs (maintaining insertion order)
122
+ graphs[graph_name] = all_graphs[graph_name]
123
+
124
+ # Get direct children of this graph
125
+ children = [
126
+ child_name
127
+ for child_name, parent_name in subgraph_map.items()
128
+ if parent_name == graph_name
129
+ ]
130
+
131
+ # Add children in sorted order (for deterministic output)
132
+ for child_name in sorted(children):
133
+ _load_graph_and_children(child_name, graphs, all_graphs, subgraph_map)
134
+
135
+
136
+ def _log_supported_graphs(
137
+ graphs: dict[str, Pregel], subgraph_map: dict[str, str]
138
+ ) -> None:
139
+ """Log the complete graph hierarchy in a tree-like format."""
140
+ LOGGER.info("Loaded graphs:")
327
141
 
142
+ # Get root graphs
143
+ root_graphs = {name for name in graphs if name not in subgraph_map}
328
144
 
329
- @lru_cache(maxsize=1)
330
- def _patch_specific_base_message() -> None:
331
- """Patch the specific BaseMessage class used in your system."""
332
- from langchain_core.messages import BaseMessage
145
+ for root_name in sorted(root_graphs):
146
+ LOGGER.info(f" {root_name}")
147
+ _log_graph_children(root_name, subgraph_map, indent=2)
333
148
 
334
- original_init = BaseMessage.__init__
335
149
 
336
- def patched_init(self, content: Any, **kwargs: Any) -> None:
337
- original_init(self, content, **kwargs)
338
- if self.id is None:
339
- self.id = str(uuid.uuid4())
150
+ def _log_graph_children(
151
+ parent_name: str, subgraph_map: dict[str, str], *, indent: int = 0
152
+ ) -> None:
153
+ """Recursively log children of a graph with proper indentation."""
154
+ children = [
155
+ child for child, parent in subgraph_map.items() if parent == parent_name
156
+ ]
340
157
 
341
- BaseMessage.__init__ = patched_init # type: ignore[method-assign]
158
+ for child in sorted(children):
159
+ prefix = " " * indent + "└─ "
160
+ LOGGER.info(f"{prefix}{child}")
161
+ # Recursively log this child's children
162
+ _log_graph_children(child, subgraph_map, indent=indent + 1)