langchain 1.0.0a7__py3-none-any.whl → 1.0.0a8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/agents/middleware/__init__.py +0 -2
- langchain/agents/middleware/human_in_the_loop.py +23 -19
- langchain/agents/middleware/prompt_caching.py +22 -5
- langchain/agents/middleware/types.py +425 -1
- langchain/agents/middleware_agent.py +26 -22
- {langchain-1.0.0a7.dist-info → langchain-1.0.0a8.dist-info}/METADATA +1 -1
- {langchain-1.0.0a7.dist-info → langchain-1.0.0a8.dist-info}/RECORD +10 -11
- langchain/agents/middleware/dynamic_system_prompt.py +0 -105
- {langchain-1.0.0a7.dist-info → langchain-1.0.0a8.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a7.dist-info → langchain-1.0.0a8.dist-info}/entry_points.txt +0 -0
- {langchain-1.0.0a7.dist-info → langchain-1.0.0a8.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Middleware plugins for agents."""
|
|
2
2
|
|
|
3
|
-
from .dynamic_system_prompt import DynamicSystemPromptMiddleware
|
|
4
3
|
from .human_in_the_loop import HumanInTheLoopMiddleware
|
|
5
4
|
from .prompt_caching import AnthropicPromptCachingMiddleware
|
|
6
5
|
from .summarization import SummarizationMiddleware
|
|
@@ -11,7 +10,6 @@ __all__ = [
|
|
|
11
10
|
"AgentState",
|
|
12
11
|
# should move to langchain-anthropic if we decide to keep it
|
|
13
12
|
"AnthropicPromptCachingMiddleware",
|
|
14
|
-
"DynamicSystemPromptMiddleware",
|
|
15
13
|
"HumanInTheLoopMiddleware",
|
|
16
14
|
"ModelRequest",
|
|
17
15
|
"SummarizationMiddleware",
|
|
@@ -112,14 +112,14 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
112
112
|
|
|
113
113
|
def __init__(
|
|
114
114
|
self,
|
|
115
|
-
|
|
115
|
+
interrupt_on: dict[str, bool | ToolConfig],
|
|
116
116
|
*,
|
|
117
117
|
description_prefix: str = "Tool execution requires approval",
|
|
118
118
|
) -> None:
|
|
119
119
|
"""Initialize the human in the loop middleware.
|
|
120
120
|
|
|
121
121
|
Args:
|
|
122
|
-
|
|
122
|
+
interrupt_on: Mapping of tool name to allowed actions.
|
|
123
123
|
If a tool doesn't have an entry, it's auto-approved by default.
|
|
124
124
|
* `True` indicates all actions are allowed: accept, edit, and respond.
|
|
125
125
|
* `False` indicates that the tool is auto-approved.
|
|
@@ -130,7 +130,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
130
130
|
"""
|
|
131
131
|
super().__init__()
|
|
132
132
|
resolved_tool_configs: dict[str, ToolConfig] = {}
|
|
133
|
-
for tool_name, tool_config in
|
|
133
|
+
for tool_name, tool_config in interrupt_on.items():
|
|
134
134
|
if isinstance(tool_config, bool):
|
|
135
135
|
if tool_config is True:
|
|
136
136
|
resolved_tool_configs[tool_name] = ToolConfig(
|
|
@@ -138,13 +138,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
138
138
|
allow_edit=True,
|
|
139
139
|
allow_respond=True,
|
|
140
140
|
)
|
|
141
|
-
|
|
141
|
+
elif any(
|
|
142
|
+
tool_config.get(x, False) for x in ["allow_accept", "allow_edit", "allow_respond"]
|
|
143
|
+
):
|
|
142
144
|
resolved_tool_configs[tool_name] = tool_config
|
|
143
|
-
self.
|
|
145
|
+
self.interrupt_on = resolved_tool_configs
|
|
144
146
|
self.description_prefix = description_prefix
|
|
145
147
|
|
|
146
148
|
def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
|
|
147
|
-
"""Trigger
|
|
149
|
+
"""Trigger interrupt flows for relevant tool calls after an AIMessage."""
|
|
148
150
|
messages = state["messages"]
|
|
149
151
|
if not messages:
|
|
150
152
|
return None
|
|
@@ -154,16 +156,16 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
154
156
|
return None
|
|
155
157
|
|
|
156
158
|
# Separate tool calls that need interrupts from those that don't
|
|
157
|
-
|
|
159
|
+
interrupt_tool_calls: list[ToolCall] = []
|
|
158
160
|
auto_approved_tool_calls = []
|
|
159
161
|
|
|
160
162
|
for tool_call in last_ai_msg.tool_calls:
|
|
161
|
-
|
|
163
|
+
interrupt_tool_calls.append(tool_call) if tool_call[
|
|
162
164
|
"name"
|
|
163
|
-
] in self.
|
|
165
|
+
] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
|
|
164
166
|
|
|
165
167
|
# If no interrupts needed, return early
|
|
166
|
-
if not
|
|
168
|
+
if not interrupt_tool_calls:
|
|
167
169
|
return None
|
|
168
170
|
|
|
169
171
|
# Process all tool calls that require interrupts
|
|
@@ -171,11 +173,11 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
171
173
|
artificial_tool_messages: list[ToolMessage] = []
|
|
172
174
|
|
|
173
175
|
# Create interrupt requests for all tools that need approval
|
|
174
|
-
|
|
175
|
-
for tool_call in
|
|
176
|
+
interrupt_requests: list[HumanInTheLoopRequest] = []
|
|
177
|
+
for tool_call in interrupt_tool_calls:
|
|
176
178
|
tool_name = tool_call["name"]
|
|
177
179
|
tool_args = tool_call["args"]
|
|
178
|
-
config = self.
|
|
180
|
+
config = self.interrupt_on[tool_name]
|
|
179
181
|
description = (
|
|
180
182
|
config.get("description")
|
|
181
183
|
or f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
|
|
@@ -189,21 +191,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
189
191
|
"config": config,
|
|
190
192
|
"description": description,
|
|
191
193
|
}
|
|
192
|
-
|
|
194
|
+
interrupt_requests.append(request)
|
|
193
195
|
|
|
194
|
-
responses: list[HumanInTheLoopResponse] = interrupt(
|
|
196
|
+
responses: list[HumanInTheLoopResponse] = interrupt(interrupt_requests)
|
|
195
197
|
|
|
196
198
|
# Validate that the number of responses matches the number of interrupt tool calls
|
|
197
|
-
if (responses_len := len(responses)) != (
|
|
199
|
+
if (responses_len := len(responses)) != (
|
|
200
|
+
interrupt_tool_calls_len := len(interrupt_tool_calls)
|
|
201
|
+
):
|
|
198
202
|
msg = (
|
|
199
203
|
f"Number of human responses ({responses_len}) does not match "
|
|
200
|
-
f"number of hanging tool calls ({
|
|
204
|
+
f"number of hanging tool calls ({interrupt_tool_calls_len})."
|
|
201
205
|
)
|
|
202
206
|
raise ValueError(msg)
|
|
203
207
|
|
|
204
208
|
for i, response in enumerate(responses):
|
|
205
|
-
tool_call =
|
|
206
|
-
config = self.
|
|
209
|
+
tool_call = interrupt_tool_calls[i]
|
|
210
|
+
config = self.interrupt_on[tool_call["name"]]
|
|
207
211
|
|
|
208
212
|
if response["type"] == "accept" and config.get("allow_accept"):
|
|
209
213
|
approved_tool_calls.append(tool_call)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Anthropic prompt caching middleware."""
|
|
2
2
|
|
|
3
3
|
from typing import Literal
|
|
4
|
+
from warnings import warn
|
|
4
5
|
|
|
5
6
|
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
|
|
6
7
|
|
|
@@ -19,6 +20,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
|
19
20
|
type: Literal["ephemeral"] = "ephemeral",
|
|
20
21
|
ttl: Literal["5m", "1h"] = "5m",
|
|
21
22
|
min_messages_to_cache: int = 0,
|
|
23
|
+
unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
|
|
22
24
|
) -> None:
|
|
23
25
|
"""Initialize the middleware with cache control settings.
|
|
24
26
|
|
|
@@ -27,10 +29,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
|
27
29
|
ttl: The time to live for the cache, only "5m" and "1h" are supported.
|
|
28
30
|
min_messages_to_cache: The minimum number of messages until the cache is used,
|
|
29
31
|
default is 0.
|
|
32
|
+
unsupported_model_behavior: The behavior to take when an unsupported model is used.
|
|
33
|
+
"ignore" will ignore the unsupported model and continue without caching.
|
|
34
|
+
"warn" will warn the user and continue without caching.
|
|
35
|
+
"raise" will raise an error and stop the agent.
|
|
30
36
|
"""
|
|
31
37
|
self.type = type
|
|
32
38
|
self.ttl = ttl
|
|
33
39
|
self.min_messages_to_cache = min_messages_to_cache
|
|
40
|
+
self.unsupported_model_behavior = unsupported_model_behavior
|
|
34
41
|
|
|
35
42
|
def modify_model_request( # type: ignore[override]
|
|
36
43
|
self,
|
|
@@ -40,19 +47,29 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
|
|
|
40
47
|
try:
|
|
41
48
|
from langchain_anthropic import ChatAnthropic
|
|
42
49
|
except ImportError:
|
|
50
|
+
ChatAnthropic = None # noqa: N806
|
|
51
|
+
|
|
52
|
+
msg: str | None = None
|
|
53
|
+
|
|
54
|
+
if ChatAnthropic is None:
|
|
43
55
|
msg = (
|
|
44
56
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
|
45
|
-
"Anthropic models."
|
|
57
|
+
"Anthropic models. "
|
|
46
58
|
"Please install langchain-anthropic."
|
|
47
59
|
)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if not isinstance(request.model, ChatAnthropic):
|
|
60
|
+
elif not isinstance(request.model, ChatAnthropic):
|
|
51
61
|
msg = (
|
|
52
62
|
"AnthropicPromptCachingMiddleware caching middleware only supports "
|
|
53
63
|
f"Anthropic models, not instances of {type(request.model)}"
|
|
54
64
|
)
|
|
55
|
-
|
|
65
|
+
|
|
66
|
+
if msg is not None:
|
|
67
|
+
if self.unsupported_model_behavior == "raise":
|
|
68
|
+
raise ValueError(msg)
|
|
69
|
+
if self.unsupported_model_behavior == "warn":
|
|
70
|
+
warn(msg, stacklevel=3)
|
|
71
|
+
else:
|
|
72
|
+
return request
|
|
56
73
|
|
|
57
74
|
messages_count = (
|
|
58
75
|
len(request.messages) + 1 if request.system_prompt else len(request.messages)
|
|
@@ -3,7 +3,20 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
6
|
+
from inspect import signature
|
|
7
|
+
from typing import (
|
|
8
|
+
TYPE_CHECKING,
|
|
9
|
+
Annotated,
|
|
10
|
+
Any,
|
|
11
|
+
ClassVar,
|
|
12
|
+
Generic,
|
|
13
|
+
Literal,
|
|
14
|
+
Protocol,
|
|
15
|
+
TypeAlias,
|
|
16
|
+
TypeGuard,
|
|
17
|
+
cast,
|
|
18
|
+
overload,
|
|
19
|
+
)
|
|
7
20
|
|
|
8
21
|
# needed as top level import for pydantic schema generation on AgentState
|
|
9
22
|
from langchain_core.messages import AnyMessage # noqa: TC002
|
|
@@ -14,9 +27,12 @@ from langgraph.typing import ContextT
|
|
|
14
27
|
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
15
28
|
|
|
16
29
|
if TYPE_CHECKING:
|
|
30
|
+
from collections.abc import Callable
|
|
31
|
+
|
|
17
32
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
18
33
|
from langchain_core.tools import BaseTool
|
|
19
34
|
from langgraph.runtime import Runtime
|
|
35
|
+
from langgraph.types import Command
|
|
20
36
|
|
|
21
37
|
from langchain.agents.structured_output import ResponseFormat
|
|
22
38
|
|
|
@@ -88,6 +104,7 @@ class PublicAgentState(TypedDict, Generic[ResponseT]):
|
|
|
88
104
|
|
|
89
105
|
|
|
90
106
|
StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
|
|
107
|
+
StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
|
|
91
108
|
|
|
92
109
|
|
|
93
110
|
class AgentMiddleware(Generic[StateT, ContextT]):
|
|
@@ -103,6 +120,12 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
103
120
|
tools: list[BaseTool]
|
|
104
121
|
"""Additional tools registered by the middleware."""
|
|
105
122
|
|
|
123
|
+
before_model_jump_to: ClassVar[list[JumpTo]] = []
|
|
124
|
+
"""Valid jump destinations for before_model hook. Used to establish conditional edges."""
|
|
125
|
+
|
|
126
|
+
after_model_jump_to: ClassVar[list[JumpTo]] = []
|
|
127
|
+
"""Valid jump destinations for after_model hook. Used to establish conditional edges."""
|
|
128
|
+
|
|
106
129
|
def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
107
130
|
"""Logic to run before the model is called."""
|
|
108
131
|
|
|
@@ -117,3 +140,404 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
117
140
|
|
|
118
141
|
def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
119
142
|
"""Logic to run after the model is called."""
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class _CallableWithState(Protocol[StateT_contra]):
|
|
146
|
+
"""Callable with AgentState as argument."""
|
|
147
|
+
|
|
148
|
+
def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
|
|
149
|
+
"""Perform some logic with the state."""
|
|
150
|
+
...
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
154
|
+
"""Callable with AgentState and Runtime as arguments."""
|
|
155
|
+
|
|
156
|
+
def __call__(
|
|
157
|
+
self, state: StateT_contra, runtime: Runtime[ContextT]
|
|
158
|
+
) -> dict[str, Any] | Command | None:
|
|
159
|
+
"""Perform some logic with the state and runtime."""
|
|
160
|
+
...
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
|
|
164
|
+
"""Callable with ModelRequest and AgentState as arguments."""
|
|
165
|
+
|
|
166
|
+
def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
|
|
167
|
+
"""Perform some logic with the model request and state."""
|
|
168
|
+
...
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
172
|
+
"""Callable with ModelRequest, AgentState, and Runtime as arguments."""
|
|
173
|
+
|
|
174
|
+
def __call__(
|
|
175
|
+
self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
|
|
176
|
+
) -> ModelRequest:
|
|
177
|
+
"""Perform some logic with the model request, state, and runtime."""
|
|
178
|
+
...
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
_NodeSignature: TypeAlias = (
|
|
182
|
+
_CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
|
|
183
|
+
)
|
|
184
|
+
_ModelRequestSignature: TypeAlias = (
|
|
185
|
+
_CallableWithModelRequestAndState[StateT]
|
|
186
|
+
| _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def is_callable_with_runtime(
|
|
191
|
+
func: _NodeSignature[StateT, ContextT],
|
|
192
|
+
) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]:
|
|
193
|
+
return "runtime" in signature(func).parameters
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def is_callable_with_runtime_and_request(
|
|
197
|
+
func: _ModelRequestSignature[StateT, ContextT],
|
|
198
|
+
) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]:
|
|
199
|
+
return "runtime" in signature(func).parameters
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@overload
|
|
203
|
+
def before_model(
|
|
204
|
+
func: _NodeSignature[StateT, ContextT],
|
|
205
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@overload
|
|
209
|
+
def before_model(
|
|
210
|
+
func: None = None,
|
|
211
|
+
*,
|
|
212
|
+
state_schema: type[StateT] | None = None,
|
|
213
|
+
tools: list[BaseTool] | None = None,
|
|
214
|
+
jump_to: list[JumpTo] | None = None,
|
|
215
|
+
name: str | None = None,
|
|
216
|
+
) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def before_model(
|
|
220
|
+
func: _NodeSignature[StateT, ContextT] | None = None,
|
|
221
|
+
*,
|
|
222
|
+
state_schema: type[StateT] | None = None,
|
|
223
|
+
tools: list[BaseTool] | None = None,
|
|
224
|
+
jump_to: list[JumpTo] | None = None,
|
|
225
|
+
name: str | None = None,
|
|
226
|
+
) -> (
|
|
227
|
+
Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
228
|
+
| AgentMiddleware[StateT, ContextT]
|
|
229
|
+
):
|
|
230
|
+
"""Decorator used to dynamically create a middleware with the before_model hook.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
func: The function to be decorated. Can accept either:
|
|
234
|
+
- `state: StateT` - Just the agent state
|
|
235
|
+
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
236
|
+
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
237
|
+
AgentState schema.
|
|
238
|
+
tools: Optional list of additional tools to register with this middleware.
|
|
239
|
+
jump_to: Optional list of valid jump destinations for conditional edges.
|
|
240
|
+
Valid values are: "tools", "model", "__end__"
|
|
241
|
+
name: Optional name for the generated middleware class. If not provided,
|
|
242
|
+
uses the decorated function's name.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Either an AgentMiddleware instance (if func is provided directly) or a decorator function
|
|
246
|
+
that can be applied to a function its wrapping.
|
|
247
|
+
|
|
248
|
+
The decorated function should return:
|
|
249
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
250
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
251
|
+
- `None` - No state updates or flow control
|
|
252
|
+
|
|
253
|
+
Examples:
|
|
254
|
+
Basic usage with state only:
|
|
255
|
+
```python
|
|
256
|
+
@before_model
|
|
257
|
+
def log_before_model(state: AgentState) -> None:
|
|
258
|
+
print(f"About to call model with {len(state['messages'])} messages")
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
Advanced usage with runtime and conditional jumping:
|
|
262
|
+
```python
|
|
263
|
+
@before_model(jump_to=["__end__"])
|
|
264
|
+
def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
265
|
+
if some_condition(state):
|
|
266
|
+
return {"jump_to": "__end__"}
|
|
267
|
+
return None
|
|
268
|
+
```
|
|
269
|
+
|
|
270
|
+
With custom state schema:
|
|
271
|
+
```python
|
|
272
|
+
@before_model(
|
|
273
|
+
state_schema=MyCustomState,
|
|
274
|
+
)
|
|
275
|
+
def custom_before_model(state: MyCustomState) -> dict[str, Any]:
|
|
276
|
+
return {"custom_field": "updated_value"}
|
|
277
|
+
```
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
|
281
|
+
if is_callable_with_runtime(func):
|
|
282
|
+
|
|
283
|
+
def wrapped_with_runtime(
|
|
284
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
285
|
+
state: StateT,
|
|
286
|
+
runtime: Runtime[ContextT],
|
|
287
|
+
) -> dict[str, Any] | Command | None:
|
|
288
|
+
return func(state, runtime)
|
|
289
|
+
|
|
290
|
+
wrapped = wrapped_with_runtime
|
|
291
|
+
else:
|
|
292
|
+
|
|
293
|
+
def wrapped_without_runtime(
|
|
294
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
295
|
+
state: StateT,
|
|
296
|
+
) -> dict[str, Any] | Command | None:
|
|
297
|
+
return func(state) # type: ignore[call-arg]
|
|
298
|
+
|
|
299
|
+
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
|
300
|
+
|
|
301
|
+
# Use function name as default if no name provided
|
|
302
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
|
|
303
|
+
|
|
304
|
+
return type(
|
|
305
|
+
middleware_name,
|
|
306
|
+
(AgentMiddleware,),
|
|
307
|
+
{
|
|
308
|
+
"state_schema": state_schema or AgentState,
|
|
309
|
+
"tools": tools or [],
|
|
310
|
+
"before_model_jump_to": jump_to or [],
|
|
311
|
+
"before_model": wrapped,
|
|
312
|
+
},
|
|
313
|
+
)()
|
|
314
|
+
|
|
315
|
+
if func is not None:
|
|
316
|
+
return decorator(func)
|
|
317
|
+
return decorator
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
@overload
|
|
321
|
+
def modify_model_request(
|
|
322
|
+
func: _ModelRequestSignature[StateT, ContextT],
|
|
323
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@overload
|
|
327
|
+
def modify_model_request(
|
|
328
|
+
func: None = None,
|
|
329
|
+
*,
|
|
330
|
+
state_schema: type[StateT] | None = None,
|
|
331
|
+
tools: list[BaseTool] | None = None,
|
|
332
|
+
name: str | None = None,
|
|
333
|
+
) -> Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def modify_model_request(
|
|
337
|
+
func: _ModelRequestSignature[StateT, ContextT] | None = None,
|
|
338
|
+
*,
|
|
339
|
+
state_schema: type[StateT] | None = None,
|
|
340
|
+
tools: list[BaseTool] | None = None,
|
|
341
|
+
name: str | None = None,
|
|
342
|
+
) -> (
|
|
343
|
+
Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
344
|
+
| AgentMiddleware[StateT, ContextT]
|
|
345
|
+
):
|
|
346
|
+
r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
func: The function to be decorated. Can accept either:
|
|
350
|
+
- `request: ModelRequest, state: StateT` - Model request and agent state
|
|
351
|
+
- `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
|
|
352
|
+
Model request, state, and runtime context
|
|
353
|
+
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
354
|
+
AgentState schema.
|
|
355
|
+
tools: Optional list of additional tools to register with this middleware.
|
|
356
|
+
name: Optional name for the generated middleware class. If not provided,
|
|
357
|
+
uses the decorated function's name.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
361
|
+
that can be applied to a function.
|
|
362
|
+
|
|
363
|
+
The decorated function should return:
|
|
364
|
+
- `ModelRequest` - The modified model request to be sent to the language model
|
|
365
|
+
|
|
366
|
+
Examples:
|
|
367
|
+
Basic usage to modify system prompt:
|
|
368
|
+
```python
|
|
369
|
+
@modify_model_request
|
|
370
|
+
def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
|
|
371
|
+
if request.system_prompt:
|
|
372
|
+
request.system_prompt += "\n\nAdditional context: ..."
|
|
373
|
+
else:
|
|
374
|
+
request.system_prompt = "Additional context: ..."
|
|
375
|
+
return request
|
|
376
|
+
```
|
|
377
|
+
|
|
378
|
+
Advanced usage with runtime and custom model settings:
|
|
379
|
+
```python
|
|
380
|
+
@modify_model_request
|
|
381
|
+
def dynamic_model_settings(
|
|
382
|
+
request: ModelRequest, state: AgentState, runtime: Runtime
|
|
383
|
+
) -> ModelRequest:
|
|
384
|
+
# Use a different model based on user subscription tier
|
|
385
|
+
if runtime.context.get("subscription_tier") == "premium":
|
|
386
|
+
request.model = "gpt-4o"
|
|
387
|
+
else:
|
|
388
|
+
request.model = "gpt-4o-mini"
|
|
389
|
+
|
|
390
|
+
return request
|
|
391
|
+
```
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
def decorator(
|
|
395
|
+
func: _ModelRequestSignature[StateT, ContextT],
|
|
396
|
+
) -> AgentMiddleware[StateT, ContextT]:
|
|
397
|
+
if is_callable_with_runtime_and_request(func):
|
|
398
|
+
|
|
399
|
+
def wrapped_with_runtime(
|
|
400
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
401
|
+
request: ModelRequest,
|
|
402
|
+
state: StateT,
|
|
403
|
+
runtime: Runtime[ContextT],
|
|
404
|
+
) -> ModelRequest:
|
|
405
|
+
return func(request, state, runtime)
|
|
406
|
+
|
|
407
|
+
wrapped = wrapped_with_runtime
|
|
408
|
+
else:
|
|
409
|
+
|
|
410
|
+
def wrapped_without_runtime(
|
|
411
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
412
|
+
request: ModelRequest,
|
|
413
|
+
state: StateT,
|
|
414
|
+
) -> ModelRequest:
|
|
415
|
+
return func(request, state) # type: ignore[call-arg]
|
|
416
|
+
|
|
417
|
+
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
|
418
|
+
|
|
419
|
+
# Use function name as default if no name provided
|
|
420
|
+
middleware_name = name or cast(
|
|
421
|
+
"str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return type(
|
|
425
|
+
middleware_name,
|
|
426
|
+
(AgentMiddleware,),
|
|
427
|
+
{
|
|
428
|
+
"state_schema": state_schema or AgentState,
|
|
429
|
+
"tools": tools or [],
|
|
430
|
+
"modify_model_request": wrapped,
|
|
431
|
+
},
|
|
432
|
+
)()
|
|
433
|
+
|
|
434
|
+
if func is not None:
|
|
435
|
+
return decorator(func)
|
|
436
|
+
return decorator
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
@overload
|
|
440
|
+
def after_model(
|
|
441
|
+
func: _NodeSignature[StateT, ContextT],
|
|
442
|
+
) -> AgentMiddleware[StateT, ContextT]: ...
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
@overload
|
|
446
|
+
def after_model(
|
|
447
|
+
func: None = None,
|
|
448
|
+
*,
|
|
449
|
+
state_schema: type[StateT] | None = None,
|
|
450
|
+
tools: list[BaseTool] | None = None,
|
|
451
|
+
jump_to: list[JumpTo] | None = None,
|
|
452
|
+
name: str | None = None,
|
|
453
|
+
) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def after_model(
|
|
457
|
+
func: _NodeSignature[StateT, ContextT] | None = None,
|
|
458
|
+
*,
|
|
459
|
+
state_schema: type[StateT] | None = None,
|
|
460
|
+
tools: list[BaseTool] | None = None,
|
|
461
|
+
jump_to: list[JumpTo] | None = None,
|
|
462
|
+
name: str | None = None,
|
|
463
|
+
) -> (
|
|
464
|
+
Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
|
|
465
|
+
| AgentMiddleware[StateT, ContextT]
|
|
466
|
+
):
|
|
467
|
+
"""Decorator used to dynamically create a middleware with the after_model hook.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
func: The function to be decorated. Can accept either:
|
|
471
|
+
- `state: StateT` - Just the agent state (includes model response)
|
|
472
|
+
- `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
|
|
473
|
+
state_schema: Optional custom state schema type. If not provided, uses the default
|
|
474
|
+
AgentState schema.
|
|
475
|
+
tools: Optional list of additional tools to register with this middleware.
|
|
476
|
+
jump_to: Optional list of valid jump destinations for conditional edges.
|
|
477
|
+
Valid values are: "tools", "model", "__end__"
|
|
478
|
+
name: Optional name for the generated middleware class. If not provided,
|
|
479
|
+
uses the decorated function's name.
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
Either an AgentMiddleware instance (if func is provided) or a decorator function
|
|
483
|
+
that can be applied to a function.
|
|
484
|
+
|
|
485
|
+
The decorated function should return:
|
|
486
|
+
- `dict[str, Any]` - State updates to merge into the agent state
|
|
487
|
+
- `Command` - A command to control flow (e.g., jump to different node)
|
|
488
|
+
- `None` - No state updates or flow control
|
|
489
|
+
|
|
490
|
+
Examples:
|
|
491
|
+
Basic usage for logging model responses:
|
|
492
|
+
```python
|
|
493
|
+
@after_model
|
|
494
|
+
def log_latest_message(state: AgentState) -> None:
|
|
495
|
+
print(state["messages"][-1].content)
|
|
496
|
+
```
|
|
497
|
+
|
|
498
|
+
With custom state schema:
|
|
499
|
+
```python
|
|
500
|
+
@after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
|
|
501
|
+
def custom_after_model(state: MyCustomState) -> dict[str, Any]:
|
|
502
|
+
return {"custom_field": "updated_after_model"}
|
|
503
|
+
```
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
|
|
507
|
+
if is_callable_with_runtime(func):
|
|
508
|
+
|
|
509
|
+
def wrapped_with_runtime(
|
|
510
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
511
|
+
state: StateT,
|
|
512
|
+
runtime: Runtime[ContextT],
|
|
513
|
+
) -> dict[str, Any] | Command | None:
|
|
514
|
+
return func(state, runtime)
|
|
515
|
+
|
|
516
|
+
wrapped = wrapped_with_runtime
|
|
517
|
+
else:
|
|
518
|
+
|
|
519
|
+
def wrapped_without_runtime(
|
|
520
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
521
|
+
state: StateT,
|
|
522
|
+
) -> dict[str, Any] | Command | None:
|
|
523
|
+
return func(state) # type: ignore[call-arg]
|
|
524
|
+
|
|
525
|
+
wrapped = wrapped_without_runtime # type: ignore[assignment]
|
|
526
|
+
|
|
527
|
+
# Use function name as default if no name provided
|
|
528
|
+
middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
|
|
529
|
+
|
|
530
|
+
return type(
|
|
531
|
+
middleware_name,
|
|
532
|
+
(AgentMiddleware,),
|
|
533
|
+
{
|
|
534
|
+
"state_schema": state_schema or AgentState,
|
|
535
|
+
"tools": tools or [],
|
|
536
|
+
"after_model_jump_to": jump_to or [],
|
|
537
|
+
"after_model": wrapped,
|
|
538
|
+
},
|
|
539
|
+
)()
|
|
540
|
+
|
|
541
|
+
if func is not None:
|
|
542
|
+
return decorator(func)
|
|
543
|
+
return decorator
|
|
@@ -464,7 +464,7 @@ def create_agent( # noqa: PLR0915
|
|
|
464
464
|
f"{middleware_w_after[0].__class__.__name__}.after_model",
|
|
465
465
|
END,
|
|
466
466
|
first_node,
|
|
467
|
-
|
|
467
|
+
jump_to=middleware_w_after[0].after_model_jump_to,
|
|
468
468
|
)
|
|
469
469
|
|
|
470
470
|
# Add middleware edges (same as before)
|
|
@@ -475,7 +475,7 @@ def create_agent( # noqa: PLR0915
|
|
|
475
475
|
f"{m1.__class__.__name__}.before_model",
|
|
476
476
|
f"{m2.__class__.__name__}.before_model",
|
|
477
477
|
first_node,
|
|
478
|
-
|
|
478
|
+
jump_to=m1.before_model_jump_to,
|
|
479
479
|
)
|
|
480
480
|
# Go directly to model_request after the last before_model
|
|
481
481
|
_add_middleware_edge(
|
|
@@ -483,7 +483,7 @@ def create_agent( # noqa: PLR0915
|
|
|
483
483
|
f"{middleware_w_before[-1].__class__.__name__}.before_model",
|
|
484
484
|
"model_request",
|
|
485
485
|
first_node,
|
|
486
|
-
|
|
486
|
+
jump_to=middleware_w_before[-1].before_model_jump_to,
|
|
487
487
|
)
|
|
488
488
|
|
|
489
489
|
if middleware_w_after:
|
|
@@ -496,7 +496,7 @@ def create_agent( # noqa: PLR0915
|
|
|
496
496
|
f"{m1.__class__.__name__}.after_model",
|
|
497
497
|
f"{m2.__class__.__name__}.after_model",
|
|
498
498
|
first_node,
|
|
499
|
-
|
|
499
|
+
jump_to=m1.after_model_jump_to,
|
|
500
500
|
)
|
|
501
501
|
|
|
502
502
|
return graph
|
|
@@ -528,8 +528,8 @@ def _fetch_last_ai_and_tool_messages(
|
|
|
528
528
|
|
|
529
529
|
def _make_model_to_tools_edge(
|
|
530
530
|
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
|
|
531
|
-
) -> Callable[[
|
|
532
|
-
def model_to_tools(state:
|
|
531
|
+
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
|
|
532
|
+
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
|
|
533
533
|
if jump_to := state.get("jump_to"):
|
|
534
534
|
return _resolve_jump(jump_to, first_node)
|
|
535
535
|
|
|
@@ -548,8 +548,7 @@ def _make_model_to_tools_edge(
|
|
|
548
548
|
# of using Send w/ tool calls directly which allows more intuitive interrupt behavior
|
|
549
549
|
# largely internal so can be fixed later
|
|
550
550
|
pending_tool_calls = [
|
|
551
|
-
tool_node.inject_tool_args(call, state, None)
|
|
552
|
-
for call in pending_tool_calls
|
|
551
|
+
tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
|
|
553
552
|
]
|
|
554
553
|
return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
|
|
555
554
|
|
|
@@ -560,8 +559,8 @@ def _make_model_to_tools_edge(
|
|
|
560
559
|
|
|
561
560
|
def _make_tools_to_model_edge(
|
|
562
561
|
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
|
|
563
|
-
) -> Callable[[
|
|
564
|
-
def tools_to_model(state:
|
|
562
|
+
) -> Callable[[dict[str, Any]], str | None]:
|
|
563
|
+
def tools_to_model(state: dict[str, Any]) -> str | None:
|
|
565
564
|
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
|
|
566
565
|
|
|
567
566
|
if all(
|
|
@@ -584,7 +583,7 @@ def _add_middleware_edge(
|
|
|
584
583
|
name: str,
|
|
585
584
|
default_destination: str,
|
|
586
585
|
model_destination: str,
|
|
587
|
-
|
|
586
|
+
jump_to: list[JumpTo] | None,
|
|
588
587
|
) -> None:
|
|
589
588
|
"""Add an edge to the graph for a middleware node.
|
|
590
589
|
|
|
@@ -594,18 +593,23 @@ def _add_middleware_edge(
|
|
|
594
593
|
name: The name of the middleware node.
|
|
595
594
|
default_destination: The default destination for the edge.
|
|
596
595
|
model_destination: The destination for the edge to the model.
|
|
597
|
-
|
|
596
|
+
jump_to: The conditionally jumpable destinations for the edge.
|
|
598
597
|
"""
|
|
598
|
+
if jump_to:
|
|
599
|
+
|
|
600
|
+
def jump_edge(state: dict[str, Any]) -> str:
|
|
601
|
+
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
|
599
602
|
|
|
600
|
-
|
|
601
|
-
return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
|
|
603
|
+
destinations = [default_destination]
|
|
602
604
|
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
destinations.append(model_destination)
|
|
605
|
+
if "__end__" in jump_to:
|
|
606
|
+
destinations.append(END)
|
|
607
|
+
if "tools" in jump_to:
|
|
608
|
+
destinations.append("tools")
|
|
609
|
+
if "model" in jump_to and name != model_destination:
|
|
610
|
+
destinations.append(model_destination)
|
|
610
611
|
|
|
611
|
-
|
|
612
|
+
graph.add_conditional_edges(name, jump_edge, destinations)
|
|
613
|
+
|
|
614
|
+
else:
|
|
615
|
+
graph.add_edge(name, default_destination)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
langchain-1.0.
|
|
2
|
-
langchain-1.0.
|
|
3
|
-
langchain-1.0.
|
|
4
|
-
langchain-1.0.
|
|
1
|
+
langchain-1.0.0a8.dist-info/METADATA,sha256=slSRaiJXGZdoNrLCLCYpJqqaPdM0keHqtv2aCBHLC7w,6259
|
|
2
|
+
langchain-1.0.0a8.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
|
|
3
|
+
langchain-1.0.0a8.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
|
|
4
|
+
langchain-1.0.0a8.dist-info/licenses/LICENSE,sha256=TsZ-TKbmch26hJssqCJhWXyGph7iFLvyFBYAa3stBHg,1067
|
|
5
5
|
langchain/__init__.py,sha256=Z6r4MjNaC6DSyiMgFSRmi8EhizelVIOb_CVJJRAVjDc,604
|
|
6
6
|
langchain/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
7
|
langchain/_internal/_documents.py,sha256=z9wAPukoASOMw4WTFFBKCCZYsvsKbo-Cq6CeHjdq9eE,1045
|
|
@@ -12,13 +12,12 @@ langchain/_internal/_utils.py,sha256=lG8X9muiRAWtQjRPudq-1x-wHbk0J3spu_rYZckVdYs
|
|
|
12
12
|
langchain/agents/__init__.py,sha256=NG2S3dic9L3i4sAD9mpgaTv6Dl4L3u45xxK6jn-I4W8,281
|
|
13
13
|
langchain/agents/_internal/__init__.py,sha256=5nNBeaeQIvv9IOQjY4_aNW8pffWzMXQgi0b6Nx-WghM,37
|
|
14
14
|
langchain/agents/_internal/_typing.py,sha256=JoWa-KL5uLNeq6yrm56wnIvhDeFnCt2fTzgUcj5zWy4,270
|
|
15
|
-
langchain/agents/middleware/__init__.py,sha256
|
|
16
|
-
langchain/agents/middleware/
|
|
17
|
-
langchain/agents/middleware/
|
|
18
|
-
langchain/agents/middleware/prompt_caching.py,sha256=du_qrBr0_kwWhdO_xggtfrEN5FTcGLKu3oYQDnSS0Do,2263
|
|
15
|
+
langchain/agents/middleware/__init__.py,sha256=-NzMTmD5ogpzlsqHGjv6SnTrfXqU3vTahGUoGDk299U,511
|
|
16
|
+
langchain/agents/middleware/human_in_the_loop.py,sha256=_6THKNzp1dvYBwBLdnZ9PXsHJP3uedn4A60ZON4xlvI,10301
|
|
17
|
+
langchain/agents/middleware/prompt_caching.py,sha256=QLoWdd9jUiXAytGqbXW0I_Mg8WgrgTBO8gOZ-s8Bx8g,3081
|
|
19
18
|
langchain/agents/middleware/summarization.py,sha256=qqEqAuJXQ5rfewhFHftHLnrX8jhdMu9dPfz0akhzfuc,10281
|
|
20
|
-
langchain/agents/middleware/types.py,sha256=
|
|
21
|
-
langchain/agents/middleware_agent.py,sha256=
|
|
19
|
+
langchain/agents/middleware/types.py,sha256=a9B6Ihx12mNTroopL1SqHxsO51ZfSkxdtkPeZXw8EJc,18606
|
|
20
|
+
langchain/agents/middleware_agent.py,sha256=fncjAFNsqZqEkDYSsBfq-goxN4GqNlQixirWRBguXhs,23847
|
|
22
21
|
langchain/agents/react_agent.py,sha256=6ZNI2dp0hTL7hTm7ao-HkQ3hmVvBQuFu9pJz0PSK_eg,49712
|
|
23
22
|
langchain/agents/structured_output.py,sha256=QWNafJx7au_jJawJgIfovnDoP8Z9mLxDZNvDX_1RRJ0,13327
|
|
24
23
|
langchain/agents/tool_node.py,sha256=QabTfIi8nGrwfzaSOeWfyHos6sgXjFTdRXexQG7u2HE,46596
|
|
@@ -36,4 +35,4 @@ langchain/storage/exceptions.py,sha256=Fl_8tON3KmByBKwXtno5WSj0-c2RiZxnhw3gv5aS2
|
|
|
36
35
|
langchain/storage/in_memory.py,sha256=ozrmu0EtaJJVSAzK_u7nzxWpr9OOscWkANHSg-qIVYQ,369
|
|
37
36
|
langchain/text_splitter.py,sha256=yxWs4secpnkfK6VZDiNJNdlYOrRZ18RQZj1S3xNQ73A,1554
|
|
38
37
|
langchain/tools/__init__.py,sha256=NYQzLxW2iI5Twu3voefVC-dJEI4Wgh7jC311CQEpvZs,252
|
|
39
|
-
langchain-1.0.
|
|
38
|
+
langchain-1.0.0a8.dist-info/RECORD,,
|
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
"""Dynamic System Prompt Middleware.
|
|
2
|
-
|
|
3
|
-
Allows setting the system prompt dynamically right before each model invocation.
|
|
4
|
-
Useful when the prompt depends on the current agent state or per-invocation context.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from __future__ import annotations
|
|
8
|
-
|
|
9
|
-
from inspect import signature
|
|
10
|
-
from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
|
|
11
|
-
|
|
12
|
-
from langgraph.typing import ContextT
|
|
13
|
-
|
|
14
|
-
from langchain.agents.middleware.types import (
|
|
15
|
-
AgentMiddleware,
|
|
16
|
-
AgentState,
|
|
17
|
-
ModelRequest,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
if TYPE_CHECKING:
|
|
21
|
-
from langgraph.runtime import Runtime
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class DynamicSystemPromptWithoutRuntime(Protocol):
|
|
25
|
-
"""Dynamic system prompt without runtime in call signature."""
|
|
26
|
-
|
|
27
|
-
def __call__(self, state: AgentState) -> str:
|
|
28
|
-
"""Return the system prompt for the next model call."""
|
|
29
|
-
...
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
|
|
33
|
-
"""Dynamic system prompt with runtime in call signature."""
|
|
34
|
-
|
|
35
|
-
def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
|
36
|
-
"""Return the system prompt for the next model call."""
|
|
37
|
-
...
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
DynamicSystemPrompt: TypeAlias = (
|
|
41
|
-
DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
|
|
42
|
-
)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class DynamicSystemPromptMiddleware(AgentMiddleware):
|
|
46
|
-
"""Dynamic System Prompt Middleware.
|
|
47
|
-
|
|
48
|
-
Allows setting the system prompt dynamically right before each model invocation.
|
|
49
|
-
Useful when the prompt depends on the current agent state or per-invocation context.
|
|
50
|
-
|
|
51
|
-
Example:
|
|
52
|
-
```python
|
|
53
|
-
from langchain.agents.middleware import DynamicSystemPromptMiddleware
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class Context(TypedDict):
|
|
57
|
-
user_name: str
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
|
|
61
|
-
user_name = runtime.context.get("user_name", "n/a")
|
|
62
|
-
return (
|
|
63
|
-
f"You are a helpful assistant. Always address the user by their name: {user_name}"
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
middleware = DynamicSystemPromptMiddleware(system_prompt)
|
|
68
|
-
```
|
|
69
|
-
"""
|
|
70
|
-
|
|
71
|
-
_accepts_runtime: bool
|
|
72
|
-
|
|
73
|
-
def __init__(
|
|
74
|
-
self,
|
|
75
|
-
dynamic_system_prompt: DynamicSystemPrompt[ContextT],
|
|
76
|
-
) -> None:
|
|
77
|
-
"""Initialize the dynamic system prompt middleware.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
dynamic_system_prompt: Function that receives the current agent state
|
|
81
|
-
and optionally runtime with context, and returns the system prompt for
|
|
82
|
-
the next model call. Returns a string.
|
|
83
|
-
"""
|
|
84
|
-
super().__init__()
|
|
85
|
-
self.dynamic_system_prompt = dynamic_system_prompt
|
|
86
|
-
self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
|
|
87
|
-
|
|
88
|
-
def modify_model_request(
|
|
89
|
-
self,
|
|
90
|
-
request: ModelRequest,
|
|
91
|
-
state: AgentState,
|
|
92
|
-
runtime: Runtime[ContextT],
|
|
93
|
-
) -> ModelRequest:
|
|
94
|
-
"""Modify the model request to include the dynamic system prompt."""
|
|
95
|
-
if self._accepts_runtime:
|
|
96
|
-
system_prompt = cast(
|
|
97
|
-
"DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
|
|
98
|
-
)(state, runtime)
|
|
99
|
-
else:
|
|
100
|
-
system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
|
|
101
|
-
state
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
request.system_prompt = system_prompt
|
|
105
|
-
return request
|
|
File without changes
|
|
File without changes
|
|
File without changes
|