langchain 1.0.0a13__py3-none-any.whl → 1.0.0a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +115 -29
- langchain/agents/middleware/__init__.py +6 -5
- langchain/agents/middleware/context_editing.py +29 -1
- langchain/agents/middleware/human_in_the_loop.py +13 -13
- langchain/agents/middleware/model_call_limit.py +38 -4
- langchain/agents/middleware/model_fallback.py +36 -1
- langchain/agents/middleware/pii.py +6 -8
- langchain/agents/middleware/{planning.py → todo.py} +18 -5
- langchain/agents/middleware/tool_call_limit.py +88 -15
- langchain/agents/middleware/types.py +196 -18
- langchain/embeddings/__init__.py +0 -2
- langchain/messages/__init__.py +32 -0
- langchain/tools/__init__.py +1 -6
- langchain/tools/tool_node.py +62 -11
- langchain-1.0.0a15.dist-info/METADATA +85 -0
- langchain-1.0.0a15.dist-info/RECORD +29 -0
- langchain/agents/middleware/prompt_caching.py +0 -89
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -122
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a13.dist-info/METADATA +0 -125
- langchain-1.0.0a13.dist-info/RECORD +0 -36
- {langchain-1.0.0a13.dist-info → langchain-1.0.0a15.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a13.dist-info → langchain-1.0.0a15.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,16 +2,37 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Literal
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
6
6
|
|
|
7
7
|
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
|
8
|
+
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
|
+
from typing_extensions import NotRequired
|
|
8
10
|
|
|
9
|
-
from langchain.agents.middleware.types import
|
|
11
|
+
from langchain.agents.middleware.types import (
|
|
12
|
+
AgentMiddleware,
|
|
13
|
+
AgentState,
|
|
14
|
+
PrivateStateAttr,
|
|
15
|
+
hook_config,
|
|
16
|
+
)
|
|
10
17
|
|
|
11
18
|
if TYPE_CHECKING:
|
|
12
19
|
from langgraph.runtime import Runtime
|
|
13
20
|
|
|
14
21
|
|
|
22
|
+
class ToolCallLimitState(AgentState):
|
|
23
|
+
"""State schema for ToolCallLimitMiddleware.
|
|
24
|
+
|
|
25
|
+
Extends AgentState with tool call tracking fields.
|
|
26
|
+
|
|
27
|
+
The count fields are dictionaries mapping tool names to execution counts.
|
|
28
|
+
This allows multiple middleware instances to track different tools independently.
|
|
29
|
+
The special key "__all__" is used for tracking all tool calls globally.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
|
|
33
|
+
run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
|
|
34
|
+
|
|
35
|
+
|
|
15
36
|
def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | None = None) -> int:
|
|
16
37
|
"""Count tool calls in a list of messages.
|
|
17
38
|
|
|
@@ -124,18 +145,18 @@ class ToolCallLimitExceededError(Exception):
|
|
|
124
145
|
super().__init__(msg)
|
|
125
146
|
|
|
126
147
|
|
|
127
|
-
class ToolCallLimitMiddleware(AgentMiddleware):
|
|
148
|
+
class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
|
|
128
149
|
"""Middleware that tracks tool call counts and enforces limits.
|
|
129
150
|
|
|
130
151
|
This middleware monitors the number of tool calls made during agent execution
|
|
131
152
|
and can terminate the agent when specified limits are reached. It supports
|
|
132
153
|
both thread-level and run-level call counting with configurable exit behaviors.
|
|
133
154
|
|
|
134
|
-
Thread-level: The middleware
|
|
135
|
-
|
|
155
|
+
Thread-level: The middleware tracks the total number of tool calls and persists
|
|
156
|
+
call count across multiple runs (invocations) of the agent.
|
|
136
157
|
|
|
137
|
-
Run-level: The middleware
|
|
138
|
-
|
|
158
|
+
Run-level: The middleware tracks the number of tool calls made during a single
|
|
159
|
+
run (invocation) of the agent.
|
|
139
160
|
|
|
140
161
|
Example:
|
|
141
162
|
```python
|
|
@@ -157,6 +178,8 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
|
|
157
178
|
```
|
|
158
179
|
"""
|
|
159
180
|
|
|
181
|
+
state_schema = ToolCallLimitState
|
|
182
|
+
|
|
160
183
|
def __init__(
|
|
161
184
|
self,
|
|
162
185
|
*,
|
|
@@ -211,11 +234,11 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
|
|
211
234
|
return base_name
|
|
212
235
|
|
|
213
236
|
@hook_config(can_jump_to=["end"])
|
|
214
|
-
def before_model(self, state:
|
|
237
|
+
def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
215
238
|
"""Check tool call limits before making a model call.
|
|
216
239
|
|
|
217
240
|
Args:
|
|
218
|
-
state: The current agent state containing
|
|
241
|
+
state: The current agent state containing tool call counts.
|
|
219
242
|
runtime: The langgraph runtime.
|
|
220
243
|
|
|
221
244
|
Returns:
|
|
@@ -226,14 +249,14 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
|
|
226
249
|
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
|
|
227
250
|
is "error".
|
|
228
251
|
"""
|
|
229
|
-
|
|
252
|
+
# Get the count key for this middleware instance
|
|
253
|
+
count_key = self.tool_name if self.tool_name else "__all__"
|
|
230
254
|
|
|
231
|
-
|
|
232
|
-
|
|
255
|
+
thread_counts = state.get("thread_tool_call_count", {})
|
|
256
|
+
run_counts = state.get("run_tool_call_count", {})
|
|
233
257
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
run_count = _count_tool_calls_in_messages(run_messages, self.tool_name)
|
|
258
|
+
thread_count = thread_counts.get(count_key, 0)
|
|
259
|
+
run_count = run_counts.get(count_key, 0)
|
|
237
260
|
|
|
238
261
|
# Check if any limits are exceeded
|
|
239
262
|
thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
|
|
@@ -258,3 +281,53 @@ class ToolCallLimitMiddleware(AgentMiddleware):
|
|
|
258
281
|
return {"jump_to": "end", "messages": [limit_ai_message]}
|
|
259
282
|
|
|
260
283
|
return None
|
|
284
|
+
|
|
285
|
+
def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
286
|
+
"""Increment tool call counts after a model call (when tool calls are made).
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
state: The current agent state.
|
|
290
|
+
runtime: The langgraph runtime.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
State updates with incremented tool call counts if tool calls were made.
|
|
294
|
+
"""
|
|
295
|
+
# Get the last AIMessage to check for tool calls
|
|
296
|
+
messages = state.get("messages", [])
|
|
297
|
+
if not messages:
|
|
298
|
+
return None
|
|
299
|
+
|
|
300
|
+
# Find the last AIMessage
|
|
301
|
+
last_ai_message = None
|
|
302
|
+
for message in reversed(messages):
|
|
303
|
+
if isinstance(message, AIMessage):
|
|
304
|
+
last_ai_message = message
|
|
305
|
+
break
|
|
306
|
+
|
|
307
|
+
if not last_ai_message or not last_ai_message.tool_calls:
|
|
308
|
+
return None
|
|
309
|
+
|
|
310
|
+
# Count relevant tool calls (filter by tool_name if specified)
|
|
311
|
+
tool_call_count = 0
|
|
312
|
+
for tool_call in last_ai_message.tool_calls:
|
|
313
|
+
if self.tool_name is None or tool_call["name"] == self.tool_name:
|
|
314
|
+
tool_call_count += 1
|
|
315
|
+
|
|
316
|
+
if tool_call_count == 0:
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
# Get the count key for this middleware instance
|
|
320
|
+
count_key = self.tool_name if self.tool_name else "__all__"
|
|
321
|
+
|
|
322
|
+
# Get current counts
|
|
323
|
+
thread_counts = state.get("thread_tool_call_count", {}).copy()
|
|
324
|
+
run_counts = state.get("run_tool_call_count", {}).copy()
|
|
325
|
+
|
|
326
|
+
# Increment counts for this key
|
|
327
|
+
thread_counts[count_key] = thread_counts.get(count_key, 0) + tool_call_count
|
|
328
|
+
run_counts[count_key] = run_counts.get(count_key, 0) + tool_call_count
|
|
329
|
+
|
|
330
|
+
return {
|
|
331
|
+
"thread_tool_call_count": thread_counts,
|
|
332
|
+
"run_tool_call_count": run_counts,
|
|
333
|
+
}
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Awaitable, Callable
|
|
6
|
-
from dataclasses import dataclass, field
|
|
6
|
+
from dataclasses import dataclass, field, replace
|
|
7
7
|
from inspect import iscoroutinefunction
|
|
8
8
|
from typing import (
|
|
9
9
|
TYPE_CHECKING,
|
|
@@ -21,16 +21,15 @@ if TYPE_CHECKING:
|
|
|
21
21
|
|
|
22
22
|
from langchain.tools.tool_node import ToolCallRequest
|
|
23
23
|
|
|
24
|
-
#
|
|
24
|
+
# Needed as top level import for Pydantic schema generation on AgentState
|
|
25
25
|
from typing import TypeAlias
|
|
26
26
|
|
|
27
27
|
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
|
|
28
28
|
from langgraph.channels.ephemeral_value import EphemeralValue
|
|
29
|
-
from langgraph.channels.untracked_value import UntrackedValue
|
|
30
29
|
from langgraph.graph.message import add_messages
|
|
31
30
|
from langgraph.types import Command # noqa: TC002
|
|
32
31
|
from langgraph.typing import ContextT
|
|
33
|
-
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
|
|
32
|
+
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
|
|
34
33
|
|
|
35
34
|
if TYPE_CHECKING:
|
|
36
35
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
@@ -62,6 +61,18 @@ JumpTo = Literal["tools", "model", "end"]
|
|
|
62
61
|
ResponseT = TypeVar("ResponseT")
|
|
63
62
|
|
|
64
63
|
|
|
64
|
+
class _ModelRequestOverrides(TypedDict, total=False):
|
|
65
|
+
"""Possible overrides for ModelRequest.override() method."""
|
|
66
|
+
|
|
67
|
+
model: BaseChatModel
|
|
68
|
+
system_prompt: str | None
|
|
69
|
+
messages: list[AnyMessage]
|
|
70
|
+
tool_choice: Any | None
|
|
71
|
+
tools: list[BaseTool | dict]
|
|
72
|
+
response_format: ResponseFormat | None
|
|
73
|
+
model_settings: dict[str, Any]
|
|
74
|
+
|
|
75
|
+
|
|
65
76
|
@dataclass
|
|
66
77
|
class ModelRequest:
|
|
67
78
|
"""Model request information for the agent."""
|
|
@@ -76,6 +87,36 @@ class ModelRequest:
|
|
|
76
87
|
runtime: Runtime[ContextT] # type: ignore[valid-type]
|
|
77
88
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
|
78
89
|
|
|
90
|
+
def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
|
|
91
|
+
"""Replace the request with a new request with the given overrides.
|
|
92
|
+
|
|
93
|
+
Returns a new `ModelRequest` instance with the specified attributes replaced.
|
|
94
|
+
This follows an immutable pattern, leaving the original request unchanged.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
**overrides: Keyword arguments for attributes to override. Supported keys:
|
|
98
|
+
- model: BaseChatModel instance
|
|
99
|
+
- system_prompt: Optional system prompt string
|
|
100
|
+
- messages: List of messages
|
|
101
|
+
- tool_choice: Tool choice configuration
|
|
102
|
+
- tools: List of available tools
|
|
103
|
+
- response_format: Response format specification
|
|
104
|
+
- model_settings: Additional model settings
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
New ModelRequest instance with specified overrides applied.
|
|
108
|
+
|
|
109
|
+
Examples:
|
|
110
|
+
```python
|
|
111
|
+
# Create a new request with different model
|
|
112
|
+
new_request = request.override(model=different_model)
|
|
113
|
+
|
|
114
|
+
# Override multiple attributes
|
|
115
|
+
new_request = request.override(system_prompt="New instructions", tool_choice="auto")
|
|
116
|
+
```
|
|
117
|
+
"""
|
|
118
|
+
return replace(self, **overrides)
|
|
119
|
+
|
|
79
120
|
|
|
80
121
|
@dataclass
|
|
81
122
|
class ModelResponse:
|
|
@@ -129,8 +170,6 @@ class AgentState(TypedDict, Generic[ResponseT]):
|
|
|
129
170
|
messages: Required[Annotated[list[AnyMessage], add_messages]]
|
|
130
171
|
jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
|
|
131
172
|
structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
|
|
132
|
-
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
|
133
|
-
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
|
|
134
173
|
|
|
135
174
|
|
|
136
175
|
class PublicAgentState(TypedDict, Generic[ResponseT]):
|
|
@@ -263,18 +302,35 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
263
302
|
return AIMessage(content="Simplified response")
|
|
264
303
|
```
|
|
265
304
|
"""
|
|
266
|
-
|
|
305
|
+
msg = (
|
|
306
|
+
"Synchronous implementation of wrap_model_call is not available. "
|
|
307
|
+
"You are likely encountering this error because you defined only the async version "
|
|
308
|
+
"(awrap_model_call) and invoked your agent in a synchronous context "
|
|
309
|
+
"(e.g., using `stream()` or `invoke()`). "
|
|
310
|
+
"To resolve this, either: "
|
|
311
|
+
"(1) subclass AgentMiddleware and implement the synchronous wrap_model_call method, "
|
|
312
|
+
"(2) use the @wrap_model_call decorator on a standalone sync function, or "
|
|
313
|
+
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
|
|
314
|
+
)
|
|
315
|
+
raise NotImplementedError(msg)
|
|
267
316
|
|
|
268
317
|
async def awrap_model_call(
|
|
269
318
|
self,
|
|
270
319
|
request: ModelRequest,
|
|
271
320
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
272
321
|
) -> ModelCallResult:
|
|
273
|
-
"""
|
|
322
|
+
"""Intercept and control async model execution via handler callback.
|
|
323
|
+
|
|
324
|
+
The handler callback executes the model request and returns a ModelResponse.
|
|
325
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
326
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
327
|
+
compose with first in list as outermost layer.
|
|
274
328
|
|
|
275
329
|
Args:
|
|
276
330
|
request: Model request to execute (includes state and runtime).
|
|
277
|
-
handler: Async callback that executes the model request.
|
|
331
|
+
handler: Async callback that executes the model request and returns ModelResponse.
|
|
332
|
+
Call this to execute the model. Can be called multiple times
|
|
333
|
+
for retry logic. Can skip calling it to short-circuit.
|
|
278
334
|
|
|
279
335
|
Returns:
|
|
280
336
|
ModelCallResult
|
|
@@ -291,7 +347,17 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
291
347
|
raise
|
|
292
348
|
```
|
|
293
349
|
"""
|
|
294
|
-
|
|
350
|
+
msg = (
|
|
351
|
+
"Asynchronous implementation of awrap_model_call is not available. "
|
|
352
|
+
"You are likely encountering this error because you defined only the sync version "
|
|
353
|
+
"(wrap_model_call) and invoked your agent in an asynchronous context "
|
|
354
|
+
"(e.g., using `astream()` or `ainvoke()`). "
|
|
355
|
+
"To resolve this, either: "
|
|
356
|
+
"(1) subclass AgentMiddleware and implement the asynchronous awrap_model_call method, "
|
|
357
|
+
"(2) use the @wrap_model_call decorator on a standalone async function, or "
|
|
358
|
+
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
|
|
359
|
+
)
|
|
360
|
+
raise NotImplementedError(msg)
|
|
295
361
|
|
|
296
362
|
def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
297
363
|
"""Logic to run after the agent execution completes."""
|
|
@@ -353,7 +419,77 @@ class AgentMiddleware(Generic[StateT, ContextT]):
|
|
|
353
419
|
continue
|
|
354
420
|
return result
|
|
355
421
|
"""
|
|
356
|
-
|
|
422
|
+
msg = (
|
|
423
|
+
"Synchronous implementation of wrap_tool_call is not available. "
|
|
424
|
+
"You are likely encountering this error because you defined only the async version "
|
|
425
|
+
"(awrap_tool_call) and invoked your agent in a synchronous context "
|
|
426
|
+
"(e.g., using `stream()` or `invoke()`). "
|
|
427
|
+
"To resolve this, either: "
|
|
428
|
+
"(1) subclass AgentMiddleware and implement the synchronous wrap_tool_call method, "
|
|
429
|
+
"(2) use the @wrap_tool_call decorator on a standalone sync function, or "
|
|
430
|
+
"(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
|
|
431
|
+
)
|
|
432
|
+
raise NotImplementedError(msg)
|
|
433
|
+
|
|
434
|
+
async def awrap_tool_call(
|
|
435
|
+
self,
|
|
436
|
+
request: ToolCallRequest,
|
|
437
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
438
|
+
) -> ToolMessage | Command:
|
|
439
|
+
"""Intercept and control async tool execution via handler callback.
|
|
440
|
+
|
|
441
|
+
The handler callback executes the tool call and returns a ToolMessage or Command.
|
|
442
|
+
Middleware can call the handler multiple times for retry logic, skip calling
|
|
443
|
+
it to short-circuit, or modify the request/response. Multiple middleware
|
|
444
|
+
compose with first in list as outermost layer.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
request: Tool call request with call dict, BaseTool, state, and runtime.
|
|
448
|
+
Access state via request.state and runtime via request.runtime.
|
|
449
|
+
handler: Async callable to execute the tool and returns ToolMessage or Command.
|
|
450
|
+
Call this to execute the tool. Can be called multiple times
|
|
451
|
+
for retry logic. Can skip calling it to short-circuit.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
ToolMessage or Command (the final result).
|
|
455
|
+
|
|
456
|
+
The handler callable can be invoked multiple times for retry logic.
|
|
457
|
+
Each call to handler is independent and stateless.
|
|
458
|
+
|
|
459
|
+
Examples:
|
|
460
|
+
Async retry on error:
|
|
461
|
+
```python
|
|
462
|
+
async def awrap_tool_call(self, request, handler):
|
|
463
|
+
for attempt in range(3):
|
|
464
|
+
try:
|
|
465
|
+
result = await handler(request)
|
|
466
|
+
if is_valid(result):
|
|
467
|
+
return result
|
|
468
|
+
except Exception:
|
|
469
|
+
if attempt == 2:
|
|
470
|
+
raise
|
|
471
|
+
return result
|
|
472
|
+
```
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
async def awrap_tool_call(self, request, handler):
|
|
476
|
+
if cached := await get_cache_async(request):
|
|
477
|
+
return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
|
|
478
|
+
result = await handler(request)
|
|
479
|
+
await save_cache_async(request, result)
|
|
480
|
+
return result
|
|
481
|
+
"""
|
|
482
|
+
msg = (
|
|
483
|
+
"Asynchronous implementation of awrap_tool_call is not available. "
|
|
484
|
+
"You are likely encountering this error because you defined only the sync version "
|
|
485
|
+
"(wrap_tool_call) and invoked your agent in an asynchronous context "
|
|
486
|
+
"(e.g., using `astream()` or `ainvoke()`). "
|
|
487
|
+
"To resolve this, either: "
|
|
488
|
+
"(1) subclass AgentMiddleware and implement the asynchronous awrap_tool_call method, "
|
|
489
|
+
"(2) use the @wrap_tool_call decorator on a standalone async function, or "
|
|
490
|
+
"(3) invoke your agent synchronously using `stream()` or `invoke()`."
|
|
491
|
+
)
|
|
492
|
+
raise NotImplementedError(msg)
|
|
357
493
|
|
|
358
494
|
|
|
359
495
|
class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
|
|
@@ -1104,6 +1240,16 @@ def dynamic_prompt(
|
|
|
1104
1240
|
request.system_prompt = prompt
|
|
1105
1241
|
return handler(request)
|
|
1106
1242
|
|
|
1243
|
+
async def async_wrapped_from_sync(
|
|
1244
|
+
self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
|
|
1245
|
+
request: ModelRequest,
|
|
1246
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
1247
|
+
) -> ModelCallResult:
|
|
1248
|
+
# Delegate to sync function
|
|
1249
|
+
prompt = cast("str", func(request))
|
|
1250
|
+
request.system_prompt = prompt
|
|
1251
|
+
return await handler(request)
|
|
1252
|
+
|
|
1107
1253
|
middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
|
|
1108
1254
|
|
|
1109
1255
|
return type(
|
|
@@ -1113,6 +1259,7 @@ def dynamic_prompt(
|
|
|
1113
1259
|
"state_schema": AgentState,
|
|
1114
1260
|
"tools": [],
|
|
1115
1261
|
"wrap_model_call": wrapped,
|
|
1262
|
+
"awrap_model_call": async_wrapped_from_sync,
|
|
1116
1263
|
},
|
|
1117
1264
|
)()
|
|
1118
1265
|
|
|
@@ -1309,6 +1456,7 @@ def wrap_tool_call(
|
|
|
1309
1456
|
Args:
|
|
1310
1457
|
func: Function accepting (request, handler) that calls
|
|
1311
1458
|
handler(request) to execute the tool and returns final ToolMessage or Command.
|
|
1459
|
+
Can be sync or async.
|
|
1312
1460
|
tools: Additional tools to register with this middleware.
|
|
1313
1461
|
name: Middleware class name. Defaults to function name.
|
|
1314
1462
|
|
|
@@ -1316,13 +1464,6 @@ def wrap_tool_call(
|
|
|
1316
1464
|
AgentMiddleware instance if func provided, otherwise a decorator.
|
|
1317
1465
|
|
|
1318
1466
|
Examples:
|
|
1319
|
-
Basic passthrough:
|
|
1320
|
-
```python
|
|
1321
|
-
@wrap_tool_call
|
|
1322
|
-
def passthrough(request, handler):
|
|
1323
|
-
return handler(request)
|
|
1324
|
-
```
|
|
1325
|
-
|
|
1326
1467
|
Retry logic:
|
|
1327
1468
|
```python
|
|
1328
1469
|
@wrap_tool_call
|
|
@@ -1336,6 +1477,18 @@ def wrap_tool_call(
|
|
|
1336
1477
|
raise
|
|
1337
1478
|
```
|
|
1338
1479
|
|
|
1480
|
+
Async retry logic:
|
|
1481
|
+
```python
|
|
1482
|
+
@wrap_tool_call
|
|
1483
|
+
async def async_retry(request, handler):
|
|
1484
|
+
for attempt in range(3):
|
|
1485
|
+
try:
|
|
1486
|
+
return await handler(request)
|
|
1487
|
+
except Exception:
|
|
1488
|
+
if attempt == 2:
|
|
1489
|
+
raise
|
|
1490
|
+
```
|
|
1491
|
+
|
|
1339
1492
|
Modify request:
|
|
1340
1493
|
```python
|
|
1341
1494
|
@wrap_tool_call
|
|
@@ -1359,6 +1512,31 @@ def wrap_tool_call(
|
|
|
1359
1512
|
def decorator(
|
|
1360
1513
|
func: _CallableReturningToolResponse,
|
|
1361
1514
|
) -> AgentMiddleware:
|
|
1515
|
+
is_async = iscoroutinefunction(func)
|
|
1516
|
+
|
|
1517
|
+
if is_async:
|
|
1518
|
+
|
|
1519
|
+
async def async_wrapped(
|
|
1520
|
+
self: AgentMiddleware, # noqa: ARG001
|
|
1521
|
+
request: ToolCallRequest,
|
|
1522
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
1523
|
+
) -> ToolMessage | Command:
|
|
1524
|
+
return await func(request, handler) # type: ignore[arg-type,misc]
|
|
1525
|
+
|
|
1526
|
+
middleware_name = name or cast(
|
|
1527
|
+
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
|
|
1528
|
+
)
|
|
1529
|
+
|
|
1530
|
+
return type(
|
|
1531
|
+
middleware_name,
|
|
1532
|
+
(AgentMiddleware,),
|
|
1533
|
+
{
|
|
1534
|
+
"state_schema": AgentState,
|
|
1535
|
+
"tools": tools or [],
|
|
1536
|
+
"awrap_tool_call": async_wrapped,
|
|
1537
|
+
},
|
|
1538
|
+
)()
|
|
1539
|
+
|
|
1362
1540
|
def wrapped(
|
|
1363
1541
|
self: AgentMiddleware, # noqa: ARG001
|
|
1364
1542
|
request: ToolCallRequest,
|
langchain/embeddings/__init__.py
CHANGED
langchain/messages/__init__.py
CHANGED
|
@@ -3,29 +3,61 @@
|
|
|
3
3
|
from langchain_core.messages import (
|
|
4
4
|
AIMessage,
|
|
5
5
|
AIMessageChunk,
|
|
6
|
+
Annotation,
|
|
6
7
|
AnyMessage,
|
|
8
|
+
AudioContentBlock,
|
|
9
|
+
Citation,
|
|
10
|
+
ContentBlock,
|
|
11
|
+
DataContentBlock,
|
|
12
|
+
FileContentBlock,
|
|
7
13
|
HumanMessage,
|
|
14
|
+
ImageContentBlock,
|
|
8
15
|
InvalidToolCall,
|
|
9
16
|
MessageLikeRepresentation,
|
|
17
|
+
NonStandardAnnotation,
|
|
18
|
+
NonStandardContentBlock,
|
|
19
|
+
PlainTextContentBlock,
|
|
20
|
+
ReasoningContentBlock,
|
|
10
21
|
RemoveMessage,
|
|
22
|
+
ServerToolCall,
|
|
23
|
+
ServerToolCallChunk,
|
|
24
|
+
ServerToolResult,
|
|
11
25
|
SystemMessage,
|
|
26
|
+
TextContentBlock,
|
|
12
27
|
ToolCall,
|
|
13
28
|
ToolCallChunk,
|
|
14
29
|
ToolMessage,
|
|
30
|
+
VideoContentBlock,
|
|
15
31
|
trim_messages,
|
|
16
32
|
)
|
|
17
33
|
|
|
18
34
|
__all__ = [
|
|
19
35
|
"AIMessage",
|
|
20
36
|
"AIMessageChunk",
|
|
37
|
+
"Annotation",
|
|
21
38
|
"AnyMessage",
|
|
39
|
+
"AudioContentBlock",
|
|
40
|
+
"Citation",
|
|
41
|
+
"ContentBlock",
|
|
42
|
+
"DataContentBlock",
|
|
43
|
+
"FileContentBlock",
|
|
22
44
|
"HumanMessage",
|
|
45
|
+
"ImageContentBlock",
|
|
23
46
|
"InvalidToolCall",
|
|
24
47
|
"MessageLikeRepresentation",
|
|
48
|
+
"NonStandardAnnotation",
|
|
49
|
+
"NonStandardContentBlock",
|
|
50
|
+
"PlainTextContentBlock",
|
|
51
|
+
"ReasoningContentBlock",
|
|
25
52
|
"RemoveMessage",
|
|
53
|
+
"ServerToolCall",
|
|
54
|
+
"ServerToolCallChunk",
|
|
55
|
+
"ServerToolResult",
|
|
26
56
|
"SystemMessage",
|
|
57
|
+
"TextContentBlock",
|
|
27
58
|
"ToolCall",
|
|
28
59
|
"ToolCallChunk",
|
|
29
60
|
"ToolMessage",
|
|
61
|
+
"VideoContentBlock",
|
|
30
62
|
"trim_messages",
|
|
31
63
|
]
|
langchain/tools/__init__.py
CHANGED
|
@@ -8,11 +8,7 @@ from langchain_core.tools import (
|
|
|
8
8
|
tool,
|
|
9
9
|
)
|
|
10
10
|
|
|
11
|
-
from langchain.tools.tool_node import
|
|
12
|
-
InjectedState,
|
|
13
|
-
InjectedStore,
|
|
14
|
-
ToolNode,
|
|
15
|
-
)
|
|
11
|
+
from langchain.tools.tool_node import InjectedState, InjectedStore
|
|
16
12
|
|
|
17
13
|
__all__ = [
|
|
18
14
|
"BaseTool",
|
|
@@ -21,6 +17,5 @@ __all__ = [
|
|
|
21
17
|
"InjectedToolArg",
|
|
22
18
|
"InjectedToolCallId",
|
|
23
19
|
"ToolException",
|
|
24
|
-
"ToolNode",
|
|
25
20
|
"tool",
|
|
26
21
|
]
|