langchain 1.0.0a14__py3-none-any.whl → 1.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

@@ -2,16 +2,33 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Literal
5
+ from typing import TYPE_CHECKING, Annotated, Any, Literal
6
6
 
7
7
  from langchain_core.messages import AIMessage
8
+ from langgraph.channels.untracked_value import UntrackedValue
9
+ from typing_extensions import NotRequired
8
10
 
9
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
11
+ from langchain.agents.middleware.types import (
12
+ AgentMiddleware,
13
+ AgentState,
14
+ PrivateStateAttr,
15
+ hook_config,
16
+ )
10
17
 
11
18
  if TYPE_CHECKING:
12
19
  from langgraph.runtime import Runtime
13
20
 
14
21
 
22
+ class ModelCallLimitState(AgentState):
23
+ """State schema for ModelCallLimitMiddleware.
24
+
25
+ Extends AgentState with model call tracking fields.
26
+ """
27
+
28
+ thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
29
+ run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
30
+
31
+
15
32
  def _build_limit_exceeded_message(
16
33
  thread_count: int,
17
34
  run_count: int,
@@ -69,7 +86,7 @@ class ModelCallLimitExceededError(Exception):
69
86
  super().__init__(msg)
70
87
 
71
88
 
72
- class ModelCallLimitMiddleware(AgentMiddleware):
89
+ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
73
90
  """Middleware that tracks model call counts and enforces limits.
74
91
 
75
92
  This middleware monitors the number of model calls made during agent execution
@@ -97,6 +114,8 @@ class ModelCallLimitMiddleware(AgentMiddleware):
97
114
  ```
98
115
  """
99
116
 
117
+ state_schema = ModelCallLimitState
118
+
100
119
  def __init__(
101
120
  self,
102
121
  *,
@@ -108,17 +127,16 @@ class ModelCallLimitMiddleware(AgentMiddleware):
108
127
 
109
128
  Args:
110
129
  thread_limit: Maximum number of model calls allowed per thread.
111
- None means no limit. Defaults to `None`.
130
+ None means no limit.
112
131
  run_limit: Maximum number of model calls allowed per run.
113
- None means no limit. Defaults to `None`.
132
+ None means no limit.
114
133
  exit_behavior: What to do when limits are exceeded.
115
134
  - "end": Jump to the end of the agent execution and
116
135
  inject an artificial AI message indicating that the limit was exceeded.
117
- - "error": Raise a ModelCallLimitExceededError
118
- Defaults to "end".
136
+ - "error": Raise a `ModelCallLimitExceededError`
119
137
 
120
138
  Raises:
121
- ValueError: If both limits are None or if exit_behavior is invalid.
139
+ ValueError: If both limits are `None` or if `exit_behavior` is invalid.
122
140
  """
123
141
  super().__init__()
124
142
 
@@ -135,7 +153,7 @@ class ModelCallLimitMiddleware(AgentMiddleware):
135
153
  self.exit_behavior = exit_behavior
136
154
 
137
155
  @hook_config(can_jump_to=["end"])
138
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
156
+ def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
139
157
  """Check model call limits before making a model call.
140
158
 
141
159
  Args:
@@ -175,3 +193,18 @@ class ModelCallLimitMiddleware(AgentMiddleware):
175
193
  return {"jump_to": "end", "messages": [limit_ai_message]}
176
194
 
177
195
  return None
196
+
197
+ def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
198
+ """Increment model call counts after a model call.
199
+
200
+ Args:
201
+ state: The current agent state.
202
+ runtime: The langgraph runtime.
203
+
204
+ Returns:
205
+ State updates with incremented call counts.
206
+ """
207
+ return {
208
+ "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
209
+ "run_model_call_count": state.get("run_model_call_count", 0) + 1,
210
+ }
@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
13
13
  from langchain.chat_models import init_chat_model
14
14
 
15
15
  if TYPE_CHECKING:
16
- from collections.abc import Callable
16
+ from collections.abc import Awaitable, Callable
17
17
 
18
18
  from langchain_core.language_models.chat_models import BaseChatModel
19
19
 
@@ -75,8 +75,6 @@ class ModelFallbackMiddleware(AgentMiddleware):
75
75
 
76
76
  Args:
77
77
  request: Initial model request.
78
- state: Current agent state.
79
- runtime: LangGraph runtime.
80
78
  handler: Callback to execute the model.
81
79
 
82
80
  Returns:
@@ -102,3 +100,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
102
100
  continue
103
101
 
104
102
  raise last_exception
103
+
104
+ async def awrap_model_call(
105
+ self,
106
+ request: ModelRequest,
107
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
108
+ ) -> ModelCallResult:
109
+ """Try fallback models in sequence on errors (async version).
110
+
111
+ Args:
112
+ request: Initial model request.
113
+ handler: Async callback to execute the model.
114
+
115
+ Returns:
116
+ AIMessage from successful model call.
117
+
118
+ Raises:
119
+ Exception: If all models fail, re-raises last exception.
120
+ """
121
+ # Try primary model first
122
+ last_exception: Exception
123
+ try:
124
+ return await handler(request)
125
+ except Exception as e: # noqa: BLE001
126
+ last_exception = e
127
+
128
+ # Try fallback models
129
+ for fallback_model in self.models:
130
+ request.model = fallback_model
131
+ try:
132
+ return await handler(request)
133
+ except Exception as e: # noqa: BLE001
134
+ last_exception = e
135
+ continue
136
+
137
+ raise last_exception
@@ -421,7 +421,7 @@ class PIIMiddleware(AgentMiddleware):
421
421
  - `credit_card`: Credit card numbers (validated with Luhn algorithm)
422
422
  - `ip`: IP addresses (validated with stdlib)
423
423
  - `mac_address`: MAC addresses
424
- - `url`: URLs (both http/https and bare URLs)
424
+ - `url`: URLs (both `http`/`https` and bare URLs)
425
425
 
426
426
  Strategies:
427
427
  - `block`: Raise an exception when PII is detected
@@ -431,12 +431,12 @@ class PIIMiddleware(AgentMiddleware):
431
431
 
432
432
  Strategy Selection Guide:
433
433
 
434
- | Strategy | Preserves Identity? | Best For |
435
- | -------- | ------------------- | --------------------------------------- |
436
- | `block` | N/A | Avoid PII completely |
437
- | `redact` | No | General compliance, log sanitization |
438
- | `mask` | No | Human readability, customer service UIs |
439
- | `hash` | Yes (pseudonymous) | Analytics, debugging |
434
+ | Strategy | Preserves Identity? | Best For |
435
+ | -------- | ------------------- | --------------------------------------- |
436
+ | `block` | N/A | Avoid PII completely |
437
+ | `redact` | No | General compliance, log sanitization |
438
+ | `mask` | No | Human readability, customer service UIs |
439
+ | `hash` | Yes (pseudonymous) | Analytics, debugging |
440
440
 
441
441
  Example:
442
442
  ```python
@@ -6,7 +6,7 @@ from __future__ import annotations
6
6
  from typing import TYPE_CHECKING, Annotated, Literal
7
7
 
8
8
  if TYPE_CHECKING:
9
- from collections.abc import Callable
9
+ from collections.abc import Awaitable, Callable
10
10
 
11
11
  from langchain_core.messages import ToolMessage
12
12
  from langchain_core.tools import tool
@@ -126,7 +126,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
126
126
  )
127
127
 
128
128
 
129
- class PlanningMiddleware(AgentMiddleware):
129
+ class TodoListMiddleware(AgentMiddleware):
130
130
  """Middleware that provides todo list management capabilities to agents.
131
131
 
132
132
  This middleware adds a `write_todos` tool that allows agents to create and manage
@@ -139,10 +139,10 @@ class PlanningMiddleware(AgentMiddleware):
139
139
 
140
140
  Example:
141
141
  ```python
142
- from langchain.agents.middleware.planning import PlanningMiddleware
142
+ from langchain.agents.middleware.todo import TodoListMiddleware
143
143
  from langchain.agents import create_agent
144
144
 
145
- agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
145
+ agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
146
146
 
147
147
  # Agent now has access to write_todos tool and todo state tracking
148
148
  result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
@@ -165,7 +165,7 @@ class PlanningMiddleware(AgentMiddleware):
165
165
  system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
166
166
  tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
167
167
  ) -> None:
168
- """Initialize the PlanningMiddleware with optional custom prompts.
168
+ """Initialize the TodoListMiddleware with optional custom prompts.
169
169
 
170
170
  Args:
171
171
  system_prompt: Custom system prompt to guide the agent on using the todo tool.
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
204
204
  else self.system_prompt
205
205
  )
