langchain 1.2.3__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -102,7 +102,9 @@ class HITLResponse(TypedDict):
102
102
  class _DescriptionFactory(Protocol):
103
103
  """Callable that generates a description for a tool call."""
104
104
 
105
- def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
105
+ def __call__(
106
+ self, tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[ContextT]
107
+ ) -> str:
106
108
  """Generate a description for a tool call."""
107
109
  ...
108
110
 
@@ -203,7 +205,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
203
205
  self,
204
206
  tool_call: ToolCall,
205
207
  config: InterruptOnConfig,
206
- state: AgentState,
208
+ state: AgentState[Any],
207
209
  runtime: Runtime[ContextT],
208
210
  ) -> tuple[ActionRequest, ReviewConfig]:
209
211
  """Create an ActionRequest and ReviewConfig for a tool call."""
@@ -235,8 +237,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
235
237
 
236
238
  return action_request, review_config
237
239
 
240
+ @staticmethod
238
241
  def _process_decision(
239
- self,
240
242
  decision: Decision,
241
243
  tool_call: ToolCall,
242
244
  config: InterruptOnConfig,
@@ -277,8 +279,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
277
279
  )
278
280
  raise ValueError(msg)
279
281
 
280
- def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
281
- """Trigger interrupt flows for relevant tool calls after an `AIMessage`."""
282
+ def after_model(
283
+ self, state: AgentState[Any], runtime: Runtime[ContextT]
284
+ ) -> dict[str, Any] | None:
285
+ """Trigger interrupt flows for relevant tool calls after an `AIMessage`.
286
+
287
+ Args:
288
+ state: The current agent state.
289
+ runtime: The runtime context.
290
+
291
+ Returns:
292
+ Updated message with the revised tool calls.
293
+
294
+ Raises:
295
+ ValueError: If the number of human decisions does not match the number of
296
+ interrupted tool calls.
297
+ """
282
298
  messages = state["messages"]
283
299
  if not messages:
284
300
  return None
@@ -351,7 +367,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
351
367
  return {"messages": [last_ai_msg, *artificial_tool_messages]}
352
368
 
353
369
  async def aafter_model(
354
- self, state: AgentState, runtime: Runtime[ContextT]
370
+ self, state: AgentState[Any], runtime: Runtime[ContextT]
355
371
  ) -> dict[str, Any] | None:
356
- """Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
372
+ """Async trigger interrupt flows for relevant tool calls after an `AIMessage`.
373
+
374
+ Args:
375
+ state: The current agent state.
376
+ runtime: The runtime context.
377
+
378
+ Returns:
379
+ Updated message with the revised tool calls.
380
+ """
357
381
  return self.after_model(state, runtime)
@@ -19,7 +19,7 @@ if TYPE_CHECKING:
19
19
  from langgraph.runtime import Runtime
20
20
 
21
21
 
22
- class ModelCallLimitState(AgentState):
22
+ class ModelCallLimitState(AgentState[Any]):
23
23
  """State schema for `ModelCallLimitMiddleware`.
24
24
 
25
25
  Extends `AgentState` with model call tracking fields.
@@ -163,7 +163,8 @@ class ModelRetryMiddleware(AgentMiddleware):
163
163
  self.max_delay = max_delay
164
164
  self.jitter = jitter
165
165
 
166
- def _format_failure_message(self, exc: Exception, attempts_made: int) -> AIMessage:
166
+ @staticmethod
167
+ def _format_failure_message(exc: Exception, attempts_made: int) -> AIMessage:
167
168
  """Format the failure message when retries are exhausted.
168
169
 
169
170
  Args:
@@ -218,6 +219,9 @@ class ModelRetryMiddleware(AgentMiddleware):
218
219
 
219
220
  Returns:
220
221
  `ModelResponse` or `AIMessage` (the final result).
222
+
223
+ Raises:
224
+ RuntimeError: If the retry loop completes without returning. (This should not happen.)
221
225
  """
222
226
  # Initial attempt + retries
223
227
  for attempt in range(self.max_retries + 1):
