langchain 1.0.0a13__py3-none-any.whl → 1.0.0a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

@@ -2,16 +2,37 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Literal
5
+ from typing import TYPE_CHECKING, Annotated, Any, Literal
6
6
 
7
7
  from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
8
+ from langgraph.channels.untracked_value import UntrackedValue
9
+ from typing_extensions import NotRequired
8
10
 
9
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
11
+ from langchain.agents.middleware.types import (
12
+ AgentMiddleware,
13
+ AgentState,
14
+ PrivateStateAttr,
15
+ hook_config,
16
+ )
10
17
 
11
18
  if TYPE_CHECKING:
12
19
  from langgraph.runtime import Runtime
13
20
 
14
21
 
22
+ class ToolCallLimitState(AgentState):
23
+ """State schema for ToolCallLimitMiddleware.
24
+
25
+ Extends AgentState with tool call tracking fields.
26
+
27
+ The count fields are dictionaries mapping tool names to execution counts.
28
+ This allows multiple middleware instances to track different tools independently.
29
+ The special key "__all__" is used for tracking all tool calls globally.
30
+ """
31
+
32
+ thread_tool_call_count: NotRequired[Annotated[dict[str, int], PrivateStateAttr]]
33
+ run_tool_call_count: NotRequired[Annotated[dict[str, int], UntrackedValue, PrivateStateAttr]]
34
+
35
+
15
36
  def _count_tool_calls_in_messages(messages: list[AnyMessage], tool_name: str | None = None) -> int:
16
37
  """Count tool calls in a list of messages.
17
38
 
@@ -124,18 +145,18 @@ class ToolCallLimitExceededError(Exception):
124
145
  super().__init__(msg)
125
146
 
126
147
 
127
- class ToolCallLimitMiddleware(AgentMiddleware):
148
+ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
128
149
  """Middleware that tracks tool call counts and enforces limits.
129
150
 
130
151
  This middleware monitors the number of tool calls made during agent execution
131
152
  and can terminate the agent when specified limits are reached. It supports
132
153
  both thread-level and run-level call counting with configurable exit behaviors.
133
154
 
134
- Thread-level: The middleware counts all tool calls in the entire message history
135
- and persists this count across multiple runs (invocations) of the agent.
155
+ Thread-level: The middleware tracks the total number of tool calls and persists
156
+ call count across multiple runs (invocations) of the agent.
136
157
 
137
- Run-level: The middleware counts tool calls made after the last HumanMessage,
138
- representing the current run (invocation) of the agent.
158
+ Run-level: The middleware tracks the number of tool calls made during a single
159
+ run (invocation) of the agent.
139
160
 
140
161
  Example:
141
162
  ```python
@@ -157,6 +178,8 @@ class ToolCallLimitMiddleware(AgentMiddleware):
157
178
  ```
158
179
  """
159
180
 
181
+ state_schema = ToolCallLimitState
182
+
160
183
  def __init__(
161
184
  self,
162
185
  *,
@@ -211,11 +234,11 @@ class ToolCallLimitMiddleware(AgentMiddleware):
211
234
  return base_name
212
235
 
213
236
  @hook_config(can_jump_to=["end"])
214
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
237
+ def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
215
238
  """Check tool call limits before making a model call.
216
239
 
217
240
  Args:
218
- state: The current agent state containing messages.
241
+ state: The current agent state containing tool call counts.
219
242
  runtime: The langgraph runtime.
220
243
 
221
244
  Returns:
@@ -226,14 +249,14 @@ class ToolCallLimitMiddleware(AgentMiddleware):
226
249
  ToolCallLimitExceededError: If limits are exceeded and exit_behavior
227
250
  is "error".
