langgraph-executor 0.0.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_executor/__init__.py +1 -0
- langgraph_executor/common.py +395 -0
- langgraph_executor/example.py +29 -0
- langgraph_executor/execute_task.py +239 -0
- langgraph_executor/executor.py +341 -0
- langgraph_executor/extract_graph.py +178 -0
- langgraph_executor/info_logger.py +111 -0
- langgraph_executor/pb/__init__.py +0 -0
- langgraph_executor/pb/executor_pb2.py +79 -0
- langgraph_executor/pb/executor_pb2.pyi +415 -0
- langgraph_executor/pb/executor_pb2_grpc.py +321 -0
- langgraph_executor/pb/executor_pb2_grpc.pyi +150 -0
- langgraph_executor/pb/graph_pb2.py +55 -0
- langgraph_executor/pb/graph_pb2.pyi +230 -0
- langgraph_executor/pb/graph_pb2_grpc.py +24 -0
- langgraph_executor/pb/graph_pb2_grpc.pyi +17 -0
- langgraph_executor/pb/runtime_pb2.py +68 -0
- langgraph_executor/pb/runtime_pb2.pyi +364 -0
- langgraph_executor/pb/runtime_pb2_grpc.py +322 -0
- langgraph_executor/pb/runtime_pb2_grpc.pyi +151 -0
- langgraph_executor/pb/types_pb2.py +144 -0
- langgraph_executor/pb/types_pb2.pyi +1044 -0
- langgraph_executor/pb/types_pb2_grpc.py +24 -0
- langgraph_executor/pb/types_pb2_grpc.pyi +17 -0
- langgraph_executor/py.typed +0 -0
- langgraph_executor/server.py +186 -0
- langgraph_executor/setup.sh +29 -0
- langgraph_executor/stream_utils.py +96 -0
- langgraph_executor-0.0.1a0.dist-info/METADATA +14 -0
- langgraph_executor-0.0.1a0.dist-info/RECORD +31 -0
- langgraph_executor-0.0.1a0.dist-info/WHEEL +4 -0
@@ -0,0 +1,341 @@
|
|
1
|
+
import logging
|
2
|
+
import uuid
|
3
|
+
from collections.abc import Iterator, Sequence
|
4
|
+
from functools import lru_cache
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import grpc
|
8
|
+
from google.protobuf.struct_pb2 import Struct # type: ignore[import-untyped]
|
9
|
+
from langchain_core.messages import BaseMessage, BaseMessageChunk
|
10
|
+
from langgraph.checkpoint.base import Checkpoint
|
11
|
+
from langgraph.errors import GraphBubbleUp, GraphInterrupt
|
12
|
+
from langgraph.pregel import Pregel
|
13
|
+
from langgraph.pregel._algo import apply_writes
|
14
|
+
from langgraph.pregel._checkpoint import channels_from_checkpoint
|
15
|
+
from langgraph.pregel._retry import run_with_retry
|
16
|
+
|
17
|
+
from langgraph_executor.common import (
|
18
|
+
checkpoint_to_proto,
|
19
|
+
exception_to_pb,
|
20
|
+
extract_channels,
|
21
|
+
get_graph,
|
22
|
+
pb_to_val,
|
23
|
+
reconstruct_channels,
|
24
|
+
reconstruct_checkpoint,
|
25
|
+
reconstruct_task_writes,
|
26
|
+
updates_to_proto,
|
27
|
+
)
|
28
|
+
from langgraph_executor.execute_task import (
|
29
|
+
extract_writes,
|
30
|
+
get_init_request,
|
31
|
+
reconstruct_task,
|
32
|
+
)
|
33
|
+
from langgraph_executor.extract_graph import extract_graph
|
34
|
+
from langgraph_executor.pb import executor_pb2, executor_pb2_grpc, types_pb2
|
35
|
+
from langgraph_executor.stream_utils import ExecutorStreamHandler
|
36
|
+
|
37
|
+
|
38
|
+
class LangGraphExecutorServicer(executor_pb2_grpc.LangGraphExecutorServicer):
|
39
|
+
"""gRPC servicer for LangGraph runtime execution operations."""
|
40
|
+
|
41
|
+
def __init__(self, graphs: dict[str, Pregel]):
|
42
|
+
"""Initialize the servicer with compiled graphs.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
graphs: Dictionary mapping graph names to compiled graphs
|
46
|
+
|
47
|
+
"""
|
48
|
+
self.graphs = graphs
|
49
|
+
self.logger = logging.getLogger(__name__)
|
50
|
+
|
51
|
+
def ListGraphs(self, request: Any, context: Any) -> executor_pb2.ListGraphsResponse: # type: ignore[name-defined]
|
52
|
+
"""List available graphs."""
|
53
|
+
return executor_pb2.ListGraphsResponse(
|
54
|
+
graph_names=list(self.graphs.keys()),
|
55
|
+
)
|
56
|
+
|
57
|
+
def GetGraph(self, request: Any, context: Any) -> executor_pb2.GetGraphResponse: # type: ignore[name-defined]
|
58
|
+
"""Get graph definition."""
|
59
|
+
try:
|
60
|
+
self.logger.debug("GetGraph called")
|
61
|
+
|
62
|
+
graph = self.graphs[request.graph_name]
|
63
|
+
|
64
|
+
# extract graph
|
65
|
+
graph_definition = extract_graph(graph)
|
66
|
+
|
67
|
+
return executor_pb2.GetGraphResponse(graph_definition=graph_definition)
|
68
|
+
|
69
|
+
except Exception as e:
|
70
|
+
self.logger.error(f"GetGraph Error: {e}", exc_info=True)
|
71
|
+
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
72
|
+
|
73
|
+
def ChannelsFromCheckpoint(
|
74
|
+
self, request: Any, context: Any
|
75
|
+
) -> executor_pb2.ChannelsFromCheckpointResponse: # type: ignore[name-defined]
|
76
|
+
try:
|
77
|
+
self.logger.debug("ChannelsFromCheckpoint called")
|
78
|
+
|
79
|
+
graph = get_graph(request.graph_name, self.graphs)
|
80
|
+
|
81
|
+
# reconstruct specs
|
82
|
+
specs, _ = reconstruct_channels(
|
83
|
+
request.specs.channels,
|
84
|
+
graph,
|
85
|
+
scratchpad=None, # type: ignore[invalid-arg-type]
|
86
|
+
)
|
87
|
+
|
88
|
+
# initialize channels from specs and checkpoint channel values
|
89
|
+
checkpoint_dummy = Checkpoint( # type: ignore[typeddict-item]
|
90
|
+
channel_values={
|
91
|
+
k: pb_to_val(v)
|
92
|
+
for k, v in request.checkpoint_channel_values.items()
|
93
|
+
},
|
94
|
+
)
|
95
|
+
channels, _ = channels_from_checkpoint(specs, checkpoint_dummy)
|
96
|
+
|
97
|
+
# channels to pb
|
98
|
+
channels = extract_channels(channels)
|
99
|
+
|
100
|
+
return executor_pb2.ChannelsFromCheckpointResponse(channels=channels)
|
101
|
+
|
102
|
+
except Exception as e:
|
103
|
+
self.logger.error(f"ChannelsFromCheckpoint Error: {e}", exc_info=True)
|
104
|
+
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
105
|
+
|
106
|
+
def ExecuteTask(
|
107
|
+
self,
|
108
|
+
request_iterator: Iterator[executor_pb2.ExecuteTaskRequest], # type: ignore[name-defined]
|
109
|
+
context: Any,
|
110
|
+
) -> Iterator[executor_pb2.ExecuteTaskResponse]: # type: ignore[name-defined]
|
111
|
+
self.logger.debug("ExecuteTask called")
|
112
|
+
_patch_specific_base_message()
|
113
|
+
|
114
|
+
# Right now, only handle task execution without interrupts, etc
|
115
|
+
try:
|
116
|
+
# Get request
|
117
|
+
request = get_init_request(request_iterator)
|
118
|
+
|
119
|
+
# Reconstruct PregelExecutableTask
|
120
|
+
graph = get_graph(request.graph_name, self.graphs)
|
121
|
+
task = reconstruct_task(request, graph)
|
122
|
+
|
123
|
+
# Check if streaming is requested (for messages mode)
|
124
|
+
stream_messages = "messages" in request.stream_modes
|
125
|
+
|
126
|
+
# Set up streaming callback if needed
|
127
|
+
stream_chunks = []
|
128
|
+
if stream_messages:
|
129
|
+
|
130
|
+
def stream_callback(message: BaseMessageChunk, metadata: dict):
|
131
|
+
"""Callback to capture stream chunks and queue them."""
|
132
|
+
try:
|
133
|
+
stream_chunks.append(
|
134
|
+
executor_pb2.ExecuteTaskResponse(
|
135
|
+
message_or_message_chunk=extract_output_message(message)
|
136
|
+
)
|
137
|
+
)
|
138
|
+
except Exception as e:
|
139
|
+
self.logger.warning(
|
140
|
+
f"Failed to create stream chunk: {e}", exc_info=True
|
141
|
+
)
|
142
|
+
|
143
|
+
# Create and inject callback handler
|
144
|
+
stream_handler = ExecutorStreamHandler(stream_callback, task.id)
|
145
|
+
|
146
|
+
# Add handler to task config callbacks
|
147
|
+
if "callbacks" not in task.config:
|
148
|
+
task.config["callbacks"] = []
|
149
|
+
task.config["callbacks"].append(stream_handler) # type: ignore[union-attr]
|
150
|
+
|
151
|
+
# Execute task, catching interrupts
|
152
|
+
# Check cache if task has cache key - send request to Go orchestrator
|
153
|
+
should_execute = True
|
154
|
+
if task.cache_key:
|
155
|
+
self.logger.debug(
|
156
|
+
f"Task {task.id} has cache key, sending cache check request to Go",
|
157
|
+
)
|
158
|
+
|
159
|
+
# Send cache check request to Go runtime
|
160
|
+
cache_check_request = executor_pb2.CacheCheckRequest(
|
161
|
+
cache_namespace=list(task.cache_key.ns),
|
162
|
+
cache_key=task.cache_key.key,
|
163
|
+
ttl=task.cache_key.ttl,
|
164
|
+
)
|
165
|
+
|
166
|
+
yield executor_pb2.ExecuteTaskResponse(
|
167
|
+
cache_check_request=cache_check_request,
|
168
|
+
)
|
169
|
+
|
170
|
+
# Wait for Go's response via the bidirectional stream
|
171
|
+
try:
|
172
|
+
cache_response_request = next(request_iterator)
|
173
|
+
if hasattr(cache_response_request, "cache_check_response"):
|
174
|
+
cache_response = cache_response_request.cache_check_response
|
175
|
+
should_execute = not cache_response.cache_hit
|
176
|
+
self.logger.debug(
|
177
|
+
f"Received cache response for task {task.id}: cache_hit={cache_response.cache_hit}",
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
self.logger.warning(
|
181
|
+
f"Expected cache_check_response for task {task.id}, got unexpected message type",
|
182
|
+
)
|
183
|
+
should_execute = (
|
184
|
+
True # Default to execution if unexpected response
|
185
|
+
)
|
186
|
+
except StopIteration:
|
187
|
+
self.logger.warning(
|
188
|
+
f"No cache response received for task {task.id}, defaulting to execution",
|
189
|
+
)
|
190
|
+
should_execute = True # Default to execution if no response
|
191
|
+
|
192
|
+
# TODO patch retry policy
|
193
|
+
# TODO configurable to deal with _call and the functional api
|
194
|
+
|
195
|
+
exception_pb = None
|
196
|
+
if not should_execute:
|
197
|
+
# Skip execution but still send response
|
198
|
+
pass
|
199
|
+
try:
|
200
|
+
run_with_retry(
|
201
|
+
task,
|
202
|
+
retry_policy=None,
|
203
|
+
)
|
204
|
+
|
205
|
+
# Yield any accumulated stream chunks
|
206
|
+
yield from stream_chunks
|
207
|
+
|
208
|
+
except Exception as e:
|
209
|
+
if isinstance(e, GraphBubbleUp | GraphInterrupt):
|
210
|
+
self.logger.info(f"Interrupt in task {task.id}: {e}")
|
211
|
+
else:
|
212
|
+
self.logger.exception(
|
213
|
+
f"Exception running task {task.id}: {e}\nTask: {task}\n\n",
|
214
|
+
exc_info=True,
|
215
|
+
)
|
216
|
+
exception_pb = exception_to_pb(e)
|
217
|
+
|
218
|
+
# Send final messages via message_chunk if they exist
|
219
|
+
final_messages = extract_output_messages(task.writes)
|
220
|
+
if final_messages:
|
221
|
+
for message in final_messages:
|
222
|
+
yield executor_pb2.ExecuteTaskResponse(
|
223
|
+
message_or_message_chunk=message
|
224
|
+
)
|
225
|
+
|
226
|
+
# Extract and yield channel writes
|
227
|
+
writes_pb = extract_writes(task.writes)
|
228
|
+
task_result_pb = (
|
229
|
+
executor_pb2.TaskResult(error=exception_pb, writes=writes_pb)
|
230
|
+
if exception_pb
|
231
|
+
else executor_pb2.TaskResult(writes=writes_pb)
|
232
|
+
)
|
233
|
+
|
234
|
+
yield executor_pb2.ExecuteTaskResponse(task_result=task_result_pb)
|
235
|
+
|
236
|
+
# Generate streaming chunks
|
237
|
+
# for chunk in output_writes(task, request):
|
238
|
+
# yield executor_pb2.ExecuteTaskResponse(stream_chunk=chunk)
|
239
|
+
|
240
|
+
except Exception as e:
|
241
|
+
self.logger.exception(f"ExecuteTask error: {e}")
|
242
|
+
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
243
|
+
|
244
|
+
def ApplyWrites(
|
245
|
+
self, request: Any, context: Any
|
246
|
+
) -> executor_pb2.ApplyWritesResponse: # type: ignore[name-defined]
|
247
|
+
# get graph
|
248
|
+
self.logger.debug("ApplyWrites called")
|
249
|
+
try:
|
250
|
+
# Reconstruct python objects from proto
|
251
|
+
graph = get_graph(request.graph_name, self.graphs)
|
252
|
+
channels, _ = reconstruct_channels(
|
253
|
+
request.channels.channels,
|
254
|
+
graph,
|
255
|
+
# TODO: figure this out
|
256
|
+
scratchpad=None, # type: ignore[invalid-arg-type]
|
257
|
+
)
|
258
|
+
checkpoint = reconstruct_checkpoint(request.checkpoint)
|
259
|
+
tasks = reconstruct_task_writes(request.tasks)
|
260
|
+
|
261
|
+
# apply writes
|
262
|
+
updated_channel_names_set = apply_writes(
|
263
|
+
checkpoint,
|
264
|
+
channels,
|
265
|
+
tasks,
|
266
|
+
lambda *args: request.next_version,
|
267
|
+
graph.trigger_to_nodes,
|
268
|
+
)
|
269
|
+
updated_channel_names = list(updated_channel_names_set)
|
270
|
+
|
271
|
+
# Reconstruct protos
|
272
|
+
updated_channels = extract_channels(channels)
|
273
|
+
checkpoint_proto = checkpoint_to_proto(checkpoint)
|
274
|
+
|
275
|
+
# Respond with updates
|
276
|
+
return executor_pb2.ApplyWritesResponse(
|
277
|
+
updates=updates_to_proto(
|
278
|
+
checkpoint_proto,
|
279
|
+
updated_channel_names,
|
280
|
+
updated_channels,
|
281
|
+
),
|
282
|
+
)
|
283
|
+
|
284
|
+
except Exception as e:
|
285
|
+
self.logger.exception(f"ApplyWrites error: {e}")
|
286
|
+
context.abort(grpc.StatusCode.INTERNAL, str(e))
|
287
|
+
|
288
|
+
|
289
|
+
def extract_output_messages(writes: Sequence[Any]) -> list[types_pb2.Message]: # type: ignore[name-defined]
|
290
|
+
messages = []
|
291
|
+
for write in writes:
|
292
|
+
# Not sure this check is right
|
293
|
+
if isinstance(write[1], BaseMessage):
|
294
|
+
messages.append(extract_output_message(write[1]))
|
295
|
+
elif isinstance(write[1], Sequence):
|
296
|
+
messages.extend(
|
297
|
+
[
|
298
|
+
extract_output_message(w)
|
299
|
+
for w in write[1]
|
300
|
+
if isinstance(w, BaseMessage)
|
301
|
+
]
|
302
|
+
)
|
303
|
+
|
304
|
+
return messages
|
305
|
+
|
306
|
+
|
307
|
+
def extract_output_message(write: Any) -> types_pb2.Message: # type: ignore[name-defined]
|
308
|
+
message = Struct()
|
309
|
+
message.update(
|
310
|
+
{
|
311
|
+
"is_streaming_chunk": False,
|
312
|
+
"message": {
|
313
|
+
"id": getattr(write, "id", None) or uuid.uuid4().hex,
|
314
|
+
"type": getattr(write, "type", None),
|
315
|
+
"content": str(getattr(write, "content", "") or ""),
|
316
|
+
"additional_kwargs": getattr(write, "additional_kwargs", {}),
|
317
|
+
"usage_metadata": getattr(write, "usage_metadata", {}),
|
318
|
+
"tool_calls": getattr(write, "tool_calls", []),
|
319
|
+
"tool_call_id": getattr(write, "tool_call_id", ""),
|
320
|
+
"tool_call_chunks": getattr(write, "tool_call_chunks", []),
|
321
|
+
"response_metadata": getattr(write, "response_metadata", {}),
|
322
|
+
},
|
323
|
+
"metadata": {},
|
324
|
+
}
|
325
|
+
)
|
326
|
+
return types_pb2.Message(payload=message)
|
327
|
+
|
328
|
+
|
329
|
+
@lru_cache(maxsize=1)
|
330
|
+
def _patch_specific_base_message() -> None:
|
331
|
+
"""Patch the specific BaseMessage class used in your system."""
|
332
|
+
from langchain_core.messages import BaseMessage
|
333
|
+
|
334
|
+
original_init = BaseMessage.__init__
|
335
|
+
|
336
|
+
def patched_init(self, content: Any, **kwargs: Any) -> None:
|
337
|
+
original_init(self, content, **kwargs)
|
338
|
+
if self.id is None:
|
339
|
+
self.id = str(uuid.uuid4())
|
340
|
+
|
341
|
+
BaseMessage.__init__ = patched_init # type: ignore[method-assign]
|
@@ -0,0 +1,178 @@
|
|
1
|
+
"""Shared module for extracting graph information from LangGraph graphs."""
|
2
|
+
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
from google.protobuf.json_format import MessageToJson
|
7
|
+
from google.protobuf.struct_pb2 import Struct # type: ignore[import-not-found]
|
8
|
+
from langchain_core.runnables import RunnableConfig
|
9
|
+
from langgraph._internal._constants import ( # CONFIG_KEY_PREVIOUS,
|
10
|
+
CONFIG_KEY_CHECKPOINT_ID,
|
11
|
+
CONFIG_KEY_CHECKPOINT_MAP,
|
12
|
+
CONFIG_KEY_CHECKPOINT_NS,
|
13
|
+
CONFIG_KEY_DURABILITY,
|
14
|
+
CONFIG_KEY_RESUMING,
|
15
|
+
CONFIG_KEY_TASK_ID,
|
16
|
+
CONFIG_KEY_THREAD_ID,
|
17
|
+
RESERVED,
|
18
|
+
)
|
19
|
+
from langgraph.cache.memory import InMemoryCache
|
20
|
+
from langgraph.pregel import Pregel
|
21
|
+
from langgraph.pregel._read import PregelNode
|
22
|
+
from langgraph.utils.config import ensure_config
|
23
|
+
|
24
|
+
from langgraph_executor.common import extract_channels
|
25
|
+
from langgraph_executor.pb import graph_pb2, types_pb2
|
26
|
+
|
27
|
+
DEFAULT_MAX_CONCURRENCY = 1
|
28
|
+
|
29
|
+
|
30
|
+
def extract_cache_type(cache: Any) -> str:
|
31
|
+
"""Extract cache type from a cache object."""
|
32
|
+
if cache is None:
|
33
|
+
return "unsupported"
|
34
|
+
if isinstance(cache, InMemoryCache):
|
35
|
+
return "inMemory"
|
36
|
+
return "unsupported"
|
37
|
+
|
38
|
+
|
39
|
+
def extract_config(config: RunnableConfig) -> types_pb2.RunnableConfig:
|
40
|
+
ensured_config = ensure_config(config)
|
41
|
+
# metadata
|
42
|
+
metadata_proto = Struct()
|
43
|
+
metadata = {k: v for k, v in ensured_config["metadata"].items()}
|
44
|
+
metadata_proto.update(metadata)
|
45
|
+
# configurable
|
46
|
+
configurable_proto = Struct()
|
47
|
+
|
48
|
+
configurable = {}
|
49
|
+
for k, v in ensured_config["configurable"].items():
|
50
|
+
if k not in RESERVED:
|
51
|
+
configurable[k] = v
|
52
|
+
|
53
|
+
configurable_proto.update(configurable)
|
54
|
+
return types_pb2.RunnableConfig(
|
55
|
+
tags=[t for t in ensured_config["tags"]],
|
56
|
+
recursion_limit=ensured_config["recursion_limit"],
|
57
|
+
run_name=ensured_config.get("run_name", ""),
|
58
|
+
max_concurrency=(
|
59
|
+
ensured_config.get("max_concurrency", DEFAULT_MAX_CONCURRENCY)
|
60
|
+
),
|
61
|
+
metadata=metadata_proto,
|
62
|
+
configurable=configurable_proto,
|
63
|
+
reserved_configurable=extract_reserved_configurable(
|
64
|
+
config.get("configurable", {})
|
65
|
+
),
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
def extract_reserved_configurable(
|
70
|
+
configurable: dict[str, Any],
|
71
|
+
) -> types_pb2.ReservedConfigurable:
|
72
|
+
return types_pb2.ReservedConfigurable(
|
73
|
+
resuming=bool(configurable.get(CONFIG_KEY_RESUMING, False)),
|
74
|
+
task_id=str(configurable.get(CONFIG_KEY_TASK_ID, "")),
|
75
|
+
thread_id=str(configurable.get(CONFIG_KEY_THREAD_ID, "")),
|
76
|
+
checkpoint_map=dict(configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {})),
|
77
|
+
checkpoint_id=str(configurable.get(CONFIG_KEY_CHECKPOINT_ID, "")),
|
78
|
+
checkpoint_ns=str(configurable.get(CONFIG_KEY_CHECKPOINT_NS, "")),
|
79
|
+
durability=configurable.get(CONFIG_KEY_DURABILITY, "async"),
|
80
|
+
)
|
81
|
+
|
82
|
+
|
83
|
+
def extract_nodes(nodes: dict[str, PregelNode]) -> dict[str, graph_pb2.NodeDefinition]:
|
84
|
+
out = {}
|
85
|
+
for k, v in nodes.items():
|
86
|
+
out[k] = extract_node(k, v)
|
87
|
+
return out
|
88
|
+
|
89
|
+
|
90
|
+
def extract_node(name: str, node: PregelNode) -> graph_pb2.NodeDefinition:
|
91
|
+
if isinstance(node.channels, str):
|
92
|
+
channels = [node.channels]
|
93
|
+
elif isinstance(node.channels, list):
|
94
|
+
channels = node.channels
|
95
|
+
elif isinstance(node.channels, dict):
|
96
|
+
channels = [k for k, _ in node.channels.items()]
|
97
|
+
else:
|
98
|
+
channels = []
|
99
|
+
# TODO cache policy
|
100
|
+
return graph_pb2.NodeDefinition(
|
101
|
+
metadata=Struct(fields=node.metadata or {}),
|
102
|
+
name=name,
|
103
|
+
triggers=node.triggers,
|
104
|
+
tags=node.tags or [],
|
105
|
+
channels=channels,
|
106
|
+
)
|
107
|
+
|
108
|
+
|
109
|
+
def extract_trigger_to_nodes(
|
110
|
+
trigger_to_nodes: dict[str, Sequence[str]] | Any, # Allow Mapping type from graph
|
111
|
+
) -> dict[str, graph_pb2.TriggerMapping]:
|
112
|
+
trigger_map = {}
|
113
|
+
for trigger, nodes in trigger_to_nodes.items():
|
114
|
+
if isinstance(nodes, dict) and "nodes" in nodes:
|
115
|
+
trigger_map[trigger] = graph_pb2.TriggerMapping(nodes=nodes["nodes"])
|
116
|
+
elif isinstance(nodes, list):
|
117
|
+
trigger_map[trigger] = graph_pb2.TriggerMapping(nodes=nodes)
|
118
|
+
else:
|
119
|
+
trigger_map[trigger] = graph_pb2.TriggerMapping(nodes=[])
|
120
|
+
return trigger_map
|
121
|
+
|
122
|
+
|
123
|
+
def extract_graph(graph: Pregel) -> graph_pb2.GraphDefinition:
|
124
|
+
"""Extract graph information from a compiled LangGraph graph.
|
125
|
+
|
126
|
+
Returns a protobuf message that contains all relevant orchestration information about the graph
|
127
|
+
"""
|
128
|
+
# Handle input_channels and output_channels oneof
|
129
|
+
graph_def = graph_pb2.GraphDefinition(
|
130
|
+
name=str(graph.name),
|
131
|
+
channels=extract_channels(graph.channels),
|
132
|
+
interrupt_before_nodes=list(graph.interrupt_before_nodes),
|
133
|
+
interrupt_after_nodes=list(graph.interrupt_after_nodes),
|
134
|
+
stream_mode=(
|
135
|
+
[graph.stream_mode]
|
136
|
+
if isinstance(graph.stream_mode, str)
|
137
|
+
else graph.stream_mode
|
138
|
+
),
|
139
|
+
stream_eager=bool(graph.stream_eager),
|
140
|
+
stream_channels=(
|
141
|
+
[graph.stream_channels]
|
142
|
+
if isinstance(graph.stream_channels, str)
|
143
|
+
else list(graph.stream_channels)
|
144
|
+
if graph.stream_channels
|
145
|
+
else []
|
146
|
+
),
|
147
|
+
step_timeout=float(graph.step_timeout) if graph.step_timeout else 0.0,
|
148
|
+
debug=bool(graph.debug),
|
149
|
+
# TODO retry policy
|
150
|
+
cache=graph_pb2.Cache(
|
151
|
+
cache_type=extract_cache_type(getattr(graph, "cache", None)),
|
152
|
+
),
|
153
|
+
config=extract_config(graph.config) if graph.config else None,
|
154
|
+
nodes=extract_nodes(graph.nodes),
|
155
|
+
trigger_to_nodes=extract_trigger_to_nodes(graph.trigger_to_nodes),
|
156
|
+
)
|
157
|
+
|
158
|
+
# Set input_channels and input_channels_is_list flag based on type
|
159
|
+
if isinstance(graph.input_channels, str):
|
160
|
+
graph_def.input_channels.extend([graph.input_channels])
|
161
|
+
graph_def.input_channels_is_list = False
|
162
|
+
else:
|
163
|
+
graph_def.input_channels.extend(graph.input_channels)
|
164
|
+
graph_def.input_channels_is_list = True
|
165
|
+
|
166
|
+
# Set output_channels and output_channels_is_list flag based on type
|
167
|
+
if isinstance(graph.output_channels, str):
|
168
|
+
graph_def.output_channels.extend([graph.output_channels])
|
169
|
+
graph_def.output_channels_is_list = False
|
170
|
+
else:
|
171
|
+
graph_def.output_channels.extend(graph.output_channels)
|
172
|
+
graph_def.output_channels_is_list = True
|
173
|
+
|
174
|
+
return graph_def
|
175
|
+
|
176
|
+
|
177
|
+
def convert_to_json(proto):
|
178
|
+
return MessageToJson(proto)
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import json
|
2
|
+
import shutil
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class ExecutorInfo:
|
9
|
+
id: str
|
10
|
+
pid: int
|
11
|
+
port: int
|
12
|
+
status: str
|
13
|
+
start_time: float
|
14
|
+
end_time: float | None = None
|
15
|
+
error_message: str | None = None
|
16
|
+
|
17
|
+
|
18
|
+
class ExecutorInfoLogger:
|
19
|
+
def __init__(self, log_dir: Path):
|
20
|
+
self.logs_dir = log_dir / "executor"
|
21
|
+
self.info_file_name = "info.json"
|
22
|
+
|
23
|
+
def create_executor_logs_dir(self, executor_id: str) -> Path:
|
24
|
+
executor_dir = self.logs_dir / executor_id
|
25
|
+
executor_dir.mkdir(parents=True, exist_ok=True)
|
26
|
+
|
27
|
+
return executor_dir
|
28
|
+
|
29
|
+
def write_executor_info(self, executor_info: ExecutorInfo):
|
30
|
+
executor_dir = self.create_executor_logs_dir(executor_info.id)
|
31
|
+
info_file = executor_dir / self.info_file_name
|
32
|
+
|
33
|
+
data = {
|
34
|
+
"id": executor_info.id,
|
35
|
+
"pid": executor_info.pid,
|
36
|
+
"port": executor_info.port,
|
37
|
+
"status": executor_info.status,
|
38
|
+
"start_time": executor_info.start_time,
|
39
|
+
"end_time": executor_info.end_time,
|
40
|
+
"error_message": executor_info.error_message,
|
41
|
+
}
|
42
|
+
|
43
|
+
with info_file.open("w") as f:
|
44
|
+
json.dump(data, f, indent=2)
|
45
|
+
f.flush()
|
46
|
+
|
47
|
+
def read_executor_info(self, executor_id: str) -> ExecutorInfo | None:
|
48
|
+
try:
|
49
|
+
executor_dir = self.logs_dir / executor_id
|
50
|
+
info_file = executor_dir / self.info_file_name
|
51
|
+
|
52
|
+
if not info_file.exists():
|
53
|
+
return None
|
54
|
+
|
55
|
+
with open(info_file) as f:
|
56
|
+
data = json.load(f)
|
57
|
+
|
58
|
+
if not data:
|
59
|
+
return None
|
60
|
+
|
61
|
+
return ExecutorInfo(
|
62
|
+
id=data["id"],
|
63
|
+
pid=data["pid"],
|
64
|
+
port=data["port"],
|
65
|
+
status=data["status"],
|
66
|
+
start_time=data["start_time"],
|
67
|
+
end_time=data.get("end_time"),
|
68
|
+
error_message=data.get("error_message"),
|
69
|
+
)
|
70
|
+
|
71
|
+
except Exception as e:
|
72
|
+
print(f"Failed to read executor info: {e}")
|
73
|
+
return None
|
74
|
+
|
75
|
+
def update_executor_info(
|
76
|
+
self,
|
77
|
+
executor_id: str,
|
78
|
+
status: str,
|
79
|
+
error_message: str | None = None,
|
80
|
+
end_time: float | None = None,
|
81
|
+
) -> None:
|
82
|
+
info_file = self.logs_dir / executor_id / self.info_file_name
|
83
|
+
|
84
|
+
if not info_file.exists():
|
85
|
+
return
|
86
|
+
|
87
|
+
# Read current info
|
88
|
+
with open(info_file) as f:
|
89
|
+
current_info = json.load(f)
|
90
|
+
|
91
|
+
if not current_info:
|
92
|
+
return
|
93
|
+
|
94
|
+
# Update fields
|
95
|
+
current_info["status"] = status
|
96
|
+
if error_message:
|
97
|
+
current_info["error_message"] = error_message
|
98
|
+
if end_time:
|
99
|
+
current_info["end_time"] = end_time
|
100
|
+
|
101
|
+
# Write
|
102
|
+
with open(info_file, "w") as f:
|
103
|
+
json.dump(current_info, f)
|
104
|
+
|
105
|
+
def cleanup_logs(self) -> None:
|
106
|
+
if not self.logs_dir.exists():
|
107
|
+
return
|
108
|
+
|
109
|
+
for dir in self.logs_dir.iterdir():
|
110
|
+
if dir.is_dir():
|
111
|
+
shutil.rmtree(dir)
|
File without changes
|