langgraph-executor 0.0.1a4__tar.gz → 0.0.1a6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/PKG-INFO +1 -1
  2. langgraph_executor-0.0.1a6/langgraph_executor/__init__.py +1 -0
  3. langgraph_executor-0.0.1a6/langgraph_executor/client/patch.py +305 -0
  4. langgraph_executor-0.0.1a6/langgraph_executor/client/utils.py +340 -0
  5. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/common.py +57 -4
  6. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/executor.py +22 -8
  7. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/executor_base.py +92 -4
  8. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/extract_graph.py +18 -25
  9. langgraph_executor-0.0.1a6/langgraph_executor/pb/executor_pb2.py +86 -0
  10. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/executor_pb2.pyi +39 -0
  11. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/executor_pb2_grpc.py +44 -0
  12. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/executor_pb2_grpc.pyi +20 -0
  13. langgraph_executor-0.0.1a6/langgraph_executor/pb/graph_pb2.py +51 -0
  14. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/graph_pb2.pyi +13 -14
  15. langgraph_executor-0.0.1a6/langgraph_executor/pb/runtime_pb2.py +86 -0
  16. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/runtime_pb2.pyi +197 -7
  17. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/runtime_pb2_grpc.py +229 -1
  18. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/runtime_pb2_grpc.pyi +133 -6
  19. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/types_pb2.py +17 -9
  20. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/types_pb2.pyi +95 -0
  21. langgraph_executor-0.0.1a6/langgraph_executor/py.typed +0 -0
  22. langgraph_executor-0.0.1a4/langgraph_executor/__init__.py +0 -1
  23. langgraph_executor-0.0.1a4/langgraph_executor/pb/executor_pb2.py +0 -84
  24. langgraph_executor-0.0.1a4/langgraph_executor/pb/graph_pb2.py +0 -51
  25. langgraph_executor-0.0.1a4/langgraph_executor/pb/runtime_pb2.py +0 -68
  26. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/.gitignore +0 -0
  27. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/README.md +0 -0
  28. {langgraph_executor-0.0.1a4/langgraph_executor/pb → langgraph_executor-0.0.1a6/langgraph_executor/client}/__init__.py +0 -0
  29. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/example.py +0 -0
  30. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/execute_task.py +0 -0
  31. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/info_logger.py +0 -0
  32. /langgraph_executor-0.0.1a4/langgraph_executor/py.typed → /langgraph_executor-0.0.1a6/langgraph_executor/pb/__init__.py +0 -0
  33. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/graph_pb2_grpc.py +0 -0
  34. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/graph_pb2_grpc.pyi +0 -0
  35. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/types_pb2_grpc.py +0 -0
  36. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/pb/types_pb2_grpc.pyi +0 -0
  37. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/server.py +0 -0
  38. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/setup.sh +0 -0
  39. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/langgraph_executor/stream_utils.py +0 -0
  40. {langgraph_executor-0.0.1a4 → langgraph_executor-0.0.1a6}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langgraph-executor