228
251
  """
229
- messages = state.get("messages", [])
252
+ # Get the count key for this middleware instance
253
+ count_key = self.tool_name if self.tool_name else "__all__"
230
254
 
231
- # Count tool calls in entire thread
232
- thread_count = _count_tool_calls_in_messages(messages, self.tool_name)
255
+ thread_counts = state.get("thread_tool_call_count", {})
256
+ run_counts = state.get("run_tool_call_count", {})
233
257
 
234
- # Count tool calls in current run (after last HumanMessage)
235
- run_messages = _get_run_messages(messages)
236
- run_count = _count_tool_calls_in_messages(run_messages, self.tool_name)
258
+ thread_count = thread_counts.get(count_key, 0)
259
+ run_count = run_counts.get(count_key, 0)
237
260
 
238
261
  # Check if any limits are exceeded
239
262
  thread_limit_exceeded = self.thread_limit is not None and thread_count >= self.thread_limit
@@ -258,3 +281,53 @@ class ToolCallLimitMiddleware(AgentMiddleware):
258
281
  return {"jump_to": "end", "messages": [limit_ai_message]}
259
282
 
260
283
  return None
284
+
285
+ def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
286
+ """Increment tool call counts after a model call (when tool calls are made).
287
+
288
+ Args:
289
+ state: The current agent state.
290
+ runtime: The langgraph runtime.
291
+
292
+ Returns:
293
+ State updates with incremented tool call counts if tool calls were made.
294
+ """
295
+ # Get the last AIMessage to check for tool calls
296
+ messages = state.get("messages", [])
297
+ if not messages:
298
+ return None
299
+
300
+ # Find the last AIMessage
301
+ last_ai_message = None
302
+ for message in reversed(messages):
303
+ if isinstance(message, AIMessage):
304
+ last_ai_message = message
305
+ break
306
+
307
+ if not last_ai_message or not last_ai_message.tool_calls:
308
+ return None
309
+
310
+ # Count relevant tool calls (filter by tool_name if specified)
311
+ tool_call_count = 0
312
+ for tool_call in last_ai_message.tool_calls:
313
+ if self.tool_name is None or tool_call["name"] == self.tool_name:
314
+ tool_call_count += 1
315
+
316
+ if tool_call_count == 0:
317
+ return None
318
+
319
+ # Get the count key for this middleware instance
320
+ count_key = self.tool_name if self.tool_name else "__all__"
321
+
322
+ # Get current counts
323
+ thread_counts = state.get("thread_tool_call_count", {}).copy()
324
+ run_counts = state.get("run_tool_call_count", {}).copy()
325
+
326
+ # Increment counts for this key
327
+ thread_counts[count_key] = thread_counts.get(count_key, 0) + tool_call_count
328
+ run_counts[count_key] = run_counts.get(count_key, 0) + tool_call_count
329
+
330
+ return {
331
+ "thread_tool_call_count": thread_counts,
332
+ "run_tool_call_count": run_counts,
333
+ }
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from collections.abc import Awaitable, Callable
6
- from dataclasses import dataclass, field
6
+ from dataclasses import dataclass, field, replace
7
7
  from inspect import iscoroutinefunction
8
8
  from typing import (
9
9
  TYPE_CHECKING,
@@ -21,16 +21,15 @@ if TYPE_CHECKING:
21
21
 
22
22
  from langchain.tools.tool_node import ToolCallRequest
23
23
 
24
- # needed as top level import for pydantic schema generation on AgentState
24
+ # Needed as top level import for Pydantic schema generation on AgentState
25
25
  from typing import TypeAlias
26
26
 
27
27
  from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage # noqa: TC002
28
28
  from langgraph.channels.ephemeral_value import EphemeralValue
29
- from langgraph.channels.untracked_value import UntrackedValue
30
29
  from langgraph.graph.message import add_messages
31
30
  from langgraph.types import Command # noqa: TC002
32
31
  from langgraph.typing import ContextT
33
- from typing_extensions import NotRequired, Required, TypedDict, TypeVar
32
+ from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack
34
33
 
35
34
  if TYPE_CHECKING:
36
35
  from langchain_core.language_models.chat_models import BaseChatModel
@@ -62,6 +61,18 @@ JumpTo = Literal["tools", "model", "end"]
62
61
  ResponseT = TypeVar("ResponseT")
63
62
 
64
63
 
64
+ class _ModelRequestOverrides(TypedDict, total=False):
65
+ """Possible overrides for ModelRequest.override() method."""
66
+
67
+ model: BaseChatModel
68
+ system_prompt: str | None
69
+ messages: list[AnyMessage]
70
+ tool_choice: Any | None
71
+ tools: list[BaseTool | dict]
72
+ response_format: ResponseFormat | None
73
+ model_settings: dict[str, Any]
74
+
75
+
65
76
  @dataclass
66
77
  class ModelRequest:
67
78
  """Model request information for the agent."""
@@ -76,6 +87,36 @@ class ModelRequest:
76
87
  runtime: Runtime[ContextT] # type: ignore[valid-type]
77
88
  model_settings: dict[str, Any] = field(default_factory=dict)
78
89
 
90
+ def override(self, **overrides: Unpack[_ModelRequestOverrides]) -> ModelRequest:
91
+ """Replace the request with a new request with the given overrides.
92
+
93
+ Returns a new `ModelRequest` instance with the specified attributes replaced.
94
+ This follows an immutable pattern, leaving the original request unchanged.
95
+
96
+ Args:
97
+ **overrides: Keyword arguments for attributes to override. Supported keys:
98
+ - model: BaseChatModel instance
99
+ - system_prompt: Optional system prompt string
100
+ - messages: List of messages
101
+ - tool_choice: Tool choice configuration
102
+ - tools: List of available tools
103
+ - response_format: Response format specification
104
+ - model_settings: Additional model settings
105
+
106
+ Returns:
107
+ New ModelRequest instance with specified overrides applied.
108
+
109
+ Examples:
110
+ ```python
111
+ # Create a new request with different model
112
+ new_request = request.override(model=different_model)
113
+
114
+ # Override multiple attributes
115
+ new_request = request.override(system_prompt="New instructions", tool_choice="auto")
116
+ ```
117
+ """
118
+ return replace(self, **overrides)
119
+
79
120
 
80
121
  @dataclass
81
122
  class ModelResponse:
@@ -129,8 +170,6 @@ class AgentState(TypedDict, Generic[ResponseT]):
129
170
  messages: Required[Annotated[list[AnyMessage], add_messages]]
130
171
  jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
131
172
  structured_response: NotRequired[Annotated[ResponseT, OmitFromInput]]
132
- thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
133
- run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
134
173
 
135
174
 
136
175
  class PublicAgentState(TypedDict, Generic[ResponseT]):
@@ -263,18 +302,35 @@ class AgentMiddleware(Generic[StateT, ContextT]):
263
302
  return AIMessage(content="Simplified response")
264
303
  ```
265
304
  """