@@ -265,6 +269,9 @@ class ModelRetryMiddleware(AgentMiddleware):
265
269
 
266
270
  Returns:
267
271
  `ModelResponse` or `AIMessage` (the final result).
272
+
273
+ Raises:
274
+ RuntimeError: If the retry loop completes without returning. (This should not happen.)
268
275
  """
269
276
  # Initial attempt + retries
270
277
  for attempt in range(self.max_retries + 1):
@@ -164,7 +164,7 @@ class PIIMiddleware(AgentMiddleware):
164
164
  @override
165
165
  def before_model(
166
166
  self,
167
- state: AgentState,
167
+ state: AgentState[Any],
168
168
  runtime: Runtime,
169
169
  ) -> dict[str, Any] | None:
170
170
  """Check user messages and tool results for PII before model invocation.
@@ -259,7 +259,7 @@ class PIIMiddleware(AgentMiddleware):
259
259
  @hook_config(can_jump_to=["end"])
260
260
  async def abefore_model(
261
261
  self,
262
- state: AgentState,
262
+ state: AgentState[Any],
263
263
  runtime: Runtime,
264
264
  ) -> dict[str, Any] | None:
265
265
  """Async check user messages and tool results for PII before model invocation.
@@ -280,7 +280,7 @@ class PIIMiddleware(AgentMiddleware):
280
280
  @override
281
281
  def after_model(
282
282
  self,
283
- state: AgentState,
283
+ state: AgentState[Any],
284
284
  runtime: Runtime,
285
285
  ) -> dict[str, Any] | None:
286
286
  """Check AI messages for PII after model invocation.
@@ -339,7 +339,7 @@ class PIIMiddleware(AgentMiddleware):
339
339
 
340
340
  async def aafter_model(
341
341
  self,
342
- state: AgentState,
342
+ state: AgentState[Any],
343
343
  runtime: Runtime,
344
344
  ) -> dict[str, Any] | None:
345
345
  """Async check AI messages for PII after model invocation.
@@ -78,7 +78,7 @@ class _SessionResources:
78
78
  session: ShellSession
79
79
  tempdir: tempfile.TemporaryDirectory[str] | None
80
80
  policy: BaseExecutionPolicy
81
- finalizer: weakref.finalize = field(init=False, repr=False)
81
+ finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg]
82
82
 
83
83
  def __post_init__(self) -> None:
84
84
  self.finalizer = weakref.finalize(
@@ -90,7 +90,7 @@ class _SessionResources:
90
90
  )
91
91
 
92
92
 
93
- class ShellToolState(AgentState):
93
+ class ShellToolState(AgentState[Any]):
94
94
  """Agent state extension for tracking shell session resources."""
95
95
 
96
96
  shell_session_resources: NotRequired[
@@ -134,7 +134,11 @@ class ShellSession:
134
134
  self._terminated = False
135
135
 
136
136
  def start(self) -> None:
137
- """Start the shell subprocess and reader threads."""
137
+ """Start the shell subprocess and reader threads.
138
+
139
+ Raises:
140
+ RuntimeError: If the shell session pipes cannot be initialized.
141
+ """
138
142
  if self._process and self._process.poll() is None:
139
143
  return
140
144
 
@@ -604,19 +608,35 @@ class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
604
608
  normalized: dict[str, str] = {}
605
609
  for key, value in env.items():
606
610
  if not isinstance(key, str):
607
- msg = "Environment variable names must be strings."
611
+ msg = "Environment variable names must be strings." # type: ignore[unreachable]
608
612
  raise TypeError(msg)
609
613
  normalized[key] = str(value)
610
614
  return normalized
611
615
 
612
616
  @override
613
617
  def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
614
- """Start the shell session and run startup commands."""
618
+ """Start the shell session and run startup commands.
619
+
620
+ Args:
621
+ state: The current agent state.
622
+ runtime: The runtime context.
623
+
624
+ Returns:
625
+ Shell session resources to be stored in the agent state.
626
+ """
615
627
  resources = self._get_or_create_resources(state)
616
628
  return {"shell_session_resources": resources}
617
629
 
618
630
  async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
619
- """Async start the shell session and run startup commands."""
631
+ """Async start the shell session and run startup commands.
632
+
633
+ Args:
634
+ state: The current agent state.
635
+ runtime: The runtime context.
636
+
637
+ Returns:
638
+ Shell session resources to be stored in the agent state.
639
+ """
620
640
  return self.before_agent(state, runtime)
621
641
 
622
642
  @override
@@ -269,8 +269,16 @@ class SummarizationMiddleware(AgentMiddleware):
269
269
  raise ValueError(msg)
270
270
 
271
271
  @override
272
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
273
- """Process messages before model invocation, potentially triggering summarization."""
272
+ def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
273
+ """Process messages before model invocation, potentially triggering summarization.
274
+
275
+ Args:
276
+ state: The agent state.
277
+ runtime: The runtime environment.
278
+
279
+ Returns:
280
+ An updated state with summarized messages if summarization was performed.
281
+ """
274
282
  messages = state["messages"]
275
283
  self._ensure_message_ids(messages)
276
284
 
@@ -297,8 +305,18 @@ class SummarizationMiddleware(AgentMiddleware):
297
305
  }
298
306
 
299
307
  @override
300
- async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
301
- """Process messages before model invocation, potentially triggering summarization."""
308
+ async def abefore_model(
309
+ self, state: AgentState[Any], runtime: Runtime
310
+ ) -> dict[str, Any] | None:
311
+ """Process messages before model invocation, potentially triggering summarization.
312
+
313
+ Args:
314
+ state: The agent state.
315
+ runtime: The runtime environment.
316
+
317
+ Returns:
318
+ An updated state with summarized messages if summarization was performed.
319
+ """
302
320
  messages = state["messages"]
303
321
  self._ensure_message_ids(messages)
304
322
 
@@ -449,7 +467,8 @@ class SummarizationMiddleware(AgentMiddleware):
449
467
 
450
468
  return max_input_tokens
451
469
 
452
- def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
470
+ @staticmethod
471
+ def _validate_context_size(context: ContextSize, parameter_name: str) -> ContextSize:
453
472
  """Validate context configuration tuples."""
454
473
  kind, value = context
455
474
  if kind == "fraction":
@@ -465,19 +484,24 @@ class SummarizationMiddleware(AgentMiddleware):
465
484
  raise ValueError(msg)
466
485
  return context
467
486
 
468
- def _build_new_messages(self, summary: str) -> list[HumanMessage]:
487
+ @staticmethod
488
+ def _build_new_messages(summary: str) -> list[HumanMessage]:
469
489
  return [
470
- HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
490
+ HumanMessage(
491
+ content=f"Here is a summary of the conversation to date:\n\n{summary}",
492
+ additional_kwargs={"lc_source": "summarization"},
493
+ )
471
494
  ]
472
495
 
473
- def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
496
+ @staticmethod
497
+ def _ensure_message_ids(messages: list[AnyMessage]) -> None:
474
498
  """Ensure all messages have unique IDs for the add_messages reducer."""
475
499
  for msg in messages:
476
500
  if msg.id is None:
477
501
  msg.id = str(uuid.uuid4())
478
502
 
503
+ @staticmethod
479
504
  def _partition_messages(
480
- self,
481
505
  conversation_messages: list[AnyMessage],
482
506
  cutoff_index: int,
483
507
  ) -> tuple[list[AnyMessage], list[AnyMessage]]:
@@ -502,7 +526,8 @@ class SummarizationMiddleware(AgentMiddleware):
502
526
  target_cutoff = len(messages) - messages_to_keep
503
527
  return self._find_safe_cutoff_point(messages, target_cutoff)
504
528
 
505
- def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
529
+ @staticmethod
530
+ def _find_safe_cutoff_point(messages: list[AnyMessage], cutoff_index: int) -> int:
506
531
  """Find a safe cutoff point that doesn't split AI/Tool message pairs.
507
532
 
508
533
  If the message at `cutoff_index` is a `ToolMessage`, search backward for the
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
  from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
13
13
  from langchain_core.tools import tool
14
14
  from langgraph.types import Command
