langgraph-executor 0.0.1a5__py3-none-any.whl → 0.0.1a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_executor/__init__.py +1 -1
- langgraph_executor/client/__init__.py +0 -0
- langgraph_executor/client/patch.py +305 -0
- langgraph_executor/client/utils.py +340 -0
- langgraph_executor/common.py +65 -16
- langgraph_executor/executor.py +22 -8
- langgraph_executor/executor_base.py +91 -3
- langgraph_executor/extract_graph.py +19 -26
- langgraph_executor/pb/executor_pb2.py +45 -43
- langgraph_executor/pb/executor_pb2.pyi +39 -0
- langgraph_executor/pb/executor_pb2_grpc.py +44 -0
- langgraph_executor/pb/executor_pb2_grpc.pyi +20 -0
- langgraph_executor/pb/graph_pb2.py +12 -12
- langgraph_executor/pb/graph_pb2.pyi +13 -14
- langgraph_executor/pb/runtime_pb2.py +49 -31
- langgraph_executor/pb/runtime_pb2.pyi +197 -7
- langgraph_executor/pb/runtime_pb2_grpc.py +229 -1
- langgraph_executor/pb/runtime_pb2_grpc.pyi +133 -6
- langgraph_executor/pb/types_pb2.py +17 -9
- langgraph_executor/pb/types_pb2.pyi +95 -0
- langgraph_executor/serde.py +13 -0
- {langgraph_executor-0.0.1a5.dist-info → langgraph_executor-0.0.1a7.dist-info}/METADATA +1 -1
- langgraph_executor-0.0.1a7.dist-info/RECORD +36 -0
- langgraph_executor-0.0.1a5.dist-info/RECORD +0 -32
- {langgraph_executor-0.0.1a5.dist-info → langgraph_executor-0.0.1a7.dist-info}/WHEEL +0 -0
langgraph_executor/__init__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.0.
|
1
|
+
__version__ = "0.0.1a7"
|
File without changes
|
@@ -0,0 +1,305 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
import grpc
|
5
|
+
from langgraph._internal._config import ensure_config
|
6
|
+
from langgraph.errors import GraphInterrupt, GraphRecursionError
|
7
|
+
from langgraph.pregel import Pregel
|
8
|
+
from langgraph.runtime import get_runtime
|
9
|
+
from langgraph.types import Interrupt
|
10
|
+
from pydantic import ValidationError
|
11
|
+
|
12
|
+
from langgraph_executor import serde
|
13
|
+
from langgraph_executor.client.utils import (
|
14
|
+
config_to_pb,
|
15
|
+
context_to_pb,
|
16
|
+
create_runopts_pb,
|
17
|
+
decode_response,
|
18
|
+
input_to_pb,
|
19
|
+
)
|
20
|
+
from langgraph_executor.common import var_child_runnable_config
|
21
|
+
from langgraph_executor.pb import runtime_pb2
|
22
|
+
from langgraph_executor.pb.runtime_pb2 import OutputChunk
|
23
|
+
from langgraph_executor.pb.runtime_pb2_grpc import LangGraphRuntimeStub
|
24
|
+
|
25
|
+
|
26
|
+
def _patch_pregel(runtime_client: LangGraphRuntimeStub, logger: logging.Logger):
|
27
|
+
async def patched_ainvoke(pregel_self, input, config=None, **kwargs):
|
28
|
+
return await _ainvoke_wrapper(
|
29
|
+
runtime_client, logger, pregel_self, input, config, **kwargs
|
30
|
+
)
|
31
|
+
|
32
|
+
def patched_invoke(pregel_self, input, config=None, **kwargs):
|
33
|
+
return _invoke_wrapper(
|
34
|
+
runtime_client, logger, pregel_self, input, config, **kwargs
|
35
|
+
)
|
36
|
+
|
37
|
+
Pregel.ainvoke = patched_ainvoke # type: ignore[invalid-assignment]
|
38
|
+
Pregel.invoke = patched_invoke # type: ignore[invalid-assignment]
|
39
|
+
|
40
|
+
|
41
|
+
async def _ainvoke_wrapper(
|
42
|
+
runtime_client: LangGraphRuntimeStub,
|
43
|
+
logger: logging.Logger,
|
44
|
+
pregel_self: Pregel, # This is the actual Pregel instance
|
45
|
+
input,
|
46
|
+
config=None,
|
47
|
+
context=None,
|
48
|
+
stream_mode=["values"],
|
49
|
+
output_keys=None,
|
50
|
+
interrupt_before=None,
|
51
|
+
interrupt_after=None,
|
52
|
+
durability=None,
|
53
|
+
debug=None,
|
54
|
+
subgraphs=False,
|
55
|
+
) -> dict[str, Any] | Any:
|
56
|
+
"""Wrapper that handles the actual invoke logic."""
|
57
|
+
|
58
|
+
# subgraph names coerced when initializing executor
|
59
|
+
graph_name = pregel_self.name
|
60
|
+
|
61
|
+
logger.info(f"SUBGRAPH AINVOKE ENCOUNTERED: {graph_name}")
|
62
|
+
|
63
|
+
# TODO: Hacky way of retrieving runtime from runnable context
|
64
|
+
if not context:
|
65
|
+
try:
|
66
|
+
runtime = get_runtime()
|
67
|
+
if runtime.context:
|
68
|
+
context = runtime.context
|
69
|
+
except Exception as e:
|
70
|
+
logger.error(f"failed to retrive parent runtime for subgraph: {e}")
|
71
|
+
|
72
|
+
if parent_config := var_child_runnable_config.get({}):
|
73
|
+
config = ensure_config(config, parent_config)
|
74
|
+
|
75
|
+
try:
|
76
|
+
# create request
|
77
|
+
invoke_request = runtime_pb2.InvokeRequest(
|
78
|
+
graph_name=graph_name,
|
79
|
+
input=input_to_pb(input),
|
80
|
+
config=config_to_pb(config),
|
81
|
+
context=context_to_pb(context),
|
82
|
+
run_opts=create_runopts_pb(
|
83
|
+
stream_mode,
|
84
|
+
output_keys,
|
85
|
+
interrupt_before,
|
86
|
+
interrupt_after,
|
87
|
+
durability,
|
88
|
+
debug,
|
89
|
+
subgraphs,
|
90
|
+
),
|
91
|
+
)
|
92
|
+
|
93
|
+
# get response - if this blocks, you might need to make it async
|
94
|
+
try:
|
95
|
+
# Option 1: If runtime_client.Invoke is synchronous and might block:
|
96
|
+
import asyncio
|
97
|
+
|
98
|
+
loop = asyncio.get_event_loop()
|
99
|
+
response = await loop.run_in_executor(
|
100
|
+
None, runtime_client.Invoke, invoke_request
|
101
|
+
)
|
102
|
+
|
103
|
+
if response.WhichOneof("message") == "error":
|
104
|
+
error = response.error.error
|
105
|
+
|
106
|
+
if error.WhichOneof("error_type") == "graph_interrupt":
|
107
|
+
graph_interrupt = error.graph_interrupt
|
108
|
+
|
109
|
+
interrupts = []
|
110
|
+
|
111
|
+
for interrupt in graph_interrupt.interrupts:
|
112
|
+
interrupts.append(
|
113
|
+
Interrupt(
|
114
|
+
value=serde.get_serializer().loads_typed(
|
115
|
+
(
|
116
|
+
interrupt.value.base_value.method,
|
117
|
+
interrupt.value.base_value.value,
|
118
|
+
)
|
119
|
+
),
|
120
|
+
id=interrupt.id,
|
121
|
+
)
|
122
|
+
)
|
123
|
+
|
124
|
+
raise GraphInterrupt(interrupts)
|
125
|
+
|
126
|
+
else:
|
127
|
+
raise ValueError(
|
128
|
+
f"Unknown subgraph error from orchestrator: {error!s}"
|
129
|
+
)
|
130
|
+
|
131
|
+
except grpc.RpcError as e:
|
132
|
+
# grpc_message is inside str(e)
|
133
|
+
details = str(e)
|
134
|
+
if details and "recursion limit exceeded" in details.lower():
|
135
|
+
raise GraphRecursionError
|
136
|
+
if details and "invalid context format" in details.lower():
|
137
|
+
raise TypeError
|
138
|
+
if details and "invalid pydantic context format" in details.lower():
|
139
|
+
import json
|
140
|
+
|
141
|
+
# Extract the JSON error data from the error message
|
142
|
+
error_msg = str(e)
|
143
|
+
if ": {" in error_msg:
|
144
|
+
json_part = "{" + error_msg.split(": {")[1]
|
145
|
+
try:
|
146
|
+
error_data = json.loads(json_part)
|
147
|
+
raise ValidationError.from_exception_data(
|
148
|
+
error_data["title"], error_data["errors"]
|
149
|
+
)
|
150
|
+
except (json.JSONDecodeError, KeyError) as e:
|
151
|
+
logger.error(f"JSONDecodeError: {e}")
|
152
|
+
# Fallback if parsing fails
|
153
|
+
raise ValidationError.from_exception_data(
|
154
|
+
"ValidationError",
|
155
|
+
[
|
156
|
+
{
|
157
|
+
"type": "value_error",
|
158
|
+
"loc": ("context",),
|
159
|
+
"msg": "invalid pydantic context format",
|
160
|
+
"input": None,
|
161
|
+
}
|
162
|
+
],
|
163
|
+
)
|
164
|
+
raise
|
165
|
+
|
166
|
+
# decode response
|
167
|
+
return decode_response(response, stream_mode)
|
168
|
+
|
169
|
+
except Exception as e:
|
170
|
+
if isinstance(e, grpc.RpcError):
|
171
|
+
logger.error(f"gRPC client/runtime error: {e!s}")
|
172
|
+
raise e
|
173
|
+
|
174
|
+
|
175
|
+
def _invoke_wrapper(
|
176
|
+
runtime_client: LangGraphRuntimeStub,
|
177
|
+
logger: logging.Logger,
|
178
|
+
pregel_self: Pregel, # This is the actual Pregel instance
|
179
|
+
input,
|
180
|
+
config=None,
|
181
|
+
context=None,
|
182
|
+
stream_mode=["values"],
|
183
|
+
output_keys=None,
|
184
|
+
interrupt_before=None,
|
185
|
+
interrupt_after=None,
|
186
|
+
durability=None,
|
187
|
+
debug=None,
|
188
|
+
subgraphs=False,
|
189
|
+
) -> dict[str, Any] | Any:
|
190
|
+
"""Wrapper that handles the actual invoke logic."""
|
191
|
+
|
192
|
+
# subgraph names coerced when initializing executor
|
193
|
+
graph_name = pregel_self.name
|
194
|
+
|
195
|
+
logger.info(f"SUBGRAPH INVOKE ENCOUNTERED: {graph_name}")
|
196
|
+
|
197
|
+
# TODO: Hacky way of retrieving runtime from runnable context
|
198
|
+
if not context:
|
199
|
+
try:
|
200
|
+
runtime = get_runtime()
|
201
|
+
if runtime.context:
|
202
|
+
context = runtime.context
|
203
|
+
except Exception as e:
|
204
|
+
logger.error(f"failed to retrive parent runtime for subgraph: {e}")
|
205
|
+
|
206
|
+
# need to get config of parent because wont be available in orchestrator
|
207
|
+
if parent_config := var_child_runnable_config.get({}):
|
208
|
+
config = ensure_config(config, parent_config)
|
209
|
+
|
210
|
+
try:
|
211
|
+
# create request
|
212
|
+
invoke_request = runtime_pb2.InvokeRequest(
|
213
|
+
graph_name=graph_name,
|
214
|
+
input=input_to_pb(input),
|
215
|
+
config=config_to_pb(config),
|
216
|
+
context=context_to_pb(context),
|
217
|
+
run_opts=create_runopts_pb(
|
218
|
+
stream_mode,
|
219
|
+
output_keys,
|
220
|
+
interrupt_before,
|
221
|
+
interrupt_after,
|
222
|
+
durability,
|
223
|
+
debug,
|
224
|
+
subgraphs,
|
225
|
+
),
|
226
|
+
)
|
227
|
+
|
228
|
+
try:
|
229
|
+
response: OutputChunk = runtime_client.Invoke(invoke_request)
|
230
|
+
|
231
|
+
if response.WhichOneof("message") == "error":
|
232
|
+
error = response.error.error
|
233
|
+
|
234
|
+
if error.WhichOneof("error_type") == "graph_interrupt":
|
235
|
+
graph_interrupt = error.graph_interrupt
|
236
|
+
|
237
|
+
interrupts = []
|
238
|
+
|
239
|
+
for interrupt in graph_interrupt.interrupts:
|
240
|
+
interrupts.append(
|
241
|
+
Interrupt(
|
242
|
+
value=serde.get_serializer().loads_typed(
|
243
|
+
(
|
244
|
+
interrupt.value.base_value.method,
|
245
|
+
interrupt.value.base_value.value,
|
246
|
+
)
|
247
|
+
),
|
248
|
+
id=interrupt.id,
|
249
|
+
)
|
250
|
+
)
|
251
|
+
|
252
|
+
raise GraphInterrupt(interrupts)
|
253
|
+
|
254
|
+
else:
|
255
|
+
raise ValueError(
|
256
|
+
f"Unknown subgraph error from orchestrator: {error!s}"
|
257
|
+
)
|
258
|
+
|
259
|
+
except grpc.RpcError as e:
|
260
|
+
# grpc_message is inside str(e)
|
261
|
+
details = str(e)
|
262
|
+
if details and "recursion limit exceeded" in details.lower():
|
263
|
+
raise GraphRecursionError
|
264
|
+
if details and "invalid context format" in details.lower():
|
265
|
+
raise TypeError
|
266
|
+
if details and "invalid pydantic context format" in details.lower():
|
267
|
+
import json
|
268
|
+
|
269
|
+
# Extract the JSON error data from the error message
|
270
|
+
error_msg = str(e)
|
271
|
+
if ": {" in error_msg:
|
272
|
+
json_part = "{" + error_msg.split(": {")[1]
|
273
|
+
try:
|
274
|
+
error_data = json.loads(json_part)
|
275
|
+
raise ValidationError.from_exception_data(
|
276
|
+
error_data["title"], error_data["errors"]
|
277
|
+
)
|
278
|
+
except (json.JSONDecodeError, KeyError) as e:
|
279
|
+
logger.error(f"JSONDecodeError: {e}")
|
280
|
+
# Fallback if parsing fails
|
281
|
+
raise ValidationError.from_exception_data(
|
282
|
+
"ValidationError",
|
283
|
+
[
|
284
|
+
{
|
285
|
+
"type": "value_error",
|
286
|
+
"loc": ("context",),
|
287
|
+
"msg": "invalid pydantic context format",
|
288
|
+
"input": None,
|
289
|
+
}
|
290
|
+
],
|
291
|
+
)
|
292
|
+
raise
|
293
|
+
|
294
|
+
# decode response
|
295
|
+
return decode_response(response, stream_mode)
|
296
|
+
|
297
|
+
except Exception as e:
|
298
|
+
if isinstance(e, grpc.RpcError):
|
299
|
+
logger.error(f"gRPC client/runtime error: {e!s}")
|
300
|
+
raise e
|
301
|
+
|
302
|
+
|
303
|
+
__all__ = [
|
304
|
+
"_patch_pregel",
|
305
|
+
]
|
@@ -0,0 +1,340 @@
|
|
1
|
+
import base64
|
2
|
+
import copy
|
3
|
+
import re
|
4
|
+
from collections.abc import Sequence
|
5
|
+
from typing import Any, cast
|
6
|
+
|
7
|
+
from google.protobuf.json_format import MessageToDict
|
8
|
+
from langchain_core.messages import AIMessageChunk, BaseMessage
|
9
|
+
from langchain_core.messages.utils import convert_to_messages
|
10
|
+
from langchain_core.runnables import RunnableConfig
|
11
|
+
from langgraph._internal._config import _is_not_empty
|
12
|
+
from langgraph._internal._constants import (
|
13
|
+
CONFIG_KEY_CHECKPOINT_ID,
|
14
|
+
CONFIG_KEY_CHECKPOINT_MAP,
|
15
|
+
CONFIG_KEY_CHECKPOINT_NS,
|
16
|
+
CONFIG_KEY_DURABILITY,
|
17
|
+
CONFIG_KEY_RESUME_MAP,
|
18
|
+
CONFIG_KEY_RESUMING,
|
19
|
+
CONFIG_KEY_TASK_ID,
|
20
|
+
CONFIG_KEY_THREAD_ID,
|
21
|
+
)
|
22
|
+
from langgraph.pregel.debug import CheckpointMetadata
|
23
|
+
from langgraph.types import StateSnapshot
|
24
|
+
|
25
|
+
from langgraph_executor import serde
|
26
|
+
from langgraph_executor.common import reconstruct_config, val_to_pb
|
27
|
+
from langgraph_executor.pb import runtime_pb2, types_pb2
|
28
|
+
|
29
|
+
|
30
|
+
def input_to_pb(input):
|
31
|
+
return val_to_pb(None, input)
|
32
|
+
|
33
|
+
|
34
|
+
def _is_present_and_not_empty(config: RunnableConfig, key: Any) -> bool:
|
35
|
+
return key in config and _is_not_empty(config[key])
|
36
|
+
|
37
|
+
|
38
|
+
def maybe_update_reserved_configurable(
|
39
|
+
key: str, value: Any, reserved_configurable: types_pb2.ReservedConfigurable
|
40
|
+
) -> bool:
|
41
|
+
if key == CONFIG_KEY_RESUMING:
|
42
|
+
reserved_configurable.resuming = bool(value)
|
43
|
+
elif key == CONFIG_KEY_TASK_ID:
|
44
|
+
reserved_configurable.task_id = str(value)
|
45
|
+
elif key == CONFIG_KEY_THREAD_ID:
|
46
|
+
reserved_configurable.thread_id = str(value)
|
47
|
+
elif key == CONFIG_KEY_CHECKPOINT_MAP:
|
48
|
+
reserved_configurable.checkpoint_map.update(cast(dict[str, str], value))
|
49
|
+
elif key == CONFIG_KEY_CHECKPOINT_ID:
|
50
|
+
reserved_configurable.checkpoint_id = str(value)
|
51
|
+
elif key == CONFIG_KEY_CHECKPOINT_NS:
|
52
|
+
reserved_configurable.checkpoint_ns = str(value)
|
53
|
+
elif key == CONFIG_KEY_RESUME_MAP and value is not None:
|
54
|
+
resume_map = cast(dict[str, Any], value)
|
55
|
+
for k, v in resume_map.items():
|
56
|
+
pb_value = val_to_pb(None, v)
|
57
|
+
reserved_configurable.resume_map[k].CopyFrom(pb_value)
|
58
|
+
elif key == CONFIG_KEY_DURABILITY:
|
59
|
+
reserved_configurable.durability = str(value)
|
60
|
+
else:
|
61
|
+
return False
|
62
|
+
|
63
|
+
return True
|
64
|
+
|
65
|
+
|
66
|
+
def merge_configurables(
|
67
|
+
config: RunnableConfig, pb_config: types_pb2.RunnableConfig
|
68
|
+
) -> None:
|
69
|
+
if not _is_present_and_not_empty(config, "configurable"):
|
70
|
+
return
|
71
|
+
|
72
|
+
configurable = pb_config.configurable
|
73
|
+
reserved_configurable = pb_config.reserved_configurable
|
74
|
+
|
75
|
+
for k, v in config["configurable"].items():
|
76
|
+
if not maybe_update_reserved_configurable(k, v, reserved_configurable):
|
77
|
+
try:
|
78
|
+
configurable.update({k: v})
|
79
|
+
except ValueError: # TODO handle this
|
80
|
+
print(f"could not pass config field {k}:{v} to proto")
|
81
|
+
|
82
|
+
|
83
|
+
def config_to_pb(config: RunnableConfig) -> types_pb2.RunnableConfig:
|
84
|
+
if not config:
|
85
|
+
return types_pb2.RunnableConfig()
|
86
|
+
|
87
|
+
# Prepare kwargs for construction
|
88
|
+
kwargs = {}
|
89
|
+
|
90
|
+
if _is_present_and_not_empty(config, "run_name"):
|
91
|
+
kwargs["run_name"] = config["run_name"]
|
92
|
+
|
93
|
+
if _is_present_and_not_empty(config, "run_id"):
|
94
|
+
kwargs["run_id"] = str(config["run_id"]) if config["run_id"] else ""
|
95
|
+
|
96
|
+
if _is_present_and_not_empty(config, "max_concurrency"):
|
97
|
+
kwargs["max_concurrency"] = int(config["max_concurrency"])
|
98
|
+
|
99
|
+
if _is_present_and_not_empty(config, "recursion_limit"):
|
100
|
+
kwargs["recursion_limit"] = config["recursion_limit"]
|
101
|
+
|
102
|
+
# Create the config with initial values
|
103
|
+
pb_config = types_pb2.RunnableConfig(**kwargs)
|
104
|
+
|
105
|
+
# Handle collections after construction
|
106
|
+
if _is_present_and_not_empty(config, "tags"):
|
107
|
+
if isinstance(config["tags"], list):
|
108
|
+
pb_config.tags.extend(config["tags"])
|
109
|
+
elif isinstance(config["tags"], str):
|
110
|
+
pb_config.tags.append(config["tags"])
|
111
|
+
|
112
|
+
if _is_present_and_not_empty(config, "metadata"):
|
113
|
+
pb_config.metadata.update(config["metadata"])
|
114
|
+
|
115
|
+
merge_configurables(config, pb_config)
|
116
|
+
|
117
|
+
return pb_config
|
118
|
+
|
119
|
+
|
120
|
+
def context_to_pb(context: dict[str, Any] | Any) -> types_pb2.Context | None:
|
121
|
+
if context is None:
|
122
|
+
return None
|
123
|
+
|
124
|
+
# Convert dataclass or other objects to dict if needed
|
125
|
+
if hasattr(context, "__dict__") and not hasattr(context, "items"):
|
126
|
+
# Convert dataclass to dict
|
127
|
+
context_dict = context.__dict__
|
128
|
+
elif hasattr(context, "items"):
|
129
|
+
# Already a dict-like object
|
130
|
+
context_dict = context
|
131
|
+
else:
|
132
|
+
# Try to convert to dict using vars()
|
133
|
+
context_dict = vars(context) if hasattr(context, "__dict__") else {}
|
134
|
+
|
135
|
+
return types_pb2.Context(context=context_dict)
|
136
|
+
|
137
|
+
|
138
|
+
# TODO
|
139
|
+
def create_runopts_pb(
|
140
|
+
stream_mode,
|
141
|
+
output_keys,
|
142
|
+
interrupt_before,
|
143
|
+
interrupt_after,
|
144
|
+
durability,
|
145
|
+
debug,
|
146
|
+
subgraphs,
|
147
|
+
) -> runtime_pb2.RunOpts:
|
148
|
+
# Prepare kwargs for construction
|
149
|
+
kwargs = {}
|
150
|
+
|
151
|
+
if durability is not None:
|
152
|
+
kwargs["durability"] = durability
|
153
|
+
|
154
|
+
if debug is not None:
|
155
|
+
kwargs["debug"] = debug
|
156
|
+
|
157
|
+
if subgraphs is not None:
|
158
|
+
kwargs["subgraphs"] = subgraphs
|
159
|
+
|
160
|
+
if output_keys is not None:
|
161
|
+
string_or_slice_pb = None
|
162
|
+
if isinstance(output_keys, str):
|
163
|
+
string_or_slice_pb = types_pb2.StringOrSlice(
|
164
|
+
is_string=True, values=[output_keys]
|
165
|
+
)
|
166
|
+
elif isinstance(output_keys, list[str]):
|
167
|
+
string_or_slice_pb = types_pb2.StringOrSlice(
|
168
|
+
is_string=False, values=output_keys
|
169
|
+
)
|
170
|
+
|
171
|
+
if string_or_slice_pb is not None:
|
172
|
+
kwargs["output_keys"] = string_or_slice_pb
|
173
|
+
|
174
|
+
# Create the RunOpts with initial values
|
175
|
+
run_opts = runtime_pb2.RunOpts(**kwargs)
|
176
|
+
|
177
|
+
# Handle repeated fields after construction
|
178
|
+
if stream_mode is not None:
|
179
|
+
if isinstance(stream_mode, str):
|
180
|
+
run_opts.stream_mode.append(stream_mode)
|
181
|
+
elif isinstance(stream_mode, list):
|
182
|
+
run_opts.stream_mode.extend(stream_mode)
|
183
|
+
|
184
|
+
if interrupt_before is not None:
|
185
|
+
run_opts.interrupt_before.extend(interrupt_before)
|
186
|
+
|
187
|
+
if interrupt_after is not None:
|
188
|
+
run_opts.interrupt_after.extend(interrupt_after)
|
189
|
+
|
190
|
+
# Note: checkpoint_during field doesn't exist in RunOpts proto
|
191
|
+
# Ignoring it as it's not in the proto definition
|
192
|
+
|
193
|
+
return run_opts
|
194
|
+
|
195
|
+
|
196
|
+
def decode_response(response, stream_mode):
|
197
|
+
which = response.WhichOneof("message")
|
198
|
+
if which == "error":
|
199
|
+
raise ValueError(response.error)
|
200
|
+
if which == "chunk":
|
201
|
+
return decode_chunk(response.chunk, stream_mode)
|
202
|
+
if which == "chunk_list":
|
203
|
+
return [
|
204
|
+
decode_chunk(chunk.chunk, stream_mode)
|
205
|
+
for chunk in response.chunk_list.chunks
|
206
|
+
]
|
207
|
+
|
208
|
+
raise ValueError("No stream response")
|
209
|
+
|
210
|
+
|
211
|
+
VAL_KEYS = {"method", "value"}
|
212
|
+
|
213
|
+
|
214
|
+
def deser_vals(chunk: dict[str, Any]):
|
215
|
+
return _deser_vals(copy.deepcopy(chunk))
|
216
|
+
|
217
|
+
|
218
|
+
def _deser_vals(current_chunk):
|
219
|
+
if isinstance(current_chunk, list):
|
220
|
+
return [_deser_vals(v) for v in current_chunk]
|
221
|
+
if not isinstance(current_chunk, dict):
|
222
|
+
return current_chunk
|
223
|
+
if set(current_chunk.keys()) == VAL_KEYS:
|
224
|
+
return serde.get_serializer().loads_typed(
|
225
|
+
(current_chunk["method"], base64.b64decode(current_chunk["value"]))
|
226
|
+
)
|
227
|
+
for k, v in current_chunk.items():
|
228
|
+
if isinstance(v, dict | Sequence):
|
229
|
+
current_chunk[k] = _deser_vals(v)
|
230
|
+
return current_chunk
|
231
|
+
|
232
|
+
|
233
|
+
def decode_state_history_response(response):
|
234
|
+
if not response:
|
235
|
+
return
|
236
|
+
|
237
|
+
return [reconstruct_state_snapshot(state_pb) for state_pb in response.history]
|
238
|
+
|
239
|
+
|
240
|
+
def decode_state_response(response):
|
241
|
+
if not response:
|
242
|
+
return
|
243
|
+
|
244
|
+
return reconstruct_state_snapshot(response.state)
|
245
|
+
|
246
|
+
|
247
|
+
# TODO finish reconstructing these
|
248
|
+
def reconstruct_state_snapshot(state_pb: types_pb2.StateSnapshot) -> StateSnapshot:
|
249
|
+
return StateSnapshot(
|
250
|
+
values=deser_vals(MessageToDict(state_pb.values)),
|
251
|
+
next=tuple(state_pb.next),
|
252
|
+
config=reconstruct_config(state_pb.config),
|
253
|
+
metadata=CheckpointMetadata(**MessageToDict(state_pb.metadata)),
|
254
|
+
created_at=state_pb.created_at,
|
255
|
+
parent_config=reconstruct_config(state_pb.parent_config),
|
256
|
+
tasks=tuple(),
|
257
|
+
interrupts=tuple(),
|
258
|
+
)
|
259
|
+
|
260
|
+
|
261
|
+
def decode_chunk(chunk, stream_mode):
|
262
|
+
d = cast(dict[str, Any], deser_vals(MessageToDict(chunk)))
|
263
|
+
stream_mode = stream_mode or ()
|
264
|
+
mode = d.get("mode")
|
265
|
+
ns = d.get("ns")
|
266
|
+
# Handle messages mode specifically - we don't always send the stream mode in the chunk
|
267
|
+
# Because if user only has 1 mode, we exclude it since it is implied
|
268
|
+
if mode == "messages" or (mode is None and "messages" in stream_mode):
|
269
|
+
return (ns, extract_message_chunk(d["payload"]))
|
270
|
+
|
271
|
+
# Handle custom mode primitive extraction
|
272
|
+
payload = d.get("payload")
|
273
|
+
|
274
|
+
# For custom mode, unwrap primitives from "data" wrapper
|
275
|
+
if mode == "custom" or (mode is None and "custom" in stream_mode):
|
276
|
+
if isinstance(payload, dict) and len(payload) == 1 and "data" in payload:
|
277
|
+
payload = payload["data"]
|
278
|
+
|
279
|
+
# Regular logic for all modes
|
280
|
+
if ns:
|
281
|
+
if mode:
|
282
|
+
return (ns, mode, payload)
|
283
|
+
return (ns, payload)
|
284
|
+
if mode:
|
285
|
+
return (mode, payload)
|
286
|
+
|
287
|
+
return payload
|
288
|
+
|
289
|
+
|
290
|
+
class AnyStr(str):
|
291
|
+
def __init__(self, prefix: str | re.Pattern = "") -> None:
|
292
|
+
super().__init__()
|
293
|
+
self.prefix = prefix
|
294
|
+
|
295
|
+
def __eq__(self, other: object) -> bool:
|
296
|
+
return isinstance(other, str) and (
|
297
|
+
other.startswith(self.prefix)
|
298
|
+
if isinstance(self.prefix, str)
|
299
|
+
else self.prefix.match(other)
|
300
|
+
)
|
301
|
+
|
302
|
+
def __hash__(self) -> int:
|
303
|
+
return hash((str(self), self.prefix))
|
304
|
+
|
305
|
+
|
306
|
+
def extract_message_chunk(
|
307
|
+
payload: dict[str, Any],
|
308
|
+
) -> tuple[BaseMessage, dict[str, Any]]:
|
309
|
+
"""Extract (BaseMessage, metadata) tuple from messages mode payload"""
|
310
|
+
|
311
|
+
# Extract writes from payload and deserialize the message data
|
312
|
+
message_data = payload.get("message", {}).get("message", {})
|
313
|
+
metadata = payload.get("metadata", {})
|
314
|
+
message_type = message_data.get("type", "ai")
|
315
|
+
if message_type.endswith("Chunk"):
|
316
|
+
message_id = message_data.get("id")
|
317
|
+
content = message_data.get("content", "")
|
318
|
+
additional_kwargs = message_data.get("additional_kwargs", {})
|
319
|
+
usage_metadata = message_data.get("usage_metadata", None)
|
320
|
+
tool_calls = message_data.get("tool_calls", [])
|
321
|
+
name = message_data.get("name")
|
322
|
+
tool_call_chunks = message_data.get("tool_call_chunks", [])
|
323
|
+
response_metadata = message_data.get("response_metadata", {})
|
324
|
+
if message_type == "AIMessageChunk":
|
325
|
+
message = AIMessageChunk(
|
326
|
+
content=content,
|
327
|
+
id=message_id,
|
328
|
+
additional_kwargs=additional_kwargs,
|
329
|
+
tool_calls=tool_calls,
|
330
|
+
name=name,
|
331
|
+
usage_metadata=usage_metadata,
|
332
|
+
tool_call_chunks=tool_call_chunks,
|
333
|
+
response_metadata=response_metadata,
|
334
|
+
)
|
335
|
+
return (message, metadata)
|
336
|
+
else:
|
337
|
+
raise ValueError(f"Unknown message type: {message_type}")
|
338
|
+
|
339
|
+
else:
|
340
|
+
return convert_to_messages([message_data])[0], metadata
|