266
- raise NotImplementedError
305
+ msg = (
306
+ "Synchronous implementation of wrap_model_call is not available. "
307
+ "You are likely encountering this error because you defined only the async version "
308
+ "(awrap_model_call) and invoked your agent in a synchronous context "
309
+ "(e.g., using `stream()` or `invoke()`). "
310
+ "To resolve this, either: "
311
+ "(1) subclass AgentMiddleware and implement the synchronous wrap_model_call method, "
312
+ "(2) use the @wrap_model_call decorator on a standalone sync function, or "
313
+ "(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
314
+ )
315
+ raise NotImplementedError(msg)
267
316
 
268
317
  async def awrap_model_call(
269
318
  self,
270
319
  request: ModelRequest,
271
320
  handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
272
321
  ) -> ModelCallResult:
273
- """Async version of wrap_model_call.
322
+ """Intercept and control async model execution via handler callback.
323
+
324
+ The handler callback executes the model request and returns a ModelResponse.
325
+ Middleware can call the handler multiple times for retry logic, skip calling
326
+ it to short-circuit, or modify the request/response. Multiple middleware
327
+ compose with first in list as outermost layer.
274
328
 
275
329
  Args:
276
330
  request: Model request to execute (includes state and runtime).
277
- handler: Async callback that executes the model request.
331
+ handler: Async callback that executes the model request and returns ModelResponse.
332
+ Call this to execute the model. Can be called multiple times
333
+ for retry logic. Can skip calling it to short-circuit.
278
334
 