206
206
  return handler(request)
207
+
208
+ async def awrap_model_call(
209
+ self,
210
+ request: ModelRequest,
211
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
212
+ ) -> ModelCallResult:
213
+ """Update the system prompt to include the todo system prompt (async version)."""
214
+ request.system_prompt = (
215
+ request.system_prompt + "\n\n" + self.system_prompt
216
+ if request.system_prompt
217
+ else self.system_prompt
218
+ )
219
+ return await handler(request)
@@ -2,16 +2,37 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Literal
5
+ from typing import TYPE_CHECKING, Annotated, Any, Literal
6
6
 
7
7
  from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
8
+ from langgraph.channels.untracked_value import UntrackedValue
9
+ from typing_extensions import NotRequired
8
10
 
9
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
11
+ from langchain.agents.middleware.types import (
12
+ AgentMiddleware,
13
+ AgentState,
14
+ PrivateStateAttr,
15
+ hook_config,
16
+ )
10
17
 
11
18
  if TYPE_CHECKING:
12
19
  from langgraph.runtime import Runtime
13
20
 
14
21
 
22
+ class ToolCallLimitState(AgentState):
23
+ """State schema for ToolCallLimitMiddleware.
24
+
25
+ Extends AgentState with tool call tracking fields.
26
+
27
+ The count fields are dictionaries mapping tool names to execution counts.
28
+ This allows multiple middleware instances to track different tools independently.
29
+ The special key "__all__" is used for tracking all tool calls globally.
30
+ """
31
+
32
+ thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
33
+ run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
34
+
35
+
15
36
  def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | None = None) -> int:
16
37
  """Count tool calls in a list of messages.
17
38
 
@@ -124,18 +145,18 @@ class ToolCallLimitExceededError(Exception):
124
145
  super().__init__(msg)
125
146
 
126
147
 
127
- class ToolCallLimitMiddleware(AgentMiddleware):
148
+ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
128
149
  """Middleware that tracks tool call counts and enforces limits.
129
150
 
130
151
  This middleware monitors the number of tool calls made during agent execution
131
152
  and can terminate the agent when specified limits are reached. It supports
132
153
  both thread-level and run-level call counting with configurable exit behaviors.
133
154
 
134
- Thread-level: The middleware counts all tool calls in the entire message history
135
- and persists this count across multiple runs (invocations) of the agent.
155
+ Thread-level: The middleware tracks the total number of tool calls and persists
156
+ call count across multiple runs (invocations) of the agent.
136
157
 
137
- Run-level: The middleware counts tool calls made after the last HumanMessage,
138
- representing the current run (invocation) of the agent.
158
+ Run-level: The middleware tracks the number of tool calls made during a single
159
+ run (invocation) of the agent.
139
160
 
140
161
  Example:
141
162
  ```python
@@ -157,6 +178,8 @@ class ToolCallLimitMiddleware(AgentMiddleware):
157
178
  ```
158
179
  """
159
180
 
181
+ state_schema = ToolCallLimitState
182
+
160
183
  def __init__(
161
184
  self,
162
185
  *,
@@ -181,7 +204,7 @@ class ToolCallLimitMiddleware(AgentMiddleware):
181
204
  Defaults to "end".
182
205
 
183
206
  Raises:
184
- ValueError: If both limits are None or if exit_behavior is invalid.
207
+ ValueError: If both limits are `None` or if `exit_behavior` is invalid.
185
208
  """
186
209
  super().__init__()
187
210
 
@@ -211,11 +234,11 @@ class ToolCallLimitMiddleware(AgentMiddleware):
211
234
  return base_name
212
235
 
213
236
  @hook_config(can_jump_to=["end"])