15
- from typing_extensions import NotRequired, TypedDict
15
+ from typing_extensions import NotRequired, TypedDict, override
16
16
 
17
17
  from langchain.agents.middleware.types import (
18
18
  AgentMiddleware,
@@ -35,7 +35,7 @@ class Todo(TypedDict):
35
35
  """The current status of the todo item."""
36
36
 
37
37
 
38
- class PlanningState(AgentState):
38
+ class PlanningState(AgentState[Any]):
39
39
  """State schema for the todo middleware."""
40
40
 
41
41
  todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
@@ -118,7 +118,9 @@ Writing todos takes time and tokens, use it when it is helpful for managing comp
118
118
 
119
119
 
120
120
  @tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
121
- def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
121
+ def write_todos(
122
+ todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
123
+ ) -> Command[Any]:
122
124
  """Create and manage a structured task list for your current work session."""
123
125
  return Command(
124
126
  update={
@@ -178,7 +180,7 @@ class TodoListMiddleware(AgentMiddleware):
178
180
  @tool(description=self.tool_description)
179
181
  def write_todos(
180
182
  todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
181
- ) -> Command:
183
+ ) -> Command[Any]:
182
184
  """Create and manage a structured task list for your current work session."""
183
185
  return Command(
184
186
  update={
@@ -196,7 +198,16 @@ class TodoListMiddleware(AgentMiddleware):
196
198
  request: ModelRequest,
197
199
  handler: Callable[[ModelRequest], ModelResponse],
198
200
  ) -> ModelCallResult:
199
- """Update the system message to include the todo system prompt."""
201
+ """Update the system message to include the todo system prompt.
202
+
203
+ Args:
204
+ request: Model request to execute (includes state and runtime).
205
+ handler: Async callback that executes the model request and returns
206
+ `ModelResponse`.
207
+
208
+ Returns:
209
+ The model call result.
210
+ """
200
211
  if request.system_message is not None:
201
212
  new_system_content = [
202
213
  *request.system_message.content_blocks,
@@ -214,7 +225,16 @@ class TodoListMiddleware(AgentMiddleware):
214
225
  request: ModelRequest,
215
226
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
216
227
  ) -> ModelCallResult:
217
- """Update the system message to include the todo system prompt (async version)."""
228
+ """Update the system message to include the todo system prompt.
229
+
230
+ Args:
231
+ request: Model request to execute (includes state and runtime).
232
+ handler: Async callback that executes the model request and returns
233
+ `ModelResponse`.
234
+
235
+ Returns:
236
+ The model call result.
237
+ """
218
238
  if request.system_message is not None:
219
239
  new_system_content = [
220
240
  *request.system_message.content_blocks,
@@ -227,11 +247,8 @@ class TodoListMiddleware(AgentMiddleware):
227
247
  )
228
248
  return await handler(request.override(system_message=new_system_message))
229
249
 
230
- def after_model(
231
- self,
232
- state: AgentState,
233
- runtime: Runtime, # noqa: ARG002
234
- ) -> dict[str, Any] | None:
250
+ @override
251
+ def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
235
252
  """Check for parallel write_todos tool calls and return errors if detected.
236
253
 
237
254
  The todo list is designed to be updated at most once per model turn. Since
@@ -280,11 +297,8 @@ class TodoListMiddleware(AgentMiddleware):
280
297
 
281
298
  return None
282
299
 
283
- async def aafter_model(
284
- self,
285
- state: AgentState,
286
- runtime: Runtime,
287
- ) -> dict[str, Any] | None:
300
+ @override
301
+ async def aafter_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
288
302
  """Check for parallel write_todos tool calls and return errors if detected.
289
303
 
290
304
  Async version of `after_model`. The todo list is designed to be updated at
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING, Any
6
6
 
7
7
  from langchain_core.language_models.chat_models import BaseChatModel
8
8
  from langchain_core.messages import HumanMessage, ToolMessage
@@ -109,8 +109,8 @@ class LLMToolEmulator(AgentMiddleware):
109
109
  def wrap_tool_call(
110
110
  self,
111
111
  request: ToolCallRequest,
112
- handler: Callable[[ToolCallRequest], ToolMessage | Command],
113
- ) -> ToolMessage | Command:
112
+ handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
113
+ ) -> ToolMessage | Command[Any]:
114
114
  """Emulate tool execution using LLM if tool should be emulated.
115
115
 
116
116
  Args:
@@ -159,8 +159,8 @@ class LLMToolEmulator(AgentMiddleware):
159
159
  async def awrap_tool_call(
160
160
  self,
161
161
  request: ToolCallRequest,
162
- handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
163
- ) -> ToolMessage | Command:
162
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
163
+ ) -> ToolMessage | Command[Any]:
164
164
  """Async version of `wrap_tool_call`.
165
165
 
166
166
  Emulate tool execution using LLM if tool should be emulated.
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import asyncio
6
6
  import time
7
7
  import warnings
8
- from typing import TYPE_CHECKING
8
+ from typing import TYPE_CHECKING, Any
9
9
 
10
10
  from langchain_core.messages import ToolMessage
11
11
 
@@ -189,14 +189,14 @@ class ToolRetryMiddleware(AgentMiddleware):
189
189
 
190
190
  # Handle backwards compatibility for deprecated on_failure values
191
191
  if on_failure == "raise": # type: ignore[comparison-overlap]
192
- msg = (
192
+ msg = ( # type: ignore[unreachable]
193
193
  "on_failure='raise' is deprecated and will be removed in a future version. "
194
194
  "Use on_failure='error' instead."
195
195
  )
196
196
  warnings.warn(msg, DeprecationWarning, stacklevel=2)
197
197
  on_failure = "error"
198
198
  elif on_failure == "return_message": # type: ignore[comparison-overlap]
199
- msg = (
199
+ msg = ( # type: ignore[unreachable]
200
200
  "on_failure='return_message' is deprecated and will be removed "
201
201
  "in a future version. Use on_failure='continue' instead."
202
202
  )
@@ -233,7 +233,8 @@ class ToolRetryMiddleware(AgentMiddleware):
233
233
  return True
234
234
  return tool_name in self._tool_filter
235
235
 
236
- def _format_failure_message(self, tool_name: str, exc: Exception, attempts_made: int) -> str:
236
+ @staticmethod
237
+ def _format_failure_message(tool_name: str, exc: Exception, attempts_made: int) -> str:
237
238
  """Format the failure message when retries are exhausted.
238
239
 
239
240
  Args:
@@ -287,8 +288,8 @@ class ToolRetryMiddleware(AgentMiddleware):
287
288
  def wrap_tool_call(
288
289
  self,
289
290
  request: ToolCallRequest,
290
- handler: Callable[[ToolCallRequest], ToolMessage | Command],
291
- ) -> ToolMessage | Command:
291
+ handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
292
+ ) -> ToolMessage | Command[Any]:
292
293
  """Intercept tool execution and retry on failure.
293
294
 
294
295
  Args:
@@ -297,6 +298,9 @@ class ToolRetryMiddleware(AgentMiddleware):
297
298
 
298
299
  Returns:
299
300
  `ToolMessage` or `Command` (the final result).
301
+
302
+ Raises:
303
+ RuntimeError: If the retry loop completes without returning. This should not happen.
300
304
  """
301
305
  tool_name = request.tool.name if request.tool else request.tool_call["name"]
302
306
 
@@ -342,8 +346,8 @@ class ToolRetryMiddleware(AgentMiddleware):
342
346
  async def awrap_tool_call(
343
347
  self,
344
348
  request: ToolCallRequest,
345
- handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
346
- ) -> ToolMessage | Command:
349
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
350
+ ) -> ToolMessage | Command[Any]:
347
351
  """Intercept and control async tool execution with retry logic.
348
352
 
349
353
  Args:
@@ -353,6 +357,9 @@ class ToolRetryMiddleware(AgentMiddleware):
353
357
 
354
358
  Returns:
355
359
  `ToolMessage` or `Command` (the final result).
360
+
361
+ Raises:
362
+ RuntimeError: If the retry loop completes without returning. This should not happen.
356
363
  """
357
364
  tool_name = request.tool.name if request.tool else request.tool_call["name"]
358
365
 
@@ -4,12 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import logging
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING, Annotated, Literal, Union
8
-
9
- if TYPE_CHECKING:
10
- from collections.abc import Awaitable, Callable
11
-
12
- from langchain.tools import BaseTool
7
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
13
8
 
14
9
  from langchain_core.language_models.chat_models import BaseChatModel
15
10
  from langchain_core.messages import HumanMessage
@@ -24,6 +19,11 @@ from langchain.agents.middleware.types import (
24
19
  )
25
20
  from langchain.chat_models.base import init_chat_model
26
21
 
22
+ if TYPE_CHECKING:
23
+ from collections.abc import Awaitable, Callable
24
+
25
+ from langchain.tools import BaseTool
26
+
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
  DEFAULT_SYSTEM_PROMPT = (
@@ -42,7 +42,7 @@ class _SelectionRequest:
42
42
  valid_tool_names: list[str]
43
43
 
44
44
 
45
- def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
45
+ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
46
46
  """Create a structured output schema for tool selection.
47
47
 
48
48
  Args:
@@ -51,6 +51,9 @@ def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter:
51
51
  Returns:
52
52
  `TypeAdapter` for a schema where each tool name is a `Literal` with its
53
53
  description.
54
+
55
+ Raises:
56
+ AssertionError: If `tools` is empty.
54
57
  """
55
58
  if not tools:
56
59
  msg = "Invalid usage: tools must be non-empty"
@@ -153,9 +156,16 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
153
156
  def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest | None:
154
157
  """Prepare inputs for tool selection.
155
158
 
159
+ Args:
160
+ request: the model request.
161
+
156
162
  Returns:
157
163
  `SelectionRequest` with prepared inputs, or `None` if no selection is
158
- needed.
164
+ needed.
165
+
166
+ Raises:
167
+ ValueError: If tools in `always_include` are not found in the request.
168
+ AssertionError: If no user message is found in the request messages.
159
169
  """
160
170
  # If no tools available, return None
161
171
  if not request.tools or len(request.tools) == 0:
@@ -217,7 +227,7 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
217
227
 
218
228
  def _process_selection_response(
219
229
  self,
220
- response: dict,
230
+ response: dict[str, Any],
221
231
  available_tools: list[BaseTool],
222
232
  valid_tool_names: list[str],
223
233
  request: ModelRequest,
@@ -262,7 +272,19 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
262
272
  request: ModelRequest,
263
273
  handler: Callable[[ModelRequest], ModelResponse],
264
274
  ) -> ModelCallResult:
265
- """Filter tools based on LLM selection before invoking the model via handler."""
275
+ """Filter tools based on LLM selection before invoking the model via handler.
276
+
277
+ Args:
278
+ request: Model request to execute (includes state and runtime).
279
+ handler: Async callback that executes the model request and returns
280
+ `ModelResponse`.
281
+
282
+ Returns:
283
+ The model call result.
284
+
285
+ Raises:
286
+ AssertionError: If the selection model response is not a dict.
287
+ """
266
288
  selection_request = self._prepare_selection_request(request)
267
289
  if selection_request is None:
268
290
  return handler(request)
@@ -293,7 +315,19 @@ class LLMToolSelectorMiddleware(AgentMiddleware):
293
315
  request: ModelRequest,
294
316
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
295
317
  ) -> ModelCallResult:
296
- """Filter tools based on LLM selection before invoking the model via handler."""
318
+ """Filter tools based on LLM selection before invoking the model via handler.
319
+
320
+ Args:
321
+ request: Model request to execute (includes state and runtime).
322
+ handler: Async callback that executes the model request and returns
323
+ `ModelResponse`.
324
+
325
+ Returns:
326
+ The model call result.
327
+
328
+ Raises:
329
+ AssertionError: If the selection model response is not a dict.
330
+ """
297
331
  selection_request = self._prepare_selection_request(request)
298
332
  if selection_request is None:
299
333
  return await handler(request)