279
335
  Returns:
280
336
  ModelCallResult
@@ -291,7 +347,17 @@ class AgentMiddleware(Generic[StateT, ContextT]):
291
347
  raise
292
348
  ```
293
349
  """
294
- raise NotImplementedError
350
+ msg = (
351
+ "Asynchronous implementation of awrap_model_call is not available. "
352
+ "You are likely encountering this error because you defined only the sync version "
353
+ "(wrap_model_call) and invoked your agent in an asynchronous context "
354
+ "(e.g., using `astream()` or `ainvoke()`). "
355
+ "To resolve this, either: "
356
+ "(1) subclass AgentMiddleware and implement the asynchronous awrap_model_call method, "
357
+ "(2) use the @wrap_model_call decorator on a standalone async function, or "
358
+ "(3) invoke your agent synchronously using `stream()` or `invoke()`."
359
+ )
360
+ raise NotImplementedError(msg)
295
361
 
296
362
  def after_agent(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
297
363
  """Logic to run after the agent execution completes."""
@@ -353,7 +419,77 @@ class AgentMiddleware(Generic[StateT, ContextT]):
353
419
  continue
354
420
  return result
355
421
  """
356
- raise NotImplementedError
422
+ msg = (
423
+ "Synchronous implementation of wrap_tool_call is not available. "
424
+ "You are likely encountering this error because you defined only the async version "
425
+ "(awrap_tool_call) and invoked your agent in a synchronous context "
426
+ "(e.g., using `stream()` or `invoke()`). "
427
+ "To resolve this, either: "
428
+ "(1) subclass AgentMiddleware and implement the synchronous wrap_tool_call method, "
429
+ "(2) use the @wrap_tool_call decorator on a standalone sync function, or "
430
+ "(3) invoke your agent asynchronously using `astream()` or `ainvoke()`."
431
+ )
432
+ raise NotImplementedError(msg)
433
+
434
+ async def awrap_tool_call(
435
+ self,
436
+ request: ToolCallRequest,
437
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
438
+ ) -> ToolMessage | Command:
439
+ """Intercept and control async tool execution via handler callback.
440
+
441
+ The handler callback executes the tool call and returns a ToolMessage or Command.
442
+ Middleware can call the handler multiple times for retry logic, skip calling
443
+ it to short-circuit, or modify the request/response. Multiple middleware
444
+ compose with first in list as outermost layer.
445
+
446
+ Args:
447
+ request: Tool call request with call dict, BaseTool, state, and runtime.
448
+ Access state via request.state and runtime via request.runtime.
449
+ handler: Async callable to execute the tool and returns ToolMessage or Command.
450
+ Call this to execute the tool. Can be called multiple times
451
+ for retry logic. Can skip calling it to short-circuit.
452
+
453
+ Returns:
454
+ ToolMessage or Command (the final result).
455
+
456
+ The handler callable can be invoked multiple times for retry logic.
457
+ Each call to handler is independent and stateless.
458
+
459
+ Examples:
460
+ Async retry on error:
461
+ ```python
462
+ async def awrap_tool_call(self, request, handler):
463
+ for attempt in range(3):
464
+ try:
465
+ result = await handler(request)
466
+ if is_valid(result):
467
+ return result
468
+ except Exception:
469
+ if attempt == 2:
470
+ raise
471
+ return result
472
+ ```
473
+
474
+
475
+ async def awrap_tool_call(self, request, handler):
476
+ if cached := await get_cache_async(request):
477
+ return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
478
+ result = await handler(request)
479
+ await save_cache_async(request, result)
480
+ return result
481
+ """
482
+ msg = (
483
+ "Asynchronous implementation of awrap_tool_call is not available. "
484
+ "You are likely encountering this error because you defined only the sync version "
485
+ "(wrap_tool_call) and invoked your agent in an asynchronous context "
486
+ "(e.g., using `astream()` or `ainvoke()`). "
487
+ "To resolve this, either: "
488
+ "(1) subclass AgentMiddleware and implement the asynchronous awrap_tool_call method, "
489
+ "(2) use the @wrap_tool_call decorator on a standalone async function, or "
490
+ "(3) invoke your agent synchronously using `stream()` or `invoke()`."
491
+ )
492
+ raise NotImplementedError(msg)
357
493
 
358
494
 
359
495
  class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
@@ -1104,6 +1240,16 @@ def dynamic_prompt(
1104
1240
  request.system_prompt = prompt
1105
1241
  return handler(request)
1106
1242
 
1243
+ async def async_wrapped_from_sync(
1244
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
1245
+ request: ModelRequest,
1246
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
1247
+ ) -> ModelCallResult:
1248
+ # Delegate to sync function
1249
+ prompt = cast("str", func(request))
1250
+ request.system_prompt = prompt
1251
+ return await handler(request)
1252
+
1107
1253
  middleware_name = cast("str", getattr(func, "__name__", "DynamicPromptMiddleware"))
1108
1254
 
1109
1255
  return type(
@@ -1113,6 +1259,7 @@ def dynamic_prompt(
1113
1259
  "state_schema": AgentState,
1114
1260
  "tools": [],
1115
1261
  "wrap_model_call": wrapped,
1262
+ "awrap_model_call": async_wrapped_from_sync,
1116
1263
  },
1117
1264
  )()
1118
1265
 
@@ -1309,6 +1456,7 @@ def wrap_tool_call(
1309
1456
  Args:
1310
1457
  func: Function accepting (request, handler) that calls
1311
1458
  handler(request) to execute the tool and returns final ToolMessage or Command.
1459
+ Can be sync or async.
1312
1460
  tools: Additional tools to register with this middleware.
1313
1461
  name: Middleware class name. Defaults to function name.
1314
1462
 
@@ -1316,13 +1464,6 @@ def wrap_tool_call(
1316
1464
  AgentMiddleware instance if func provided, otherwise a decorator.
1317
1465
 
1318
1466
  Examples:
1319
- Basic passthrough:
1320
- ```python
1321
- @wrap_tool_call
1322
- def passthrough(request, handler):
1323
- return handler(request)
1324
- ```
1325
-
1326
1467
  Retry logic:
1327
1468
  ```python
