langchain 1.2.3__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +55 -40
- langchain/agents/middleware/__init__.py +15 -18
- langchain/agents/middleware/_execution.py +8 -12
- langchain/agents/middleware/_redaction.py +81 -10
- langchain/agents/middleware/context_editing.py +21 -3
- langchain/agents/middleware/file_search.py +1 -1
- langchain/agents/middleware/human_in_the_loop.py +31 -7
- langchain/agents/middleware/model_call_limit.py +1 -1
- langchain/agents/middleware/model_retry.py +8 -1
- langchain/agents/middleware/pii.py +4 -4
- langchain/agents/middleware/shell_tool.py +26 -6
- langchain/agents/middleware/summarization.py +35 -10
- langchain/agents/middleware/todo.py +30 -16
- langchain/agents/middleware/tool_emulator.py +5 -5
- langchain/agents/middleware/tool_retry.py +15 -8
- langchain/agents/middleware/tool_selection.py +45 -11
- langchain/agents/middleware/types.py +110 -43
- langchain/agents/structured_output.py +43 -30
- langchain/chat_models/base.py +25 -17
- {langchain-1.2.3.dist-info → langchain-1.2.4.dist-info}/METADATA +3 -3
- langchain-1.2.4.dist-info/RECORD +36 -0
- langchain-1.2.3.dist-info/RECORD +0 -36
- {langchain-1.2.3.dist-info → langchain-1.2.4.dist-info}/WHEEL +0 -0
- {langchain-1.2.3.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -102,7 +102,9 @@ class HITLResponse(TypedDict):
|
|
|
102
102
|
class _DescriptionFactory(Protocol):
|
|
103
103
|
"""Callable that generates a description for a tool call."""
|
|
104
104
|
|
|
105
|
-
def __call__(
|
|
105
|
+
def __call__(
|
|
106
|
+
self, tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[ContextT]
|
|
107
|
+
) -> str:
|
|
106
108
|
"""Generate a description for a tool call."""
|
|
107
109
|
...
|
|
108
110
|
|
|
@@ -203,7 +205,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
|
|
203
205
|
self,
|
|
204
206
|
tool_call: ToolCall,
|
|
205
207
|
config: InterruptOnConfig,
|
|
206
|
-
state: AgentState,
|
|
208
|
+
state: AgentState[Any],
|
|
207
209
|
runtime: Runtime[ContextT],
|
|
208
210
|
) -> tuple[ActionRequest, ReviewConfig]:
|
|
209
211
|
"""Create an ActionRequest and ReviewConfig for a tool call."""
|
|
@@ -235,8 +237,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
|
|
235
237
|
|
|
236
238
|
return action_request, review_config
|
|
237
239
|
|
|
240
|
+
@staticmethod
|
|
238
241
|
def _process_decision(
|
|
239
|
-
self,
|
|
240
242
|
decision: Decision,
|
|
241
243
|
tool_call: ToolCall,
|
|
242
244
|
config: InterruptOnConfig,
|
|
@@ -277,8 +279,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
|
|
277
279
|
)
|
|
278
280
|
raise ValueError(msg)
|
|
279
281
|
|
|
280
|
-
def after_model(
|
|
281
|
-
|
|
282
|
+
def after_model(
|
|
283
|
+
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
|
284
|
+
) -> dict[str, Any] | None:
|
|
285
|
+
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
state: The current agent state.
|
|
289
|
+
runtime: The runtime context.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Updated message with the revised tool calls.
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
ValueError: If the number of human decisions does not match the number of
|
|
296
|
+
interrupted tool calls.
|
|
297
|
+
"""
|
|
282
298
|
messages = state["messages"]
|
|
283
299
|
if not messages:
|
|
284
300
|
return None
|
|
@@ -351,7 +367,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
|
|
351
367
|
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
|
352
368
|
|
|
353
369
|
async def aafter_model(
|
|
354
|
-
self, state: AgentState, runtime: Runtime[ContextT]
|
|
370
|
+
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
|
355
371
|
) -> dict[str, Any] | None:
|
|
356
|
-
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`.
|
|
372
|
+
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
state: The current agent state.
|
|
376
|
+
runtime: The runtime context.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Updated message with the revised tool calls.
|
|
380
|
+
"""
|
|
357
381
|
return self.after_model(state, runtime)
|
|
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
|
|
19
19
|
from langgraph.runtime import Runtime
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class ModelCallLimitState(AgentState):
|
|
22
|
+
class ModelCallLimitState(AgentState[Any]):
|
|
23
23
|
"""State schema for `ModelCallLimitMiddleware`.
|
|
24
24
|
|
|
25
25
|
Extends `AgentState` with model call tracking fields.
|
|
@@ -163,7 +163,8 @@ class ModelRetryMiddleware(AgentMiddleware):
|
|
|
163
163
|
self.max_delay = max_delay
|
|
164
164
|
self.jitter = jitter
|
|
165
165
|
|
|
166
|
-
|
|
166
|
+
@staticmethod
|
|
167
|
+
def _format_failure_message(exc: Exception, attempts_made: int) -> AIMessage:
|
|
167
168
|
"""Format the failure message when retries are exhausted.
|
|
168
169
|
|
|
169
170
|
Args:
|
|
@@ -218,6 +219,9 @@ class ModelRetryMiddleware(AgentMiddleware):
|
|
|
218
219
|
|
|
219
220
|
Returns:
|
|
220
221
|
`ModelResponse` or `AIMessage` (the final result).
|
|
222
|
+
|
|
223
|
+
Raises:
|
|
224
|
+
RuntimeError: If the retry loop completes without returning. (This should not happen.)
|
|
221
225
|
"""
|
|
222
226
|
# Initial attempt + retries
|
|
223
227
|
for attempt in range(self.max_retries + 1):
|
|
@@ -265,6 +269,9 @@ class ModelRetryMiddleware(AgentMiddleware):
|
|
|
265
269
|
|
|
266
270
|
Returns:
|
|
267
271
|
`ModelResponse` or `AIMessage` (the final result).
|
|
272
|
+
|
|
273
|
+
Raises:
|
|
274
|
+
RuntimeError: If the retry loop completes without returning. (This should not happen.)
|
|
268
275
|
"""
|
|
269
276
|
# Initial attempt + retries
|
|
270
277
|
for attempt in range(self.max_retries + 1):
|
|
@@ -164,7 +164,7 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
164
164
|
@override
|
|
165
165
|
def before_model(
|
|
166
166
|
self,
|
|
167
|
-
state: AgentState,
|
|
167
|
+
state: AgentState[Any],
|
|
168
168
|
runtime: Runtime,
|
|
169
169
|
) -> dict[str, Any] | None:
|
|
170
170
|
"""Check user messages and tool results for PII before model invocation.
|
|
@@ -259,7 +259,7 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
259
259
|
@hook_config(can_jump_to=["end"])
|
|
260
260
|
async def abefore_model(
|
|
261
261
|
self,
|
|
262
|
-
state: AgentState,
|
|
262
|
+
state: AgentState[Any],
|
|
263
263
|
runtime: Runtime,
|
|
264
264
|
) -> dict[str, Any] | None:
|
|
265
265
|
"""Async check user messages and tool results for PII before model invocation.
|
|
@@ -280,7 +280,7 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
280
280
|
@override
|
|
281
281
|
def after_model(
|
|
282
282
|
self,
|
|
283
|
-
state: AgentState,
|
|
283
|
+
state: AgentState[Any],
|
|
284
284
|
runtime: Runtime,
|
|
285
285
|
) -> dict[str, Any] | None:
|
|
286
286
|
"""Check AI messages for PII after model invocation.
|
|
@@ -339,7 +339,7 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
339
339
|
|
|
340
340
|
async def aafter_model(
|
|
341
341
|
self,
|
|
342
|
-
state: AgentState,
|
|
342
|
+
state: AgentState[Any],
|
|
343
343
|
runtime: Runtime,
|
|
344
344
|
) -> dict[str, Any] | None:
|
|
345
345
|
"""Async check AI messages for PII after model invocation.
|
|
@@ -78,7 +78,7 @@ class _SessionResources:
|
|
|
78
78
|
session: ShellSession
|
|
79
79
|
tempdir: tempfile.TemporaryDirectory[str] | None
|
|
80
80
|
policy: BaseExecutionPolicy
|
|
81
|
-
finalizer: weakref.finalize = field(init=False, repr=False)
|
|
81
|
+
finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg]
|
|
82
82
|
|
|
83
83
|
def __post_init__(self) -> None:
|
|
84
84
|
self.finalizer = weakref.finalize(
|
|
@@ -90,7 +90,7 @@ class _SessionResources:
|
|
|
90
90
|
)
|
|
91
91
|
|
|
92
92
|
|
|
93
|
-
class ShellToolState(AgentState):
|
|
93
|
+
class ShellToolState(AgentState[Any]):
|
|
94
94
|
"""Agent state extension for tracking shell session resources."""
|
|
95
95
|
|
|
96
96
|
shell_session_resources: NotRequired[
|
|
@@ -134,7 +134,11 @@ class ShellSession:
|
|
|
134
134
|
self._terminated = False
|
|
135
135
|
|
|
136
136
|
def start(self) -> None:
|
|
137
|
-
"""Start the shell subprocess and reader threads.
|
|
137
|
+
"""Start the shell subprocess and reader threads.
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
RuntimeError: If the shell session pipes cannot be initialized.
|
|
141
|
+
"""
|
|
138
142
|
if self._process and self._process.poll() is None:
|
|
139
143
|
return
|
|
140
144
|
|
|
@@ -604,19 +608,35 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
|
|
|
604
608
|
normalized: dict[str, str] = {}
|
|
605
609
|
for key, value in env.items():
|
|
606
610
|
if not isinstance(key, str):
|
|
607
|
-
msg = "Environment variable names must be strings."
|
|
611
|
+
msg = "Environment variable names must be strings." # type: ignore[unreachable]
|
|
608
612
|
raise TypeError(msg)
|
|
609
613
|
normalized[key] = str(value)
|
|
610
614
|
return normalized
|
|
611
615
|
|
|
612
616
|
@override
|
|
613
617
|
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
614
|
-
"""Start the shell session and run startup commands.
|
|
618
|
+
"""Start the shell session and run startup commands.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
state: The current agent state.
|
|
622
|
+
runtime: The runtime context.
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
Shell session resources to be stored in the agent state.
|
|
626
|
+
"""
|
|
615
627
|
resources = self._get_or_create_resources(state)
|
|
616
628
|
return {"shell_session_resources": resources}
|
|
617
629
|
|
|
618
630
|
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
|
|
619
|
-
"""Async start the shell session and run startup commands.
|
|
631
|
+
"""Async start the shell session and run startup commands.
|
|
632
|
+
|
|
633
|
+
Args:
|
|
634
|
+
state: The current agent state.
|
|
635
|
+
runtime: The runtime context.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
Shell session resources to be stored in the agent state.
|
|
639
|
+
"""
|
|
620
640
|
return self.before_agent(state, runtime)
|
|
621
641
|
|
|
622
642
|
@override
|
|
@@ -269,8 +269,16 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
269
269
|
raise ValueError(msg)
|
|
270
270
|
|
|
271
271
|
@override
|
|
272
|
-
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
273
|
-
"""Process messages before model invocation, potentially triggering summarization.
|
|
272
|
+
def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
|
|
273
|
+
"""Process messages before model invocation, potentially triggering summarization.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
state: The agent state.
|
|
277
|
+
runtime: The runtime environment.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
An updated state with summarized messages if summarization was performed.
|
|
281
|
+
"""
|
|
274
282
|
messages = state["messages"]
|
|
275
283
|
self._ensure_message_ids(messages)
|
|
276
284
|
|
|
@@ -297,8 +305,18 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
297
305
|
}
|
|
298
306
|
|
|
299
307
|
@override
|
|
300
|
-
async def abefore_model(
|
|
301
|
-
|
|
308
|
+
async def abefore_model(
|
|
309
|
+
self, state: AgentState[Any], runtime: Runtime
|
|
310
|
+
) -> dict[str, Any] | None:
|
|
311
|
+
"""Process messages before model invocation, potentially triggering summarization.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
state: The agent state.
|
|
315
|
+
runtime: The runtime environment.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
An updated state with summarized messages if summarization was performed.
|
|
319
|
+
"""
|
|
302
320
|
messages = state["messages"]
|
|
303
321
|
self._ensure_message_ids(messages)
|
|
304
322
|
|
|
@@ -449,7 +467,8 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
449
467
|
|
|
450
468
|
return max_input_tokens
|
|
451
469
|
|
|
452
|
-
|
|
470
|
+
@staticmethod
|
|
471
|
+
def _validate_context_size(context: ContextSize, parameter_name: str) -> ContextSize:
|
|
453
472
|
"""Validate context configuration tuples."""
|
|
454
473
|
kind, value = context
|
|
455
474
|
if kind == "fraction":
|
|
@@ -465,19 +484,24 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
465
484
|
raise ValueError(msg)
|
|
466
485
|
return context
|
|
467
486
|
|
|
468
|
-
|
|
487
|
+
@staticmethod
|
|
488
|
+
def _build_new_messages(summary: str) -> list[HumanMessage]:
|
|
469
489
|
return [
|
|
470
|
-
HumanMessage(
|
|
490
|
+
HumanMessage(
|
|
491
|
+
content=f"Here is a summary of the conversation to date:\n\n{summary}",
|
|
492
|
+
additional_kwargs={"lc_source": "summarization"},
|
|
493
|
+
)
|
|
471
494
|
]
|
|
472
495
|
|
|
473
|
-
|
|
496
|
+
@staticmethod
|
|
497
|
+
def _ensure_message_ids(messages: list[AnyMessage]) -> None:
|
|
474
498
|
"""Ensure all messages have unique IDs for the add_messages reducer."""
|
|
475
499
|
for msg in messages:
|
|
476
500
|
if msg.id is None:
|
|
477
501
|
msg.id = str(uuid.uuid4())
|
|
478
502
|
|
|
503
|
+
@staticmethod
|
|
479
504
|
def _partition_messages(
|
|
480
|
-
self,
|
|
481
505
|
conversation_messages: list[AnyMessage],
|
|
482
506
|
cutoff_index: int,
|
|
483
507
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
@@ -502,7 +526,8 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
502
526
|
target_cutoff = len(messages) - messages_to_keep
|
|
503
527
|
return self._find_safe_cutoff_point(messages, target_cutoff)
|
|
504
528
|
|
|
505
|
-
|
|
529
|
+
@staticmethod
|
|
530
|
+
def _find_safe_cutoff_point(messages: list[AnyMessage], cutoff_index: int) -> int:
|
|
506
531
|
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
|
|
507
532
|
|
|
508
533
|
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
|
|
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
|
12
12
|
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
|
|
13
13
|
from langchain_core.tools import tool
|
|
14
14
|
from langgraph.types import Command
|
|
15
|
-
from typing_extensions import NotRequired, TypedDict
|
|
15
|
+
from typing_extensions import NotRequired, TypedDict, override
|
|
16
16
|
|
|
17
17
|
from langchain.agents.middleware.types import (
|
|
18
18
|
AgentMiddleware,
|
|
@@ -35,7 +35,7 @@ class Todo(TypedDict):
|
|
|
35
35
|
"""The current status of the todo item."""
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
class PlanningState(AgentState):
|
|
38
|
+
class PlanningState(AgentState[Any]):
|
|
39
39
|
"""State schema for the todo middleware."""
|
|
40
40
|
|
|
41
41
|
todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
|
|
@@ -118,7 +118,9 @@ Writing todos takes time and tokens, use it when it is helpful for managing comp
|
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
|
|
121
|
-
def write_todos(
|
|
121
|
+
def write_todos(
|
|
122
|
+
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
|
123
|
+
) -> Command[Any]:
|
|
122
124
|
"""Create and manage a structured task list for your current work session."""
|
|
123
125
|
return Command(
|
|
124
126
|
update={
|
|
@@ -178,7 +180,7 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
178
180
|
@tool(description=self.tool_description)
|
|
179
181
|
def write_todos(
|
|
180
182
|
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
|
|
181
|
-
) -> Command:
|
|
183
|
+
) -> Command[Any]:
|
|
182
184
|
"""Create and manage a structured task list for your current work session."""
|
|
183
185
|
return Command(
|
|
184
186
|
update={
|
|
@@ -196,7 +198,16 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
196
198
|
request: ModelRequest,
|
|
197
199
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
198
200
|
) -> ModelCallResult:
|
|
199
|
-
"""Update the system message to include the todo system prompt.
|
|
201
|
+
"""Update the system message to include the todo system prompt.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
request: Model request to execute (includes state and runtime).
|
|
205
|
+
handler: Async callback that executes the model request and returns
|
|
206
|
+
`ModelResponse`.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
The model call result.
|
|
210
|
+
"""
|
|
200
211
|
if request.system_message is not None:
|
|
201
212
|
new_system_content = [
|
|
202
213
|
*request.system_message.content_blocks,
|
|
@@ -214,7 +225,16 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
214
225
|
request: ModelRequest,
|
|
215
226
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
216
227
|
) -> ModelCallResult:
|
|
217
|
-
"""Update the system message to include the todo system prompt
|
|
228
|
+
"""Update the system message to include the todo system prompt.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
request: Model request to execute (includes state and runtime).
|
|
232
|
+
handler: Async callback that executes the model request and returns
|
|
233
|
+
`ModelResponse`.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
The model call result.
|
|
237
|
+
"""
|
|
218
238
|
if request.system_message is not None:
|
|
219
239
|
new_system_content = [
|
|
220
240
|
*request.system_message.content_blocks,
|
|
@@ -227,11 +247,8 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
227
247
|
)
|
|
228
248
|
return await handler(request.override(system_message=new_system_message))
|
|
229
249
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
state: AgentState,
|
|
233
|
-
runtime: Runtime, # noqa: ARG002
|
|
234
|
-
) -> dict[str, Any] | None:
|
|
250
|
+
@override
|
|
251
|
+
def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
|
|
235
252
|
"""Check for parallel write_todos tool calls and return errors if detected.
|
|
236
253
|
|
|
237
254
|
The todo list is designed to be updated at most once per model turn. Since
|
|
@@ -280,11 +297,8 @@ class TodoListMiddleware(AgentMiddleware):
|
|
|
280
297
|
|
|
281
298
|
return None
|
|
282
299
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
state: AgentState,
|
|
286
|
-
runtime: Runtime,
|
|
287
|
-
) -> dict[str, Any] | None:
|
|
300
|
+
@override
|
|
301
|
+
async def aafter_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
|
|
288
302
|
"""Check for parallel write_todos tool calls and return errors if detected.
|
|
289
303
|
|
|
290
304
|
Async version of `after_model`. The todo list is designed to be updated at
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
7
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
8
8
|
from langchain_core.messages import HumanMessage, ToolMessage
|
|
@@ -109,8 +109,8 @@ class LLMToolEmulator(AgentMiddleware):
|
|
|
109
109
|
def wrap_tool_call(
|
|
110
110
|
self,
|
|
111
111
|
request: ToolCallRequest,
|
|
112
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
113
|
-
) -> ToolMessage | Command:
|
|
112
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
113
|
+
) -> ToolMessage | Command[Any]:
|
|
114
114
|
"""Emulate tool execution using LLM if tool should be emulated.
|
|
115
115
|
|
|
116
116
|
Args:
|
|
@@ -159,8 +159,8 @@ class LLMToolEmulator(AgentMiddleware):
|
|
|
159
159
|
async def awrap_tool_call(
|
|
160
160
|
self,
|
|
161
161
|
request: ToolCallRequest,
|
|
162
|
-
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
163
|
-
) -> ToolMessage | Command:
|
|
162
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
|
163
|
+
) -> ToolMessage | Command[Any]:
|
|
164
164
|
"""Async version of `wrap_tool_call`.
|
|
165
165
|
|
|
166
166
|
Emulate tool execution using LLM if tool should be emulated.
|
|
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|
|
5
5
|
import asyncio
|
|
6
6
|
import time
|
|
7
7
|
import warnings
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
9
|
|
|
10
10
|
from langchain_core.messages import ToolMessage
|
|
11
11
|
|
|
@@ -189,14 +189,14 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
189
189
|
|
|
190
190
|
# Handle backwards compatibility for deprecated on_failure values
|
|
191
191
|
if on_failure == "raise": # type: ignore[comparison-overlap]
|
|
192
|
-
msg = (
|
|
192
|
+
msg = ( # type: ignore[unreachable]
|
|
193
193
|
"on_failure='raise' is deprecated and will be removed in a future version. "
|
|
194
194
|
"Use on_failure='error' instead."
|
|
195
195
|
)
|
|
196
196
|
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
|
197
197
|
on_failure = "error"
|
|
198
198
|
elif on_failure == "return_message": # type: ignore[comparison-overlap]
|
|
199
|
-
msg = (
|
|
199
|
+
msg = ( # type: ignore[unreachable]
|
|
200
200
|
"on_failure='return_message' is deprecated and will be removed "
|
|
201
201
|
"in a future version. Use on_failure='continue' instead."
|
|
202
202
|
)
|
|
@@ -233,7 +233,8 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
233
233
|
return True
|
|
234
234
|
return tool_name in self._tool_filter
|
|
235
235
|
|
|
236
|
-
|
|
236
|
+
@staticmethod
|
|
237
|
+
def _format_failure_message(tool_name: str, exc: Exception, attempts_made: int) -> str:
|
|
237
238
|
"""Format the failure message when retries are exhausted.
|
|
238
239
|
|
|
239
240
|
Args:
|
|
@@ -287,8 +288,8 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
287
288
|
def wrap_tool_call(
|
|
288
289
|
self,
|
|
289
290
|
request: ToolCallRequest,
|
|
290
|
-
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
|
291
|
-
) -> ToolMessage | Command:
|
|
291
|
+
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
|
|
292
|
+
) -> ToolMessage | Command[Any]:
|
|
292
293
|
"""Intercept tool execution and retry on failure.
|
|
293
294
|
|
|
294
295
|
Args:
|
|
@@ -297,6 +298,9 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
297
298
|
|
|
298
299
|
Returns:
|
|
299
300
|
`ToolMessage` or `Command` (the final result).
|
|
301
|
+
|
|
302
|
+
Raises:
|
|
303
|
+
RuntimeError: If the retry loop completes without returning. This should not happen.
|
|
300
304
|
"""
|
|
301
305
|
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
|
302
306
|
|
|
@@ -342,8 +346,8 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
342
346
|
async def awrap_tool_call(
|
|
343
347
|
self,
|
|
344
348
|
request: ToolCallRequest,
|
|
345
|
-
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
346
|
-
) -> ToolMessage | Command:
|
|
349
|
+
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
|
350
|
+
) -> ToolMessage | Command[Any]:
|
|
347
351
|
"""Intercept and control async tool execution with retry logic.
|
|
348
352
|
|
|
349
353
|
Args:
|
|
@@ -353,6 +357,9 @@ class ToolRetryMiddleware(AgentMiddleware):
|
|
|
353
357
|
|
|
354
358
|
Returns:
|
|
355
359
|
`ToolMessage` or `Command` (the final result).
|
|
360
|
+
|
|
361
|
+
Raises:
|
|
362
|
+
RuntimeError: If the retry loop completes without returning. This should not happen.
|
|
356
363
|
"""
|
|
357
364
|
tool_name = request.tool.name if request.tool else request.tool_call["name"]
|
|
358
365
|
|
|
@@ -4,12 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import TYPE_CHECKING, Annotated, Literal, Union
|
|
8
|
-
|
|
9
|
-
if TYPE_CHECKING:
|
|
10
|
-
from collections.abc import Awaitable, Callable
|
|
11
|
-
|
|
12
|
-
from langchain.tools import BaseTool
|
|
7
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
|
|
13
8
|
|
|
14
9
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
15
10
|
from langchain_core.messages import HumanMessage
|
|
@@ -24,6 +19,11 @@ from langchain.agents.middleware.types import (
|
|
|
24
19
|
)
|
|
25
20
|
from langchain.chat_models.base import init_chat_model
|
|
26
21
|
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections.abc import Awaitable, Callable
|
|
24
|
+
|
|
25
|
+
from langchain.tools import BaseTool
|
|
26
|
+
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
29
|
DEFAULT_SYSTEM_PROMPT = (
|
|
@@ -42,7 +42,7 @@ class _SelectionRequest:
|
|
|
42
42
|
valid_tool_names: list[str]
|
|
43
43
|
|
|
44
44
|
|
|
45
|
-
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
|
|
45
|
+
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
|
46
46
|
"""Create a structured output schema for tool selection.
|
|
47
47
|
|
|
48
48
|
Args:
|
|
@@ -51,6 +51,9 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
|
|
|
51
51
|
Returns:
|
|
52
52
|
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
|
53
53
|
description.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
AssertionError: If `tools` is empty.
|
|
54
57
|
"""
|
|
55
58
|
if not tools:
|
|
56
59
|
msg = "Invalid usage: tools must be non-empty"
|
|
@@ -153,9 +156,16 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
153
156
|
def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest | None:
|
|
154
157
|
"""Prepare inputs for tool selection.
|
|
155
158
|
|
|
159
|
+
Args:
|
|
160
|
+
request: the model request.
|
|
161
|
+
|
|
156
162
|
Returns:
|
|
157
163
|
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
|
158
|
-
|
|
164
|
+
needed.
|
|
165
|
+
|
|
166
|
+
Raises:
|
|
167
|
+
ValueError: If tools in `always_include` are not found in the request.
|
|
168
|
+
AssertionError: If no user message is found in the request messages.
|
|
159
169
|
"""
|
|
160
170
|
# If no tools available, return None
|
|
161
171
|
if not request.tools or len(request.tools) == 0:
|
|
@@ -217,7 +227,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
217
227
|
|
|
218
228
|
def _process_selection_response(
|
|
219
229
|
self,
|
|
220
|
-
response: dict,
|
|
230
|
+
response: dict[str, Any],
|
|
221
231
|
available_tools: list[BaseTool],
|
|
222
232
|
valid_tool_names: list[str],
|
|
223
233
|
request: ModelRequest,
|
|
@@ -262,7 +272,19 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
262
272
|
request: ModelRequest,
|
|
263
273
|
handler: Callable[[ModelRequest], ModelResponse],
|
|
264
274
|
) -> ModelCallResult:
|
|
265
|
-
"""Filter tools based on LLM selection before invoking the model via handler.
|
|
275
|
+
"""Filter tools based on LLM selection before invoking the model via handler.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
request: Model request to execute (includes state and runtime).
|
|
279
|
+
handler: Async callback that executes the model request and returns
|
|
280
|
+
`ModelResponse`.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
The model call result.
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
AssertionError: If the selection model response is not a dict.
|
|
287
|
+
"""
|
|
266
288
|
selection_request = self._prepare_selection_request(request)
|
|
267
289
|
if selection_request is None:
|
|
268
290
|
return handler(request)
|
|
@@ -293,7 +315,19 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
|
|
|
293
315
|
request: ModelRequest,
|
|
294
316
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
295
317
|
) -> ModelCallResult:
|
|
296
|
-
"""Filter tools based on LLM selection before invoking the model via handler.
|
|
318
|
+
"""Filter tools based on LLM selection before invoking the model via handler.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
request: Model request to execute (includes state and runtime).
|
|
322
|
+
handler: Async callback that executes the model request and returns
|
|
323
|
+
`ModelResponse`.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
The model call result.
|
|
327
|
+
|
|
328
|
+
Raises:
|
|
329
|
+
AssertionError: If the selection model response is not a dict.
|
|
330
|
+
"""
|
|
297
331
|
selection_request = self._prepare_selection_request(request)
|
|
298
332
|
if selection_request is None:
|
|
299
333
|
return await handler(request)
|