3
- Version: 0.0.1a4
3
+ Version: 0.0.1a6
4
4
  Summary: LangGraph python RPC server executable by the langgraph-go orchestrator.
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: grpcio>=1.73.1
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1a6"
@@ -0,0 +1,305 @@
1
+ import logging
2
+ from typing import Any
3
+
4
+ import grpc
5
+ from langgraph._internal._config import ensure_config
6
+ from langgraph.errors import GraphInterrupt, GraphRecursionError
7
+ from langgraph.pregel import Pregel
8
+ from langgraph.runtime import get_runtime
9
+ from langgraph.types import Interrupt
10
+ from pydantic import ValidationError
11
+
12
+ from langgraph_executor.client.utils import (
13
+ SERDE,
14
+ config_to_pb,
15
+ context_to_pb,
16
+ create_runopts_pb,
17
+ decode_response,
18
+ input_to_pb,
19
+ )
20
+ from langgraph_executor.common import var_child_runnable_config
21
+ from langgraph_executor.pb import runtime_pb2
22
+ from langgraph_executor.pb.runtime_pb2 import OutputChunk
23
+ from langgraph_executor.pb.runtime_pb2_grpc import LangGraphRuntimeStub
24
+
25
+
26
+ def _patch_pregel(runtime_client: LangGraphRuntimeStub, logger: logging.Logger):
27
+ async def patched_ainvoke(pregel_self, input, config=None, **kwargs):
28
+ return await _ainvoke_wrapper(
29
+ runtime_client, logger, pregel_self, input, config, **kwargs
30
+ )
31
+
32
+ def patched_invoke(pregel_self, input, config=None, **kwargs):
33
+ return _invoke_wrapper(
34
+ runtime_client, logger, pregel_self, input, config, **kwargs
35
+ )
36
+
37
+ Pregel.ainvoke = patched_ainvoke # type: ignore[invalid-assignment]
38
+ Pregel.invoke = patched_invoke # type: ignore[invalid-assignment]
39
+
40
+
41
+ async def _ainvoke_wrapper(
42
+ runtime_client: LangGraphRuntimeStub,
43
+ logger: logging.Logger,
44
+ pregel_self: Pregel, # This is the actual Pregel instance
45
+ input,
46
+ config=None,
47
+ context=None,
48
+ stream_mode=["values"],
49
+ output_keys=None,
50
+ interrupt_before=None,
51
+ interrupt_after=None,
52
+ durability=None,
53
+ debug=None,
54
+ subgraphs=False,
55
+ ) -> dict[str, Any] | Any:
56
+ """Wrapper that handles the actual invoke logic."""
57
+
58
+ # subgraph names coerced when initializing executor
59
+ graph_name = pregel_self.name
60
+
61
+ logger.info(f"SUBGRAPH INVOKE ENCOUNTERED: {graph_name}")
62
+
63
+ # TODO: Hacky way of retrieving runtime from runnable context
64
+ if not context:
65
+ try:
66
+ runtime = get_runtime()
67
+ if runtime.context:
68
+ context = runtime.context
69
+ except Exception as e:
70
+ logger.error(f"failed to retrive parent runtime for subgraph: {e}")
71
+
72
+ if parent_config := var_child_runnable_config.get({}):
73
+ config = ensure_config(config, parent_config)
74
+
75
+ try:
76
+ # create request
77
+ invoke_request = runtime_pb2.InvokeRequest(
78
+ graph_name=graph_name,
79
+ input=input_to_pb(input),
80
+ config=config_to_pb(config),
81
+ context=context_to_pb(context),
82
+ run_opts=create_runopts_pb(
83
+ stream_mode,
84
+ output_keys,
85
+ interrupt_before,
86
+ interrupt_after,
87
+ durability,
88
+ debug,
89
+ subgraphs,
90
+ ),
91
+ )
92
+
93
+ # get response - if this blocks, you might need to make it async
94
+ try:
95
+ # Option 1: If runtime_client.Invoke is synchronous and might block:
96
+ import asyncio
97
+
98
+ loop = asyncio.get_event_loop()
99
+ response = await loop.run_in_executor(
100
+ None, runtime_client.Invoke, invoke_request
101
+ )
102
+
103
+ if response.WhichOneof("message") == "error":
104
+ error = response.error.error
105
+
106
+ if error.WhichOneof("error_type") == "graph_interrupt":
107
+ graph_interrupt = error.graph_interrupt
108
+
109
+ interrupts = []
110
+
111
+ for interrupt in graph_interrupt.interrupts:
112
+ interrupts.append(
113
+ Interrupt(
114
+ value=SERDE.loads_typed(
115
+ (
116
+ interrupt.value.base_value.method,
117
+ interrupt.value.base_value.value,
118
+ )
119
+ ),
120
+ id=interrupt.id,
121
+ )
122
+ )
123
+
124
+ raise GraphInterrupt(interrupts)
125
+
126
+ else:
127
+ raise ValueError(
128
+ f"Unknown subgraph error from orchestrator: {error!s}"
129
+ )
130
+
131
+ except grpc.RpcError as e:
132
+ # grpc_message is inside str(e)
133
+ details = str(e)
134
+ if details and "recursion limit exceeded" in details.lower():
135
+ raise GraphRecursionError
136
+ if details and "invalid context format" in details.lower():
137
+ raise TypeError
138
+ if details and "invalid pydantic context format" in details.lower():
139
+ import json
140
+
141
+ # Extract the JSON error data from the error message
142
+ error_msg = str(e)
143
+ if ": {" in error_msg:
144
+ json_part = "{" + error_msg.split(": {")[1]
145
+ try:
146
+ error_data = json.loads(json_part)
147
+ raise ValidationError.from_exception_data(
148
+ error_data["title"], error_data["errors"]
149
+ )
150
+ except (json.JSONDecodeError, KeyError) as e:
151
+ logger.error(f"JSONDecodeError: {e}")
152
+ # Fallback if parsing fails
153
+ raise ValidationError.from_exception_data(
154
+ "ValidationError",
155
+ [
156
+ {
157
+ "type": "value_error",
158
+ "loc": ("context",),
159
+ "msg": "invalid pydantic context format",
160
+ "input": None,
161
+ }
162
+ ],
163
+ )
164
+ raise
165
+
166
+ # decode response
167
+ return decode_response(response, stream_mode)
168
+
169
+ except Exception as e:
170
+ if isinstance(e, grpc.RpcError):
171
+ logger.error(f"gRPC client/runtime error: {e!s}")
172
+ raise e
173
+
174
+
175
+ def _invoke_wrapper(
176
+ runtime_client: LangGraphRuntimeStub,
177
+ logger: logging.Logger,
178
+ pregel_self: Pregel, # This is the actual Pregel instance
179
+ input,
180
+ config=None,
181
+ context=None,
182
+ stream_mode=["values"],
183
+ output_keys=None,
184
+ interrupt_before=None,
185
+ interrupt_after=None,
186
+ durability=None,
187
+ debug=None,
188
+ subgraphs=False,
189
+ ) -> dict[str, Any] | Any:
190
+ """Wrapper that handles the actual invoke logic."""
191
+
192
+ # subgraph names coerced when initializing executor
193
+ graph_name = pregel_self.name
194
+
195
+ logger.info(f"SUBGRAPH INVOKE ENCOUNTERED: {graph_name}")
196
+
197
+ # TODO: Hacky way of retrieving runtime from runnable context
198
+ if not context:
199
+ try:
200
+ runtime = get_runtime()
201
+ if runtime.context:
202
+ context = runtime.context
203
+ except Exception as e:
204
+ logger.error(f"failed to retrive parent runtime for subgraph: {e}")
205
+
206
+ # need to get config of parent because wont be available in orchestrator
207
+ if parent_config := var_child_runnable_config.get({}):
208
+ config = ensure_config(config, parent_config)
209
+
210
+ try:
211
+ # create request
212
+ invoke_request = runtime_pb2.InvokeRequest(
213
+ graph_name=graph_name,
214
+ input=input_to_pb(input),
215
+ config=config_to_pb(config),
216
+ context=context_to_pb(context),
217
+ run_opts=create_runopts_pb(
218
+ stream_mode,
219
+ output_keys,
220
+ interrupt_before,
221
+ interrupt_after,
222
+ durability,
223
+ debug,
224
+ subgraphs,
225
+ ),
226
+ )
227
+
228
+ try:
229
+ response: OutputChunk = runtime_client.Invoke(invoke_request)
230
+
231
+ if response.WhichOneof("message") == "error":
232
+ error = response.error.error
233
+
234
+ if error.WhichOneof("error_type") == "graph_interrupt":
235
+ graph_interrupt = error.graph_interrupt
236
+
237
+ interrupts = []
238
+
239
+ for interrupt in graph_interrupt.interrupts:
240
+ interrupts.append(
241
+ Interrupt(
242
+ value=SERDE.loads_typed(
243
+ (
244
+ interrupt.value.base_value.method,
245
+ interrupt.value.base_value.value,
246
+ )
247
+ ),
248
+ id=interrupt.id,
249
+ )
250
+ )
251
+
252
+ raise GraphInterrupt(interrupts)
253
+
254
+ else:
255
+ raise ValueError(
256
+ f"Unknown subgraph error from orchestrator: {error!s}"
257
+ )
258
+
259
+ except grpc.RpcError as e:
260
+ # grpc_message is inside str(e)
261
+ details = str(e)
262
+ if details and "recursion limit exceeded" in details.lower():
263
+ raise GraphRecursionError
264
+ if details and "invalid context format" in details.lower():
265
+ raise TypeError
266
+ if details and "invalid pydantic context format" in details.lower():
267
+ import json
268
+
269
+ # Extract the JSON error data from the error message
270
+ error_msg = str(e)
271
+ if ": {" in error_msg:
272
+ json_part = "{" + error_msg.split(": {")[1]
273
+ try:
274
+ error_data = json.loads(json_part)
275
+ raise ValidationError.from_exception_data(
276
+ error_data["title"], error_data["errors"]
277
+ )
278
+ except (json.JSONDecodeError, KeyError) as e:
279
+ logger.error(f"JSONDecodeError: {e}")
280
+ # Fallback if parsing fails
281
+ raise ValidationError.from_exception_data(
282
+ "ValidationError",
283
+ [
284
+ {
285
+ "type": "value_error",
286
+ "loc": ("context",),
287
+ "msg": "invalid pydantic context format",
288
+ "input": None,
289
+ }
290
+ ],
291
+ )
292
+ raise
293
+
294
+ # decode response
295
+ return decode_response(response, stream_mode)
296
+
297
+ except Exception as e:
298
+ if isinstance(e, grpc.RpcError):
299
+ logger.error(f"gRPC client/runtime error: {e!s}")
300
+ raise e
301
+
302
+
303
+ __all__ = [
304
+ "_patch_pregel",
305
+ ]
@@ -0,0 +1,340 @@
1
+ import base64
2
+ import copy
3
+ import re
4
+ from collections.abc import Sequence
5
+ from typing import Any, cast
6
+
7
+ from google.protobuf.json_format import MessageToDict
8
+ from langchain_core.messages import AIMessageChunk, BaseMessage
9
+ from langchain_core.messages.utils import convert_to_messages
10
+ from langchain_core.runnables import RunnableConfig
11
+ from langgraph._internal._config import _is_not_empty
12
+ from langgraph._internal._constants import (
13
+ CONFIG_KEY_CHECKPOINT_ID,
14
+ CONFIG_KEY_CHECKPOINT_MAP,
15
+ CONFIG_KEY_CHECKPOINT_NS,
16
+ CONFIG_KEY_DURABILITY,
17
+ CONFIG_KEY_RESUMING,
18
+ CONFIG_KEY_TASK_ID,
19
+ CONFIG_KEY_THREAD_ID,
20
+ )
21
+ from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
22
+ from langgraph.pregel.debug import CheckpointMetadata
23
+ from langgraph.types import StateSnapshot
24
+
25
+ from langgraph_executor.common import reconstruct_config, val_to_pb
26
+ from langgraph_executor.pb import runtime_pb2, types_pb2
27
+
28
+ SERDE = JsonPlusSerializer()
29
+
30
+
31
+ def input_to_pb(input):
32
+ return val_to_pb(None, input)
33
+
34
+
35
+ def _is_present_and_not_empty(config: RunnableConfig, key: Any) -> bool:
36
+ return key in config and _is_not_empty(config[key])
37
+
38
+
39
+ def maybe_update_reserved_configurable(
40
+ key: str, value: Any, reserved_configurable: types_pb2.ReservedConfigurable
41
+ ) -> bool:
42
+ if key == CONFIG_KEY_RESUMING:
43
+ reserved_configurable.resuming = bool(value)
44
+ elif key == CONFIG_KEY_TASK_ID:
45
+ reserved_configurable.task_id = str(value)
46
+ elif key == CONFIG_KEY_THREAD_ID:
47
+ reserved_configurable.thread_id = str(value)
48
+ elif key == CONFIG_KEY_CHECKPOINT_MAP:
49
+ reserved_configurable.checkpoint_map.update(cast(dict[str, str], value))
50
+ elif key == CONFIG_KEY_CHECKPOINT_ID:
51
+ reserved_configurable.checkpoint_id = str(value)
52
+ elif key == CONFIG_KEY_CHECKPOINT_NS:
53
+ reserved_configurable.checkpoint_ns = str(value)
54
+ # elif key == CONFIG_KEY_PREVIOUS:
55
+ # serde = JsonPlusSerializer()
56
+ # meth, ser = serde.dumps_typed(value)
57
+ # reserved_configurable.previous = types_pb2.SerializedValue(method=meth, value=bytes(ser))
58
+ elif key == CONFIG_KEY_DURABILITY:
59
+ reserved_configurable.durability = str(value)
60
+ else:
61
+ return False
62
+
63
+ return True
64
+
65
+
66
+ def merge_configurables(
67
+ config: RunnableConfig, pb_config: types_pb2.RunnableConfig
68
+ ) -> None:
69
+ if not _is_present_and_not_empty(config, "configurable"):
70
+ return
71
+
72
+ configurable = pb_config.configurable
73
+ reserved_configurable = pb_config.reserved_configurable
74
+
75
+ for k, v in config["configurable"].items():
76
+ if not maybe_update_reserved_configurable(k, v, reserved_configurable):
77
+ try:
78
+ configurable.update({k: v})
79
+ except ValueError: # TODO handle this
80
+ print(f"could not pass config field {k}:{v} to proto")
81
+
82
+
83
+ def config_to_pb(config: RunnableConfig) -> types_pb2.RunnableConfig:
84
+ if not config:
85
+ return types_pb2.RunnableConfig()
86
+
87
+ # Prepare kwargs for construction
88
+ kwargs = {}
89
+
90
+ if _is_present_and_not_empty(config, "run_name"):
91
+ kwargs["run_name"] = config["run_name"]
92
+
93
+ if _is_present_and_not_empty(config, "run_id"):
94
+ kwargs["run_id"] = str(config["run_id"]) if config["run_id"] else ""
95
+
96
+ if _is_present_and_not_empty(config, "max_concurrency"):
97
+ kwargs["max_concurrency"] = int(config["max_concurrency"])
98
+
99
+ if _is_present_and_not_empty(config, "recursion_limit"):
100
+ kwargs["recursion_limit"] = config["recursion_limit"]
101
+
102
+ # Create the config with initial values
103
+ pb_config = types_pb2.RunnableConfig(**kwargs)
104
+
105
+ # Handle collections after construction
106
+ if _is_present_and_not_empty(config, "tags"):
107
+ if isinstance(config["tags"], list):
108
+ pb_config.tags.extend(config["tags"])
109
+ elif isinstance(config["tags"], str):
110
+ pb_config.tags.append(config["tags"])
111
+
112
+ if _is_present_and_not_empty(config, "metadata"):
113
+ pb_config.metadata.update(config["metadata"])
114
+
115
+ merge_configurables(config, pb_config)
116
+
117
+ return pb_config
118
+
119
+
120
+ def context_to_pb(context: dict[str, Any] | Any) -> types_pb2.Context | None:
121
+ if context is None:
122
+ return None
123
+
124
+ # Convert dataclass or other objects to dict if needed
125
+ if hasattr(context, "__dict__") and not hasattr(context, "items"):
126
+ # Convert dataclass to dict
127
+ context_dict = context.__dict__
128
+ elif hasattr(context, "items"):
129
+ # Already a dict-like object
130
+ context_dict = context
131
+ else:
132
+ # Try to convert to dict using vars()
133
+ context_dict = vars(context) if hasattr(context, "__dict__") else {}
134
+
135
+ return types_pb2.Context(context=context_dict)
136
+
137
+
138
+ # TODO
139
+ def create_runopts_pb(
140
+ stream_mode,
141
+ output_keys,
142
+ interrupt_before,
143
+ interrupt_after,
144
+ durability,
145
+ debug,
146
+ subgraphs,
147
+ ) -> runtime_pb2.RunOpts:
148
+ # Prepare kwargs for construction
149
+ kwargs = {}
150
+
151
+ if durability is not None:
152
+ kwargs["durability"] = durability
153
+
154
+ if debug is not None:
155
+ kwargs["debug"] = debug
156
+
157
+ if subgraphs is not None:
158
+ kwargs["subgraphs"] = subgraphs
159
+
160
+ if output_keys is not None:
161
+ string_or_slice_pb = None
162
+ if isinstance(output_keys, str):
163
+ string_or_slice_pb = types_pb2.StringOrSlice(
164
+ is_string=True, values=[output_keys]
165
+ )
166
+ elif isinstance(output_keys, list[str]):
167
+ string_or_slice_pb = types_pb2.StringOrSlice(
168
+ is_string=False, values=output_keys
169
+ )
170
+
171
+ if string_or_slice_pb is not None:
172
+ kwargs["output_keys"] = string_or_slice_pb
173
+
174
+ # Create the RunOpts with initial values
175
+ run_opts = runtime_pb2.RunOpts(**kwargs)
176
+
177
+ # Handle repeated fields after construction
178
+ if stream_mode is not None:
179
+ if isinstance(stream_mode, str):
180
+ run_opts.stream_mode.append(stream_mode)
181
+ elif isinstance(stream_mode, list):
182
+ run_opts.stream_mode.extend(stream_mode)
183
+
184
+ if interrupt_before is not None:
185
+ run_opts.interrupt_before.extend(interrupt_before)
186
+
187
+ if interrupt_after is not None:
188
+ run_opts.interrupt_after.extend(interrupt_after)
189
+
190
+ # Note: checkpoint_during field doesn't exist in RunOpts proto
191
+ # Ignoring it as it's not in the proto definition
192
+
193
+ return run_opts
194
+
195
+
196
+ def decode_response(response, stream_mode):
197
+ which = response.WhichOneof("message")
198
+ if which == "error":
199
+ raise ValueError(response.error)
200
+ if which == "chunk":
201
+ return decode_chunk(response.chunk, stream_mode)
202
+ if which == "chunk_list":
203
+ return [
204
+ decode_chunk(chunk.chunk, stream_mode)
205
+ for chunk in response.chunk_list.chunks
206
+ ]
207
+
208
+ raise ValueError("No stream response")
209
+
210
+
211
+ VAL_KEYS = {"method", "value"}
212
+
213
+
214
+ def deser_vals(chunk: dict[str, Any]):
215
+ return _deser_vals(copy.deepcopy(chunk))
216
+
217
+
218
+ def _deser_vals(current_chunk):
219
+ if isinstance(current_chunk, list):
220
+ return [_deser_vals(v) for v in current_chunk]
221
+ if not isinstance(current_chunk, dict):
222
+ return current_chunk
223
+ if set(current_chunk.keys()) == VAL_KEYS:
224
+ return SERDE.loads_typed(
225
+ (current_chunk["method"], base64.b64decode(current_chunk["value"]))
226
+ )
227
+ for k, v in current_chunk.items():
228
+ if isinstance(v, dict | Sequence):
229
+ current_chunk[k] = _deser_vals(v)
230
+ return current_chunk
231
+
232
+
233
+ def decode_state_history_response(response):
234
+ if not response:
235
+ return
236
+
237
+ return [reconstruct_state_snapshot(state_pb) for state_pb in response.history]
238
+
239
+
240
+ def decode_state_response(response):
241
+ if not response:
242
+ return
243
+
244
+ return reconstruct_state_snapshot(response.state)
245
+
246
+
247
+ # TODO finish reconstructing these
248
+ def reconstruct_state_snapshot(state_pb: types_pb2.StateSnapshot) -> StateSnapshot:
249
+ return StateSnapshot(
250
+ values=deser_vals(MessageToDict(state_pb.values)),
251
+ next=tuple(state_pb.next),
252
+ config=reconstruct_config(state_pb.config),
253
+ metadata=CheckpointMetadata(**MessageToDict(state_pb.metadata)),
254
+ created_at=state_pb.created_at,
255
+ parent_config=reconstruct_config(state_pb.parent_config),
256
+ tasks=tuple(),
257
+ interrupts=tuple(),
258
+ )
259
+
260
+
261
+ def decode_chunk(chunk, stream_mode):
262
+ d = cast(dict[str, Any], deser_vals(MessageToDict(chunk)))
263
+ stream_mode = stream_mode or ()
264
+ mode = d.get("mode")
265
+ ns = d.get("ns")
266
+ # Handle messages mode specifically - we don't always send the stream mode in the chunk
267
+ # Because if user only has 1 mode, we exclude it since it is implied
268
+ if mode == "messages" or (mode is None and "messages" in stream_mode):
269
+ return (ns, extract_message_chunk(d["payload"]))
270
+
271
+ # Handle custom mode primitive extraction
272
+ payload = d.get("payload")
273
+
274
+ # For custom mode, unwrap primitives from "data" wrapper
275
+ if mode == "custom" or (mode is None and "custom" in stream_mode):
276
+ if isinstance(payload, dict) and len(payload) == 1 and "data" in payload:
277
+ payload = payload["data"]
278
+
279
+ # Regular logic for all modes
280
+ if ns:
281
+ if mode:
282
+ return (ns, mode, payload)
283
+ return (ns, payload)
284
+ if mode:
285
+ return (mode, payload)
286
+
287
+ return payload
288
+
289
+
290
+ class AnyStr(str):
291
+ def __init__(self, prefix: str | re.Pattern = "") -> None:
292
+ super().__init__()
293
+ self.prefix = prefix
294
+
295
+ def __eq__(self, other: object) -> bool:
296
+ return isinstance(other, str) and (
297
+ other.startswith(self.prefix)
298
+ if isinstance(self.prefix, str)
299
+ else self.prefix.match(other)
300
+ )
301
+
302
+ def __hash__(self) -> int:
303
+ return hash((str(self), self.prefix))
304
+
305
+
306
+ def extract_message_chunk(
307
+ payload: dict[str, Any],
308
+ ) -> tuple[BaseMessage, dict[str, Any]]:
309
+ """Extract (BaseMessage, metadata) tuple from messages mode payload"""
310
+
311
+ # Extract writes from payload and deserialize the message data
312
+ message_data = payload.get("message", {}).get("message", {})
313
+ metadata = payload.get("metadata", {})
314
+ message_type = message_data.get("type", "ai")
315
+ if message_type.endswith("Chunk"):
316
+ message_id = message_data.get("id")
317
+ content = message_data.get("content", "")
318
+ additional_kwargs = message_data.get("additional_kwargs", {})
319
+ usage_metadata = message_data.get("usage_metadata", None)
320
+ tool_calls = message_data.get("tool_calls", [])
321
+ name = message_data.get("name")
322
+ tool_call_chunks = message_data.get("tool_call_chunks", [])
323
+ response_metadata = message_data.get("response_metadata", {})
324
+ if message_type == "AIMessageChunk":
325
+ message = AIMessageChunk(
326
+ content=content,
327
+ id=message_id,
328
+ additional_kwargs=additional_kwargs,
329
+ tool_calls=tool_calls,
330
+ name=name,
331
+ usage_metadata=usage_metadata,
332
+ tool_call_chunks=tool_call_chunks,
333
+ response_metadata=response_metadata,
334
+ )
335
+ return (message, metadata)
336
+ else:
337
+ raise ValueError(f"Unknown message type: {message_type}")
338
+
339
+ else:
340
+ return convert_to_messages([message_data])[0], metadata