1328
1469
  @wrap_tool_call
@@ -1336,6 +1477,18 @@ def wrap_tool_call(
1336
1477
  raise
1337
1478
  ```
1338
1479
 
1480
+ Async retry logic:
1481
+ ```python
1482
+ @wrap_tool_call
1483
+ async def async_retry(request, handler):
1484
+ for attempt in range(3):
1485
+ try:
1486
+ return await handler(request)
1487
+ except Exception:
1488
+ if attempt == 2:
1489
+ raise
1490
+ ```
1491
+
1339
1492
  Modify request:
1340
1493
  ```python
1341
1494
  @wrap_tool_call
@@ -1359,6 +1512,31 @@ def wrap_tool_call(
1359
1512
  def decorator(
1360
1513
  func: _CallableReturningToolResponse,
1361
1514
  ) -> AgentMiddleware:
1515
+ is_async = iscoroutinefunction(func)
1516
+
1517
+ if is_async:
1518
+
1519
+ async def async_wrapped(
1520
+ self: AgentMiddleware, # noqa: ARG001
1521
+ request: ToolCallRequest,
1522
+ handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
1523
+ ) -> ToolMessage | Command:
1524
+ return await func(request, handler) # type: ignore[arg-type,misc]
1525
+
1526
+ middleware_name = name or cast(
1527
+ "str", getattr(func, "__name__", "WrapToolCallMiddleware")
1528
+ )
1529
+
1530
+ return type(
1531
+ middleware_name,
1532
+ (AgentMiddleware,),
1533
+ {
1534
+ "state_schema": AgentState,
1535
+ "tools": tools or [],
1536
+ "awrap_tool_call": async_wrapped,
1537
+ },
1538
+ )()
1539
+
1362
1540
  def wrapped(
1363
1541
  self: AgentMiddleware, # noqa: ARG001
1364
1542
  request: ToolCallRequest,
@@ -3,10 +3,8 @@
3
3
  from langchain_core.embeddings import Embeddings
4
4
 
5
5
  from langchain.embeddings.base import init_embeddings
6
- from langchain.embeddings.cache import CacheBackedEmbeddings
7
6
 
8
7
  __all__ = [
9
- "CacheBackedEmbeddings",
10
8
  "Embeddings",
11
9
  "init_embeddings",
12
10
  ]
@@ -3,29 +3,61 @@
3
3
  from langchain_core.messages import (
4
4
  AIMessage,
5
5
  AIMessageChunk,
6
+ Annotation,
6
7
  AnyMessage,
8
+ AudioContentBlock,
9
+ Citation,
10
+ ContentBlock,
11
+ DataContentBlock,
12
+ FileContentBlock,
7
13
  HumanMessage,
14
+ ImageContentBlock,
8
15
  InvalidToolCall,
9
16
  MessageLikeRepresentation,
17
+ NonStandardAnnotation,
18
+ NonStandardContentBlock,
19
+ PlainTextContentBlock,
20
+ ReasoningContentBlock,
10
21
  RemoveMessage,
22
+ ServerToolCall,
23
+ ServerToolCallChunk,
24
+ ServerToolResult,
11
25
  SystemMessage,
26
+ TextContentBlock,
12
27
  ToolCall,
13
28
  ToolCallChunk,
14
29
  ToolMessage,
30
+ VideoContentBlock,
15
31
  trim_messages,
16
32
  )
17
33
 
18
34
  __all__ = [
19
35
  "AIMessage",
20
36
  "AIMessageChunk",
37
+ "Annotation",
21
38
  "AnyMessage",
39
+ "AudioContentBlock",
40
+ "Citation",
41
+ "ContentBlock",
42
+ "DataContentBlock",
43
+ "FileContentBlock",
22
44
  "HumanMessage",
45
+ "ImageContentBlock",
23
46
  "InvalidToolCall",
24
47
  "MessageLikeRepresentation",
48
+ "NonStandardAnnotation",
49
+ "NonStandardContentBlock",
50
+ "PlainTextContentBlock",
51
+ "ReasoningContentBlock",
25
52
  "RemoveMessage",
53
+ "ServerToolCall",
54
+ "ServerToolCallChunk",
55
+ "ServerToolResult",
26
56
  "SystemMessage",
57
+ "TextContentBlock",
27
58
  "ToolCall",
28
59
  "ToolCallChunk",
29
60
  "ToolMessage",
61
+ "VideoContentBlock",
30
62
  "trim_messages",
31
63
  ]
@@ -8,11 +8,7 @@ from langchain_core.tools import (
8
8
  tool,
9
9
  )
10
10
 
11
- from langchain.tools.tool_node import (
12
- InjectedState,
13
- InjectedStore,
14
- ToolNode,
15
- )
11
+ from langchain.tools.tool_node import InjectedState, InjectedStore
16
12
 
17
13
  __all__ = [
18
14
  "BaseTool",
@@ -21,6 +17,5 @@ __all__ = [
21
17
  "InjectedToolArg",
22
18
  "InjectedToolCallId",
23
19
  "ToolException",
24
- "ToolNode",
25
20
  "tool",
26
21
  ]