214
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
237
+ def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
215
238
  """Check tool call limits before making a model call.
216
239
 
217
240
  Args:
218
- state: The current agent state containing messages.
241
+ state: The current agent state containing tool call counts.
219
242
  runtime: The langgraph runtime.
220
243
 
221
244
  Returns:
@@ -226,14 +249,14 @@ class ToolCallLimitMiddleware(AgentMiddleware):
226
249
  ToolCallLimitExceededError: If limits are exceeded and exit_behavior
227
250
  is "error".
228
251
  """
229
- messages = state.get("messages", [])
252
+ # Get the count key for this middleware instance
253
+ count_key = self.tool_name if self.tool_name else "__all__"
230
254
 
231
- # Count tool calls in entire thread
232
- thread_count = _count_tool_calls_in_messages(messages, self.tool_name)
255
+ thread_counts = state.get("thread_tool_call_count", {})
256
+ run_counts = state.get("run_tool_call_count", {})
233
257
 
234
- # Count tool calls in current run (after last HumanMessage)
235
- run_messages = _get_run_messages(messages)
236
- run_count = _count_tool_calls_in_messages(run_messages, self.tool_name)
258
+ thread_count = thread_counts.get(count_key, 0)
259
+ run_count = run_counts.get(count_key, 0)
237
260
 
238
261
  # Check if any limits are exceeded
239
262
  thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
@@ -258,3 +281,53 @@ class ToolCallLimitMiddleware(AgentMiddleware):
258
281
  return {"jump_to": "end", "messages": [limit_ai_message]}
259
282
 
260
283
  return None
284
+
285
+ def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
286
+ """Increment tool call counts after a model call (when tool calls are made).
287
+
288
+ Args:
289
+ state: The current agent state.
290
+ runtime: The langgraph runtime.
291
+
292
+ Returns:
293
+ State updates with incremented tool call counts if tool calls were made.
294
+ """
295
+ # Get the last AIMessage to check for tool calls
296
+ messages = state.get("messages", [])
297
+ if not messages:
298
+ return None
299
+
300
+ # Find the last AIMessage
301
+ last_ai_message = None
302
+ for message in reversed(messages):
303
+ if isinstance(message, AIMessage):
304
+ last_ai_message = message
305
+ break
306
+
307
+ if not last_ai_message or not last_ai_message.tool_calls:
308
+ return None
309
+
310
+ # Count relevant tool calls (filter by tool_name if specified)
311
+ tool_call_count = 0
312
+ for tool_call in last_ai_message.tool_calls:
313
+ if self.tool_name is None or tool_call["name"] == self.tool_name:
314
+ tool_call_count += 1
315
+
316
+ if tool_call_count == 0:
317
+ return None
318
+
319
+ # Get the count key for this middleware instance
320
+ count_key = self.tool_name if self.tool_name else "__all__"
321
+
322
+ # Get current counts
323
+ thread_counts = state.get("thread_tool_call_count", {}).copy()
324
+ run_counts = state.get("run_tool_call_count", {}).copy()
325
+
326
+ # Increment counts for this key
327
+ thread_counts[count_key] = thread_counts.get(count_key, 0) + tool_call_count
328
+ run_counts[count_key] = run_counts.get(count_key, 0) + tool_call_count
329
+
330
+ return {
331
+ "thread_tool_call_count": thread_counts,
332
+ "run_tool_call_count": run_counts,
333
+ }
@@ -123,7 +123,7 @@ class LLMToolEmulator(AgentMiddleware):
123
123
 
124
124
  # Extract tool information for emulation
125
125
  tool_args = request.tool_call["args"]
126
- tool_description = request.tool.description
126
+ tool_description = request.tool.description if request.tool else "No description available"
127
127
 
128
128
  # Build prompt for emulator LLM
129
129
  prompt = (
@@ -175,7 +175,7 @@ class LLMToolEmulator(AgentMiddleware):
175
175
 
176
176
  # Extract tool information for emulation
177
177
  tool_args = request.tool_call["args"]
178
- tool_description = request.tool.description
178
+ tool_description = request.tool.description if request.tool else "No description available"
179
179
 
180
180
  # Build prompt for emulator LLM
181
181
  prompt = (