langchain 1.0.5__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +1 -7
  3. langchain/agents/factory.py +99 -40
  4. langchain/agents/middleware/__init__.py +5 -7
  5. langchain/agents/middleware/_execution.py +21 -20
  6. langchain/agents/middleware/_redaction.py +27 -12
  7. langchain/agents/middleware/_retry.py +123 -0
  8. langchain/agents/middleware/context_editing.py +26 -22
  9. langchain/agents/middleware/file_search.py +18 -13
  10. langchain/agents/middleware/human_in_the_loop.py +60 -54
  11. langchain/agents/middleware/model_call_limit.py +63 -17
  12. langchain/agents/middleware/model_fallback.py +7 -9
  13. langchain/agents/middleware/model_retry.py +300 -0
  14. langchain/agents/middleware/pii.py +80 -27
  15. langchain/agents/middleware/shell_tool.py +230 -103
  16. langchain/agents/middleware/summarization.py +439 -90
  17. langchain/agents/middleware/todo.py +111 -27
  18. langchain/agents/middleware/tool_call_limit.py +105 -71
  19. langchain/agents/middleware/tool_emulator.py +42 -33
  20. langchain/agents/middleware/tool_retry.py +171 -159
  21. langchain/agents/middleware/tool_selection.py +37 -27
  22. langchain/agents/middleware/types.py +754 -392
  23. langchain/agents/structured_output.py +22 -12
  24. langchain/chat_models/__init__.py +1 -7
  25. langchain/chat_models/base.py +233 -184
  26. langchain/embeddings/__init__.py +0 -5
  27. langchain/embeddings/base.py +79 -65
  28. langchain/messages/__init__.py +0 -5
  29. langchain/tools/__init__.py +1 -7
  30. {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
  31. langchain-1.2.3.dist-info/RECORD +36 -0
  32. {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
  33. langchain-1.0.5.dist-info/RECORD +0 -34
  34. {langchain-1.0.5.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,123 @@
1
+ """Shared retry utilities for agent middleware.
2
+
3
+ This module contains common constants, utilities, and logic used by both
4
+ model and tool retry middleware implementations.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import random
10
+ from collections.abc import Callable
11
+ from typing import Literal
12
+
13
+ # Type aliases
14
+ RetryOn = tuple[type[Exception], ...] | Callable[[Exception], bool]
15
+ """Type for specifying which exceptions to retry on.
16
+
17
+ Can be either:
18
+ - A tuple of exception types to retry on (based on `isinstance` checks)
19
+ - A callable that takes an exception and returns `True` if it should be retried
20
+ """
21
+
22
+ OnFailure = Literal["error", "continue"] | Callable[[Exception], str]
23
+ """Type for specifying failure handling behavior.
24
+
25
+ Can be either:
26
+ - A literal action string (`'error'` or `'continue'`)
27
+ - `'error'`: Re-raise the exception, stopping agent execution.
28
+ - `'continue'`: Inject a message with the error details, allowing the agent to continue.
29
+ For tool retries, a `ToolMessage` with the error details will be injected.
30
+ For model retries, an `AIMessage` with the error details will be returned.
31
+ - A callable that takes an exception and returns a string for error message content
32
+ """
33
+
34
+
35
+ def validate_retry_params(
36
+ max_retries: int,
37
+ initial_delay: float,
38
+ max_delay: float,
39
+ backoff_factor: float,
40
+ ) -> None:
41
+ """Validate retry parameters.
42
+
43
+ Args:
44
+ max_retries: Maximum number of retry attempts.
45
+ initial_delay: Initial delay in seconds before first retry.
46
+ max_delay: Maximum delay in seconds between retries.
47
+ backoff_factor: Multiplier for exponential backoff.
48
+
49
+ Raises:
50
+ ValueError: If any parameter is invalid (negative values).
51
+ """
52
+ if max_retries < 0:
53
+ msg = "max_retries must be >= 0"
54
+ raise ValueError(msg)
55
+ if initial_delay < 0:
56
+ msg = "initial_delay must be >= 0"
57
+ raise ValueError(msg)
58
+ if max_delay < 0:
59
+ msg = "max_delay must be >= 0"
60
+ raise ValueError(msg)
61
+ if backoff_factor < 0:
62
+ msg = "backoff_factor must be >= 0"
63
+ raise ValueError(msg)
64
+
65
+
66
+ def should_retry_exception(
67
+ exc: Exception,
68
+ retry_on: RetryOn,
69
+ ) -> bool:
70
+ """Check if an exception should trigger a retry.
71
+
72
+ Args:
73
+ exc: The exception that occurred.
74
+ retry_on: Either a tuple of exception types to retry on, or a callable
75
+ that takes an exception and returns `True` if it should be retried.
76
+
77
+ Returns:
78
+ `True` if the exception should be retried, `False` otherwise.
79
+ """
80
+ if callable(retry_on):
81
+ return retry_on(exc)
82
+ return isinstance(exc, retry_on)
83
+
84
+
85
+ def calculate_delay(
86
+ retry_number: int,
87
+ *,
88
+ backoff_factor: float,
89
+ initial_delay: float,
90
+ max_delay: float,
91
+ jitter: bool,
92
+ ) -> float:
93
+ """Calculate delay for a retry attempt with exponential backoff and optional jitter.
94
+
95
+ Args:
96
+ retry_number: The retry attempt number (0-indexed).
97
+ backoff_factor: Multiplier for exponential backoff.
98
+
99
+ Set to `0.0` for constant delay.
100
+ initial_delay: Initial delay in seconds before first retry.
101
+ max_delay: Maximum delay in seconds between retries.
102
+
103
+ Caps exponential backoff growth.
104
+ jitter: Whether to add random jitter to delay to avoid thundering herd.
105
+
106
+ Returns:
107
+ Delay in seconds before next retry.
108
+ """
109
+ if backoff_factor == 0.0:
110
+ delay = initial_delay
111
+ else:
112
+ delay = initial_delay * (backoff_factor**retry_number)
113
+
114
+ # Cap at max_delay
115
+ delay = min(delay, max_delay)
116
+
117
+ if jitter and delay > 0:
118
+ jitter_amount = delay * 0.25 # ±25% jitter
119
+ delay += random.uniform(-jitter_amount, jitter_amount) # noqa: S311
120
+ # Ensure delay is not negative after jitter
121
+ delay = max(0, delay)
122
+
123
+ return delay
@@ -1,14 +1,16 @@
1
1
  """Context editing middleware.
2
2
 
3
- This middleware mirrors Anthropic's context editing capabilities by clearing
4
- older tool results once the conversation grows beyond a configurable token
5
- threshold. The implementation is intentionally model-agnostic so it can be used
6
- with any LangChain chat model.
3
+ Mirrors Anthropic's context editing capabilities by clearing older tool results once the
4
+ conversation grows beyond a configurable token threshold.
5
+
6
+ The implementation is intentionally model-agnostic so it can be used with any LangChain
7
+ chat model.
7
8
  """
8
9
 
9
10
  from __future__ import annotations
10
11
 
11
12
  from collections.abc import Awaitable, Callable, Iterable, Sequence
13
+ from copy import deepcopy
12
14
  from dataclasses import dataclass
13
15
  from typing import Literal
14
16
 
@@ -16,7 +18,6 @@ from langchain_core.messages import (
16
18
  AIMessage,
17
19
  AnyMessage,
18
20
  BaseMessage,
19
- SystemMessage,
20
21
  ToolMessage,
21
22
  )
22
23
  from langchain_core.messages.utils import count_tokens_approximately
@@ -182,11 +183,13 @@ class ClearToolUsesEdit(ContextEdit):
182
183
 
183
184
 
184
185
  class ContextEditingMiddleware(AgentMiddleware):
185
- """Automatically prunes tool results to manage context size.
186
+ """Automatically prune tool results to manage context size.
187
+
188
+ The middleware applies a sequence of edits when the total input token count exceeds
189
+ configured thresholds.
186
190
 
187
- The middleware applies a sequence of edits when the total input token count
188
- exceeds configured thresholds. Currently the `ClearToolUsesEdit` strategy is
189
- supported, aligning with Anthropic's `clear_tool_uses_20250919` behaviour.
191
+ Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
192
+ `clear_tool_uses_20250919` behavior [(read more)](https://platform.claude.com/docs/en/agents-and-tools/tool-use/memory-tool).
190
193
  """
191
194
 
192
195
  edits: list[ContextEdit]
@@ -198,11 +201,12 @@ class ContextEditingMiddleware(AgentMiddleware):
198
201
  edits: Iterable[ContextEdit] | None = None,
199
202
  token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
200
203
  ) -> None:
201
- """Initializes a context editing middleware instance.
204
+ """Initialize an instance of context editing middleware.
202
205
 
203
206
  Args:
204
- edits: Sequence of edit strategies to apply. Defaults to a single
205
- `ClearToolUsesEdit` mirroring Anthropic defaults.
207
+ edits: Sequence of edit strategies to apply.
208
+
209
+ Defaults to a single `ClearToolUsesEdit` mirroring Anthropic defaults.
206
210
  token_count_method: Whether to use approximate token counting
207
211
  (faster, less accurate) or exact counting implemented by the
208
212
  chat model (potentially slower, more accurate).
@@ -224,20 +228,20 @@ class ContextEditingMiddleware(AgentMiddleware):
224
228
 
225
229
  def count_tokens(messages: Sequence[BaseMessage]) -> int:
226
230
  return count_tokens_approximately(messages)
231
+
227
232
  else:
228
- system_msg = (
229
- [SystemMessage(content=request.system_prompt)] if request.system_prompt else []
230
- )
233
+ system_msg = [request.system_message] if request.system_message else []
231
234
 
232
235
  def count_tokens(messages: Sequence[BaseMessage]) -> int:
233
236
  return request.model.get_num_tokens_from_messages(
234
237
  system_msg + list(messages), request.tools
235
238
  )
236
239
 
240
+ edited_messages = deepcopy(list(request.messages))
237
241
  for edit in self.edits:
238
- edit.apply(request.messages, count_tokens=count_tokens)
242
+ edit.apply(edited_messages, count_tokens=count_tokens)
239
243
 
240
- return handler(request)
244
+ return handler(request.override(messages=edited_messages))
241
245
 
242
246
  async def awrap_model_call(
243
247
  self,
@@ -252,20 +256,20 @@ class ContextEditingMiddleware(AgentMiddleware):
252
256
 
253
257
  def count_tokens(messages: Sequence[BaseMessage]) -> int:
254
258
  return count_tokens_approximately(messages)
259
+
255
260
  else:
256
- system_msg = (
257
- [SystemMessage(content=request.system_prompt)] if request.system_prompt else []
258
- )
261
+ system_msg = [request.system_message] if request.system_message else []
259
262
 
260
263
  def count_tokens(messages: Sequence[BaseMessage]) -> int:
261
264
  return request.model.get_num_tokens_from_messages(
262
265
  system_msg + list(messages), request.tools
263
266
  )
264
267
 
268
+ edited_messages = deepcopy(list(request.messages))
265
269
  for edit in self.edits:
266
- edit.apply(request.messages, count_tokens=count_tokens)
270
+ edit.apply(edited_messages, count_tokens=count_tokens)
267
271
 
268
- return await handler(request)
272
+ return await handler(request.override(messages=edited_messages))
269
273
 
270
274
 
271
275
  __all__ = [
@@ -21,7 +21,7 @@ from langchain.agents.middleware.types import AgentMiddleware
21
21
 
22
22
 
23
23
  def _expand_include_patterns(pattern: str) -> list[str] | None:
24
- """Expand brace patterns like ``*.{py,pyi}`` into a list of globs."""
24
+ """Expand brace patterns like `*.{py,pyi}` into a list of globs."""
25
25
  if "}" in pattern and "{" not in pattern:
26
26
  return None
27
27
 
@@ -88,6 +88,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
88
88
  """Provides Glob and Grep search over filesystem files.
89
89
 
90
90
  This middleware adds two tools that search through local filesystem:
91
+
91
92
  - Glob: Fast file pattern matching by file path
92
93
  - Grep: Fast content search using ripgrep or Python fallback
93
94
 
@@ -100,7 +101,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
100
101
 
101
102
  agent = create_agent(
102
103
  model=model,
103
- tools=[],
104
+ tools=[], # Add tools as needed
104
105
  middleware=[
105
106
  FilesystemFileSearchMiddleware(root_path="/workspace"),
106
107
  ],
@@ -119,9 +120,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
119
120
 
120
121
  Args:
121
122
  root_path: Root directory to search.
122
- use_ripgrep: Whether to use ripgrep for search (default: True).
123
- Falls back to Python if ripgrep unavailable.
124
- max_file_size_mb: Maximum file size to search in MB (default: 10).
123
+ use_ripgrep: Whether to use `ripgrep` for search.
124
+
125
+ Falls back to Python if `ripgrep` unavailable.
126
+ max_file_size_mb: Maximum file size to search in MB.
125
127
  """
126
128
  self.root_path = Path(root_path).resolve()
127
129
  self.use_ripgrep = use_ripgrep
@@ -132,8 +134,10 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
132
134
  def glob_search(pattern: str, path: str = "/") -> str:
133
135
  """Fast file pattern matching tool that works with any codebase size.
134
136
 
135
- Supports glob patterns like **/*.js or src/**/*.ts.
137
+ Supports glob patterns like `**/*.js` or `src/**/*.ts`.
138
+
136
139
  Returns matching file paths sorted by modification time.
140
+
137
141
  Use this tool when you need to find files by name patterns.
138
142
 
139
143
  Args:
@@ -142,7 +146,7 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
142
146
 
143
147
  Returns:
144
148
  Newline-separated list of matching file paths, sorted by modification
145
- time (most recently modified first). Returns "No files found" if no
149
+ time (most recently modified first). Returns `'No files found'` if no
146
150
  matches.
147
151
  """
148
152
  try:
@@ -184,15 +188,16 @@ class FilesystemFileSearchMiddleware(AgentMiddleware):
184
188
  Args:
185
189
  pattern: The regular expression pattern to search for in file contents.
186
190
  path: The directory to search in. If not specified, searches from root.
187
- include: File pattern to filter (e.g., "*.js", "*.{ts,tsx}").
191
+ include: File pattern to filter (e.g., `'*.js'`, `'*.{ts,tsx}'`).
188
192
  output_mode: Output format:
189
- - "files_with_matches": Only file paths containing matches (default)
190
- - "content": Matching lines with file:line:content format
191
- - "count": Count of matches per file
193
+
194
+ - `'files_with_matches'`: Only file paths containing matches
195
+ - `'content'`: Matching lines with `file:line:content` format
196
+ - `'count'`: Count of matches per file
192
197
 
193
198
  Returns:
194
- Search results formatted according to output_mode. Returns "No matches
195
- found" if no results.
199
+ Search results formatted according to `output_mode`.
200
+ Returns `'No matches found'` if no results.
196
201
  """
197
202
  # Compile regex pattern (for validation)
198
203
  try:
@@ -7,17 +7,17 @@ from langgraph.runtime import Runtime
7
7
  from langgraph.types import interrupt
8
8
  from typing_extensions import NotRequired, TypedDict
9
9
 
10
- from langchain.agents.middleware.types import AgentMiddleware, AgentState
10
+ from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
11
11
 
12
12
 
13
13
  class Action(TypedDict):
14
14
  """Represents an action with a name and args."""
15
15
 
16
16
  name: str
17
- """The type or name of action being requested (e.g., "add_numbers")."""
17
+ """The type or name of action being requested (e.g., `'add_numbers'`)."""
18
18
 
19
19
  args: dict[str, Any]
20
- """Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
20
+ """Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
21
21
 
22
22
 
23
23
  class ActionRequest(TypedDict):
@@ -27,7 +27,7 @@ class ActionRequest(TypedDict):
27
27
  """The name of the action being requested."""
28
28
 
29
29
  args: dict[str, Any]
30
- """Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
30
+ """Key-value pairs of args needed for the action (e.g., `{"a": 1, "b": 2}`)."""
31
31
 
32
32
  description: NotRequired[str]
33
33
  """The description of the action to be reviewed."""
@@ -102,7 +102,7 @@ class HITLResponse(TypedDict):
102
102
  class _DescriptionFactory(Protocol):
103
103
  """Callable that generates a description for a tool call."""
104
104
 
105
- def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
105
+ def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
106
106
  """Generate a description for a tool call."""
107
107
  ...
108
108
 
@@ -138,7 +138,7 @@ class InterruptOnConfig(TypedDict):
138
138
  def format_tool_description(
139
139
  tool_call: ToolCall,
140
140
  state: AgentState,
141
- runtime: Runtime
141
+ runtime: Runtime[ContextT]
142
142
  ) -> str:
143
143
  import json
144
144
  return (
@@ -156,7 +156,7 @@ class InterruptOnConfig(TypedDict):
156
156
  """JSON schema for the args associated with the action, if edits are allowed."""
157
157
 
158
158
 
159
- class HumanInTheLoopMiddleware(AgentMiddleware):
159
+ class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
160
160
  """Human in the loop middleware."""
161
161
 
162
162
  def __init__(
@@ -169,18 +169,22 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
169
169
 
170
170
  Args:
171
171
  interrupt_on: Mapping of tool name to allowed actions.
172
+
172
173
  If a tool doesn't have an entry, it's auto-approved by default.
173
174
 
174
175
  * `True` indicates all decisions are allowed: approve, edit, and reject.
175
176
  * `False` indicates that the tool is auto-approved.
176
177
  * `InterruptOnConfig` indicates the specific decisions allowed for this
177
178
  tool.
178
- The InterruptOnConfig can include a `description` field (`str` or
179
+
180
+ The `InterruptOnConfig` can include a `description` field (`str` or
179
181
  `Callable`) for custom formatting of the interrupt description.
180
182
  description_prefix: The prefix to use when constructing action requests.
183
+
181
184
  This is used to provide context about the tool call and the action being
182
- requested. Not used if a tool has a `description` in its
183
- `InterruptOnConfig`.
185
+ requested.
186
+
187
+ Not used if a tool has a `description` in its `InterruptOnConfig`.
184
188
  """
185
189
  super().__init__()
186
190
  resolved_configs: dict[str, InterruptOnConfig] = {}
@@ -200,7 +204,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
200
204
  tool_call: ToolCall,
201
205
  config: InterruptOnConfig,
202
206
  state: AgentState,
203
- runtime: Runtime,
207
+ runtime: Runtime[ContextT],
204
208
  ) -> tuple[ActionRequest, ReviewConfig]:
205
209
  """Create an ActionRequest and ReviewConfig for a tool call."""
206
210
  tool_name = tool_call["name"]
@@ -273,7 +277,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
273
277
  )
274
278
  raise ValueError(msg)
275
279
 
276
- def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
280
+ def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
277
281
  """Trigger interrupt flows for relevant tool calls after an `AIMessage`."""
278
282
  messages = state["messages"]
279
283
  if not messages:
@@ -283,36 +287,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
283
287
  if not last_ai_msg or not last_ai_msg.tool_calls:
284
288
  return None
285
289
 
286
- # Separate tool calls that need interrupts from those that don't
287
- interrupt_tool_calls: list[ToolCall] = []
288
- auto_approved_tool_calls = []
289
-
290
- for tool_call in last_ai_msg.tool_calls:
291
- interrupt_tool_calls.append(tool_call) if tool_call[
292
- "name"
293
- ] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
294
-
295
- # If no interrupts needed, return early
296
- if not interrupt_tool_calls:
297
- return None
298
-
299
- # Process all tool calls that require interrupts
300
- revised_tool_calls: list[ToolCall] = auto_approved_tool_calls.copy()
301
- artificial_tool_messages: list[ToolMessage] = []
302
-
303
- # Create action requests and review configs for all tools that need approval
290
+ # Create action requests and review configs for tools that need approval
304
291
  action_requests: list[ActionRequest] = []
305
292
  review_configs: list[ReviewConfig] = []
293
+ interrupt_indices: list[int] = []
306
294
 
307
- for tool_call in interrupt_tool_calls:
308
- config = self.interrupt_on[tool_call["name"]]
295
+ for idx, tool_call in enumerate(last_ai_msg.tool_calls):
296
+ if (config := self.interrupt_on.get(tool_call["name"])) is not None:
297
+ action_request, review_config = self._create_action_and_config(
298
+ tool_call, config, state, runtime
299
+ )
300
+ action_requests.append(action_request)
301
+ review_configs.append(review_config)
302
+ interrupt_indices.append(idx)
309
303
 
310
- # Create ActionRequest and ReviewConfig using helper method
311
- action_request, review_config = self._create_action_and_config(
312
- tool_call, config, state, runtime
313
- )
314
- action_requests.append(action_request)
315
- review_configs.append(review_config)
304
+ # If no interrupts needed, return early
305
+ if not action_requests:
306
+ return None
316
307
 
317
308
  # Create single HITLRequest with all actions and configs
318
309
  hitl_request = HITLRequest(
@@ -321,31 +312,46 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
321
312
  )
322
313
 
323
314
  # Send interrupt and get response
324
- hitl_response: HITLResponse = interrupt(hitl_request)
325
- decisions = hitl_response["decisions"]
315
+ decisions = interrupt(hitl_request)["decisions"]
326
316
 
327
317
  # Validate that the number of decisions matches the number of interrupt tool calls
328
- if (decisions_len := len(decisions)) != (
329
- interrupt_tool_calls_len := len(interrupt_tool_calls)
330
- ):
318
+ if (decisions_len := len(decisions)) != (interrupt_count := len(interrupt_indices)):
331
319
  msg = (
332
320
  f"Number of human decisions ({decisions_len}) does not match "
333
- f"number of hanging tool calls ({interrupt_tool_calls_len})."
321
+ f"number of hanging tool calls ({interrupt_count})."
334
322
  )
335
323
  raise ValueError(msg)
336
324
 
337
- # Process each decision using helper method
338
- for i, decision in enumerate(decisions):
339
- tool_call = interrupt_tool_calls[i]
340
- config = self.interrupt_on[tool_call["name"]]
341
-
342
- revised_tool_call, tool_message = self._process_decision(decision, tool_call, config)
343
- if revised_tool_call:
344
- revised_tool_calls.append(revised_tool_call)
345
- if tool_message:
346
- artificial_tool_messages.append(tool_message)
325
+ # Process decisions and rebuild tool calls in original order
326
+ revised_tool_calls: list[ToolCall] = []
327
+ artificial_tool_messages: list[ToolMessage] = []
328
+ decision_idx = 0
329
+
330
+ for idx, tool_call in enumerate(last_ai_msg.tool_calls):
331
+ if idx in interrupt_indices:
332
+ # This was an interrupt tool call - process the decision
333
+ config = self.interrupt_on[tool_call["name"]]
334
+ decision = decisions[decision_idx]
335
+ decision_idx += 1
336
+
337
+ revised_tool_call, tool_message = self._process_decision(
338
+ decision, tool_call, config
339
+ )
340
+ if revised_tool_call is not None:
341
+ revised_tool_calls.append(revised_tool_call)
342
+ if tool_message:
343
+ artificial_tool_messages.append(tool_message)
344
+ else:
345
+ # This was auto-approved - keep original
346
+ revised_tool_calls.append(tool_call)
347
347
 
348
348
  # Update the AI message to only include approved tool calls
349
349
  last_ai_msg.tool_calls = revised_tool_calls
350
350
 
351
351
  return {"messages": [last_ai_msg, *artificial_tool_messages]}
352
+
353
+ async def aafter_model(
354
+ self, state: AgentState, runtime: Runtime[ContextT]
355
+ ) -> dict[str, Any] | None:
356
+ """Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
357
+ return self.after_model(state, runtime)
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Literal
6
6
 
7
7
  from langchain_core.messages import AIMessage
8
8
  from langgraph.channels.untracked_value import UntrackedValue
9
- from typing_extensions import NotRequired
9
+ from typing_extensions import NotRequired, override
10
10
 
11
11
  from langchain.agents.middleware.types import (
12
12
  AgentMiddleware,
@@ -20,9 +20,9 @@ if TYPE_CHECKING:
20
20
 
21
21
 
22
22
  class ModelCallLimitState(AgentState):
23
- """State schema for ModelCallLimitMiddleware.
23
+ """State schema for `ModelCallLimitMiddleware`.
24
24
 
25
- Extends AgentState with model call tracking fields.
25
+ Extends `AgentState` with model call tracking fields.
26
26
  """
27
27
 
28
28
  thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
@@ -58,8 +58,8 @@ def _build_limit_exceeded_message(
58
58
  class ModelCallLimitExceededError(Exception):
59
59
  """Exception raised when model call limits are exceeded.
60
60
 
61
- This exception is raised when the configured exit behavior is 'error'
62
- and either the thread or run model call limit has been exceeded.
61
+ This exception is raised when the configured exit behavior is `'error'` and either
62
+ the thread or run model call limit has been exceeded.
63
63
  """
64
64
 
65
65
  def __init__(
@@ -127,13 +127,17 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
127
127
 
128
128
  Args:
129
129
  thread_limit: Maximum number of model calls allowed per thread.
130
- None means no limit.
130
+
131
+ `None` means no limit.
131
132
  run_limit: Maximum number of model calls allowed per run.
132
- None means no limit.
133
+
134
+ `None` means no limit.
133
135
  exit_behavior: What to do when limits are exceeded.
134
- - "end": Jump to the end of the agent execution and
135
- inject an artificial AI message indicating that the limit was exceeded.
136
- - "error": Raise a `ModelCallLimitExceededError`
136
+
137
+ - `'end'`: Jump to the end of the agent execution and
138
+ inject an artificial AI message indicating that the limit was
139
+ exceeded.
140
+ - `'error'`: Raise a `ModelCallLimitExceededError`
137
141
 
138
142
  Raises:
139
143
  ValueError: If both limits are `None` or if `exit_behavior` is invalid.
@@ -144,7 +148,7 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
144
148
  msg = "At least one limit must be specified (thread_limit or run_limit)"
145
149
  raise ValueError(msg)
146
150
 
147
- if exit_behavior not in ("end", "error"):
151
+ if exit_behavior not in {"end", "error"}:
148
152
  msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
149
153
  raise ValueError(msg)
150
154
 
@@ -153,7 +157,8 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
153
157
  self.exit_behavior = exit_behavior
154
158
 
155
159
  @hook_config(can_jump_to=["end"])
156
- def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
160
+ @override
161
+ def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
157
162
  """Check model call limits before making a model call.
158
163
 
159
164
  Args:
@@ -161,12 +166,13 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
161
166
  runtime: The langgraph runtime.
162
167
 
163
168
  Returns:
164
- If limits are exceeded and exit_behavior is "end", returns
165
- a Command to jump to the end with a limit exceeded message. Otherwise returns None.
169
+ If limits are exceeded and exit_behavior is `'end'`, returns
170
+ a `Command` to jump to the end with a limit exceeded message. Otherwise
171
+ returns `None`.
166
172
 
167
173
  Raises:
168
- ModelCallLimitExceededError: If limits are exceeded and exit_behavior
169
- is "error".
174
+ ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
175
+ is `'error'`.
170
176
  """
171
177
  thread_count = state.get("thread_model_call_count", 0)
172
178
  run_count = state.get("run_model_call_count", 0)
@@ -194,7 +200,31 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
194
200
 
195
201
  return None
196
202
 
197
- def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
203
+ @hook_config(can_jump_to=["end"])
204
+ async def abefore_model(
205
+ self,
206
+ state: ModelCallLimitState,
207
+ runtime: Runtime,
208
+ ) -> dict[str, Any] | None:
209
+ """Async check model call limits before making a model call.
210
+
211
+ Args:
212
+ state: The current agent state containing call counts.
213
+ runtime: The langgraph runtime.
214
+
215
+ Returns:
216
+ If limits are exceeded and exit_behavior is `'end'`, returns
217
+ a `Command` to jump to the end with a limit exceeded message. Otherwise
218
+ returns `None`.
219
+
220
+ Raises:
221
+ ModelCallLimitExceededError: If limits are exceeded and `exit_behavior`
222
+ is `'error'`.
223
+ """
224
+ return self.before_model(state, runtime)
225
+
226
+ @override
227
+ def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
198
228
  """Increment model call counts after a model call.
199
229
 
200
230
  Args:
@@ -208,3 +238,19 @@ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
208
238
  "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
209
239
  "run_model_call_count": state.get("run_model_call_count", 0) + 1,
210
240
  }
241
+
242
+ async def aafter_model(
243
+ self,
244
+ state: ModelCallLimitState,
245
+ runtime: Runtime,
246
+ ) -> dict[str, Any] | None:
247
+ """Async increment model call counts after a model call.
248
+
249
+ Args:
250
+ state: The current agent state.
251
+ runtime: The langgraph runtime.
252
+
253
+ Returns:
254
+ State updates with incremented call counts.
255
+ """
256
+ return self.after_model(state, runtime)