langchain 1.0.0a13__py3-none-any.whl → 1.0.0a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

langchain/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  """Main entrypoint into LangChain."""
2
2
 
3
- __version__ = "1.0.0a13"
3
+ __version__ = "1.0.0a14"
@@ -13,9 +13,6 @@ from typing import (
13
13
  get_type_hints,
14
14
  )
15
15
 
16
- if TYPE_CHECKING:
17
- from collections.abc import Awaitable
18
-
19
16
  from langchain_core.language_models.chat_models import BaseChatModel
20
17
  from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
21
18
  from langchain_core.tools import BaseTool
@@ -47,11 +44,10 @@ from langchain.agents.structured_output import (
47
44
  ToolStrategy,
48
45
  )
49
46
  from langchain.chat_models import init_chat_model
50
- from langchain.tools import ToolNode
51
- from langchain.tools.tool_node import ToolCallWithContext
47
+ from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
52
48
 
53
49
  if TYPE_CHECKING:
54
- from collections.abc import Callable, Sequence
50
+ from collections.abc import Awaitable, Callable, Sequence
55
51
 
56
52
  from langchain_core.runnables import Runnable
57
53
  from langgraph.cache.base import BaseCache
@@ -449,6 +445,70 @@ def _chain_tool_call_wrappers(
449
445
  return result
450
446
 
451
447
 
448
+ def _chain_async_tool_call_wrappers(
449
+ wrappers: Sequence[
450
+ Callable[
451
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
452
+ Awaitable[ToolMessage | Command],
453
+ ]
454
+ ],
455
+ ) -> (
456
+ Callable[
457
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
458
+ Awaitable[ToolMessage | Command],
459
+ ]
460
+ | None
461
+ ):
462
+ """Compose async wrappers into middleware stack (first = outermost).
463
+
464
+ Args:
465
+ wrappers: Async wrappers in middleware order.
466
+
467
+ Returns:
468
+ Composed async wrapper, or None if empty.
469
+ """
470
+ if not wrappers:
471
+ return None
472
+
473
+ if len(wrappers) == 1:
474
+ return wrappers[0]
475
+
476
+ def compose_two(
477
+ outer: Callable[
478
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
479
+ Awaitable[ToolMessage | Command],
480
+ ],
481
+ inner: Callable[
482
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
483
+ Awaitable[ToolMessage | Command],
484
+ ],
485
+ ) -> Callable[
486
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
487
+ Awaitable[ToolMessage | Command],
488
+ ]:
489
+ """Compose two async wrappers where outer wraps inner."""
490
+
491
+ async def composed(
492
+ request: ToolCallRequest,
493
+ execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
494
+ ) -> ToolMessage | Command:
495
+ # Create an async callable that invokes inner with the original execute
496
+ async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
497
+ return await inner(req, execute)
498
+
499
+ # Outer can call call_inner multiple times
500
+ return await outer(request, call_inner)
501
+
502
+ return composed
503
+
504
+ # Chain all wrappers: first -> second -> ... -> last
505
+ result = wrappers[-1]
506
+ for wrapper in reversed(wrappers[:-1]):
507
+ result = compose_two(wrapper, result)
508
+
509
+ return result
510
+
511
+
452
512
  def create_agent( # noqa: PLR0915
453
513
  model: str | BaseChatModel,
454
514
  tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
@@ -477,13 +537,12 @@ def create_agent( # noqa: PLR0915
477
537
  (e.g., `"openai:gpt-4"`), a chat model instance (e.g., `ChatOpenAI()`).
478
538
  tools: A list of tools, dicts, or callables. If `None` or an empty list,
479
539
  the agent will consist of a model node without a tool calling loop.
480
- system_prompt: An optional system prompt for the LLM. If provided as a string,
481
- it will be converted to a SystemMessage and added to the beginning
482
- of the message list.
540
+ system_prompt: An optional system prompt for the LLM. Prompts are converted to a
541
+ `SystemMessage` and added to the beginning of the message list.
483
542
  middleware: A sequence of middleware instances to apply to the agent.
484
543
  Middleware can intercept and modify agent behavior at various stages.
485
544
  response_format: An optional configuration for structured responses.
486
- Can be a ToolStrategy, ProviderStrategy, or a Pydantic model class.
545
+ Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
487
546
  If provided, the agent will handle structured output during the
488
547
  conversation flow. Raw schemas will be wrapped in an appropriate strategy
489
548
  based on model capabilities.
@@ -500,14 +559,14 @@ def create_agent( # noqa: PLR0915
500
559
  This is useful if you want to return directly or run additional processing
501
560
  on an output.
502
561
  debug: A flag indicating whether to enable debug mode.
503
- name: An optional name for the CompiledStateGraph.
562
+ name: An optional name for the `CompiledStateGraph`.
504
563
  This name will be automatically used when adding the agent graph to
505
564
  another graph as a subgraph node - particularly useful for building
506
565
  multi-agent systems.
507
- cache: An optional BaseCache instance to enable caching of graph execution.
566
+ cache: An optional `BaseCache` instance to enable caching of graph execution.
508
567
 
509
568
  Returns:
510
- A compiled StateGraph that can be used for chat interactions.
569
+ A compiled `StateGraph` that can be used for chat interactions.
511
570
 
512
571
  The agent node calls the language model with the messages list (after applying
513
572
  the system prompt). If the resulting AIMessage contains `tool_calls`, the graph will
@@ -576,9 +635,14 @@ def create_agent( # noqa: PLR0915
576
635
  structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
577
636
  middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
578
637
 
579
- # Collect middleware with wrap_tool_call hooks
638
+ # Collect middleware with wrap_tool_call or awrap_tool_call hooks
639
+ # Include middleware with either implementation to ensure NotImplementedError is raised
640
+ # when middleware doesn't support the execution path
580
641
  middleware_w_wrap_tool_call = [
581
- m for m in middleware if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
642
+ m
643
+ for m in middleware
644
+ if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
645
+ or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
582
646
  ]
583
647
 
584
648
  # Chain all wrap_tool_call handlers into a single composed handler
@@ -587,8 +651,24 @@ def create_agent( # noqa: PLR0915
587
651
  wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
588
652
  wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
589
653
 
654
+ # Collect middleware with awrap_tool_call or wrap_tool_call hooks
655
+ # Include middleware with either implementation to ensure NotImplementedError is raised
656
+ # when middleware doesn't support the execution path
657
+ middleware_w_awrap_tool_call = [
658
+ m
659
+ for m in middleware
660
+ if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
661
+ or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
662
+ ]
663
+
664
+ # Chain all awrap_tool_call handlers into a single composed async handler
665
+ awrap_tool_call_wrapper = None
666
+ if middleware_w_awrap_tool_call:
667
+ async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
668
+ awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
669
+
590
670
  # Setup tools
591
- tool_node: ToolNode | None = None
671
+ tool_node: _ToolNode | None = None
592
672
  # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
593
673
  built_in_tools = [t for t in tools if isinstance(t, dict)]
594
674
  regular_tools = [t for t in tools if not isinstance(t, dict)]
@@ -598,7 +678,11 @@ def create_agent( # noqa: PLR0915
598
678
 
599
679
  # Only create ToolNode if we have client-side tools
600
680
  tool_node = (
601
- ToolNode(tools=available_tools, wrap_tool_call=wrap_tool_call_wrapper)
681
+ _ToolNode(
682
+ tools=available_tools,
683
+ wrap_tool_call=wrap_tool_call_wrapper,
684
+ awrap_tool_call=awrap_tool_call_wrapper,
685
+ )
602
686
  if available_tools
603
687
  else None
604
688
  )
@@ -640,13 +724,23 @@ def create_agent( # noqa: PLR0915
640
724
  if m.__class__.after_agent is not AgentMiddleware.after_agent
641
725
  or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
642
726
  ]
727
+ # Collect middleware with wrap_model_call or awrap_model_call hooks
728
+ # Include middleware with either implementation to ensure NotImplementedError is raised
729
+ # when middleware doesn't support the execution path
643
730
  middleware_w_wrap_model_call = [
644
- m for m in middleware if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
731
+ m
732
+ for m in middleware
733
+ if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
734
+ or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
645
735
  ]
736
+ # Collect middleware with awrap_model_call or wrap_model_call hooks
737
+ # Include middleware with either implementation to ensure NotImplementedError is raised
738
+ # when middleware doesn't support the execution path
646
739
  middleware_w_awrap_model_call = [
647
740
  m
648
741
  for m in middleware
649
742
  if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
743
+ or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
650
744
  ]
651
745
 
652
746
  # Compose wrap_model_call handlers into a single middleware stack (sync)
@@ -937,11 +1031,7 @@ def create_agent( # noqa: PLR0915
937
1031
  if response.structured_response is not None:
938
1032
  state_updates["structured_response"] = response.structured_response
939
1033
 
940
- return {
941
- "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
942
- "run_model_call_count": state.get("run_model_call_count", 0) + 1,
943
- **state_updates,
944
- }
1034
+ return state_updates
945
1035
 
946
1036
  async def _execute_model_async(request: ModelRequest) -> ModelResponse:
947
1037
  """Execute model asynchronously and return response.
@@ -992,11 +1082,7 @@ def create_agent( # noqa: PLR0915
992
1082
  if response.structured_response is not None:
993
1083
  state_updates["structured_response"] = response.structured_response
994
1084
 
995
- return {
996
- "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
997
- "run_model_call_count": state.get("run_model_call_count", 0) + 1,
998
- **state_updates,
999
- }
1085
+ return state_updates
1000
1086
 
1001
1087
  # Use sync or async based on model capabilities
1002
1088
  graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
@@ -1378,7 +1464,7 @@ def _make_model_to_model_edge(
1378
1464
 
1379
1465
  def _make_tools_to_model_edge(
1380
1466
  *,
1381
- tool_node: ToolNode,
1467
+ tool_node: _ToolNode,
1382
1468
  model_destination: str,
1383
1469
  structured_output_tools: dict[str, OutputToolBinding],
1384
1470
  end_destination: str,
@@ -11,9 +11,8 @@ from .human_in_the_loop import (
11
11
  from .model_call_limit import ModelCallLimitMiddleware
12
12
  from .model_fallback import ModelFallbackMiddleware
13
13
  from .pii import PIIDetectionError, PIIMiddleware
14
- from .planning import PlanningMiddleware
15
- from .prompt_caching import AnthropicPromptCachingMiddleware
16
14
  from .summarization import SummarizationMiddleware
15
+ from .todo import TodoListMiddleware
17
16
  from .tool_call_limit import ToolCallLimitMiddleware
18
17
  from .tool_emulator import LLMToolEmulator
19
18
  from .tool_selection import LLMToolSelectorMiddleware
@@ -21,6 +20,7 @@ from .types import (
21
20
  AgentMiddleware,
22
21
  AgentState,
23
22
  ModelRequest,
23
+ ModelResponse,
24
24
  after_agent,
25
25
  after_model,
26
26
  before_agent,
@@ -28,13 +28,12 @@ from .types import (
28
28
  dynamic_prompt,
29
29
  hook_config,
30
30
  wrap_model_call,
31
+ wrap_tool_call,
31
32
  )
32
33
 
33
34
  __all__ = [
34
35
  "AgentMiddleware",
35
36
  "AgentState",
36
- # should move to langchain-anthropic if we decide to keep it
37
- "AnthropicPromptCachingMiddleware",
38
37
  "ClearToolUsesEdit",
39
38
  "ContextEditingMiddleware",
40
39
  "HumanInTheLoopMiddleware",
@@ -44,10 +43,11 @@ __all__ = [
44
43
  "ModelCallLimitMiddleware",
45
44
  "ModelFallbackMiddleware",
46
45
  "ModelRequest",
46
+ "ModelResponse",
47
47
  "PIIDetectionError",
48
48
  "PIIMiddleware",
49
- "PlanningMiddleware",
50
49
  "SummarizationMiddleware",
50
+ "TodoListMiddleware",
51
51
  "ToolCallLimitMiddleware",
52
52
  "after_agent",
53
53
  "after_model",
@@ -56,4 +56,5 @@ __all__ = [
56
56
  "dynamic_prompt",
57
57
  "hook_config",
58
58
  "wrap_model_call",
59
+ "wrap_tool_call",
59
60
  ]
@@ -8,7 +8,7 @@ with any LangChain chat model.
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from collections.abc import Callable, Iterable, Sequence
11
+ from collections.abc import Awaitable, Callable, Iterable, Sequence
12
12
  from dataclasses import dataclass
13
13
  from typing import Literal
14
14
 
@@ -239,6 +239,34 @@ class ContextEditingMiddleware(AgentMiddleware):
239
239
 
240
240
  return handler(request)
241
241
 
242
+ async def awrap_model_call(
243
+ self,
244
+ request: ModelRequest,
245
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
246
+ ) -> ModelCallResult:
247
+ """Apply context edits before invoking the model via handler (async version)."""
248
+ if not request.messages:
249
+ return await handler(request)
250
+
251
+ if self.token_count_method == "approximate": # noqa: S105
252
+
253
+ def count_tokens(messages: Sequence[BaseMessage]) -> int:
254
+ return count_tokens_approximately(messages)
255
+ else:
256
+ system_msg = (
257
+ [SystemMessage(content=request.system_prompt)] if request.system_prompt else []
258
+ )
259
+
260
+ def count_tokens(messages: Sequence[BaseMessage]) -> int:
261
+ return request.model.get_num_tokens_from_messages(
262
+ system_msg + list(messages), request.tools
263
+ )
264
+
265
+ for edit in self.edits:
266
+ edit.apply(request.messages, count_tokens=count_tokens)
267
+
268
+ return await handler(request)
269
+
242
270
 
243
271
  __all__ = [
244
272
  "ClearToolUsesEdit",
@@ -11,23 +11,23 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState
11
11
 
12
12
 
13
13
  class Action(TypedDict):
14
- """Represents an action with a name and arguments."""
14
+ """Represents an action with a name and args."""
15
15
 
16
16
  name: str
17
17
  """The type or name of action being requested (e.g., "add_numbers")."""
18
18
 
19
- arguments: dict[str, Any]
20
- """Key-value pairs of arguments needed for the action (e.g., {"a": 1, "b": 2})."""
19
+ args: dict[str, Any]
20
+ """Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
21
21
 
22
22
 
23
23
  class ActionRequest(TypedDict):
24
- """Represents an action request with a name, arguments, and description."""
24
+ """Represents an action request with a name, args, and description."""
25
25
 
26
26
  name: str
27
27
  """The name of the action being requested."""
28
28
 
29
- arguments: dict[str, Any]
30
- """Key-value pairs of arguments needed for the action (e.g., {"a": 1, "b": 2})."""
29
+ args: dict[str, Any]
30
+ """Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
31
31
 
32
32
  description: NotRequired[str]
33
33
  """The description of the action to be reviewed."""
@@ -45,8 +45,8 @@ class ReviewConfig(TypedDict):
45
45
  allowed_decisions: list[DecisionType]
46
46
  """The decisions that are allowed for this request."""
47
47
 
48
- arguments_schema: NotRequired[dict[str, Any]]
49
- """JSON schema for the arguments associated with the action, if edits are allowed."""
48
+ args_schema: NotRequired[dict[str, Any]]
49
+ """JSON schema for the args associated with the action, if edits are allowed."""
50
50
 
51
51
 
52
52
  class HITLRequest(TypedDict):
@@ -150,8 +150,8 @@ class InterruptOnConfig(TypedDict):
150
150
  )
151
151
  ```
152
152
  """
153
- arguments_schema: NotRequired[dict[str, Any]]
154
- """JSON schema for the arguments associated with the action, if edits are allowed."""
153
+ args_schema: NotRequired[dict[str, Any]]
154
+ """JSON schema for the args associated with the action, if edits are allowed."""
155
155
 
156
156
 
157
157
  class HumanInTheLoopMiddleware(AgentMiddleware):
@@ -214,12 +214,12 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
214
214
  # Create ActionRequest with description
215
215
  action_request = ActionRequest(
216
216
  name=tool_name,
217
- arguments=tool_args,
217
+ args=tool_args,
218
218
  description=description,
219
219
  )
220
220
 
221
221
  # Create ReviewConfig
222
- # eventually can get tool information and populate arguments_schema from there
222
+ # eventually can get tool information and populate args_schema from there
223
223
  review_config = ReviewConfig(
224
224
  action_name=tool_name,
225
225
  allowed_decisions=config["allowed_decisions"],
@@ -244,7 +244,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
244
244
  ToolCall(
245
245
  type="tool_call",
246
246
  name=edited_action["name"],
247
- args=edited_action["arguments"],
247
+ args=edited_action["args"],
248
248
  id=tool_call["id"],
249
249
  ),
250
250
  None,
@@ -2,16 +2,33 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Any, Literal
5
+ from typing import TYPE_CHECKING, Annotated, Any, Literal
6
6
 
7
7
  from langchain_core.messages import AIMessage
8
+ from langgraph.channels.untracked_value import UntrackedValue
9
+ from typing_extensions import NotRequired
8
10
 
9
- from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
11
+ from langchain.agents.middleware.types import (
12
+ AgentMiddleware,
13
+ AgentState,
14
+ PrivateStateAttr,
15
+ hook_config,
16
+ )
10
17
 
11
18
  if TYPE_CHECKING:
12
19
  from langgraph.runtime import Runtime
13
20
 
14
21
 
22
+ class ModelCallLimitState(AgentState):
23
+ """State schema for ModelCallLimitMiddleware.
24
+
25
+ Extends AgentState with model call tracking fields.
26
+ """
27
+
28
+ thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
29
+ run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
30
+
31
+
15
32
  def _build_limit_exceeded_message(
16
33
  thread_count: int,
17
34
  run_count: int,
@@ -69,7 +86,7 @@ class ModelCallLimitExceededError(Exception):
69
86
  super().__init__(msg)
70
87
 
71
88
 
72
- class ModelCallLimitMiddleware(AgentMiddleware):
89
+ class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
73
90
  """Middleware that tracks model call counts and enforces limits.
74
91
 
75
92
  This middleware monitors the number of model calls made during agent execution
@@ -97,6 +114,8 @@ class ModelCallLimitMiddleware(AgentMiddleware):
97
114
  ```
98
115
  """
99
116
 
117
+ state_schema = ModelCallLimitState
118
+
100
119
  def __init__(
101
120
  self,
102
121
  *,
@@ -135,7 +154,7 @@ class ModelCallLimitMiddleware(AgentMiddleware):
135
154
  self.exit_behavior = exit_behavior
136
155
 
137
156
  @hook_config(can_jump_to=["end"])
138
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
157
+ def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
139
158
  """Check model call limits before making a model call.
140
159
 
141
160
  Args:
@@ -175,3 +194,18 @@ class ModelCallLimitMiddleware(AgentMiddleware):
175
194
  return {"jump_to": "end", "messages": [limit_ai_message]}
176
195
 
177
196
  return None
197
+
198
+ def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
199
+ """Increment model call counts after a model call.
200
+
201
+ Args:
202
+ state: The current agent state.
203
+ runtime: The langgraph runtime.
204
+
205
+ Returns:
206
+ State updates with incremented call counts.
207
+ """
208
+ return {
209
+ "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
210
+ "run_model_call_count": state.get("run_model_call_count", 0) + 1,
211
+ }
@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
13
13
  from langchain.chat_models import init_chat_model
14
14
 
15
15
  if TYPE_CHECKING:
16
- from collections.abc import Callable
16
+ from collections.abc import Awaitable, Callable
17
17
 
18
18
  from langchain_core.language_models.chat_models import BaseChatModel
19
19
 
@@ -102,3 +102,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
102
102
  continue
103
103
 
104
104
  raise last_exception
105
+
106
+ async def awrap_model_call(
107
+ self,
108
+ request: ModelRequest,
109
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
110
+ ) -> ModelCallResult:
111
+ """Try fallback models in sequence on errors (async version).
112
+
113
+ Args:
114
+ request: Initial model request.
115
+ handler: Async callback to execute the model.
116
+
117
+ Returns:
118
+ AIMessage from successful model call.
119
+
120
+ Raises:
121
+ Exception: If all models fail, re-raises last exception.
122
+ """
123
+ # Try primary model first
124
+ last_exception: Exception
125
+ try:
126
+ return await handler(request)
127
+ except Exception as e: # noqa: BLE001
128
+ last_exception = e
129
+
130
+ # Try fallback models
131
+ for fallback_model in self.models:
132
+ request.model = fallback_model
133
+ try:
134
+ return await handler(request)
135
+ except Exception as e: # noqa: BLE001
136
+ last_exception = e
137
+ continue
138
+
139
+ raise last_exception
@@ -431,14 +431,12 @@ class PIIMiddleware(AgentMiddleware):
431
431
 
432
432
  Strategy Selection Guide:
433
433
 
434
- ======== =================== =======================================
435
- Strategy Preserves Identity? Best For
436
- ======== =================== =======================================
437
- `block` N/A Avoid PII completely
438
- `redact` No General compliance, log sanitization
439
- `mask` No Human readability, customer service UIs
440
- `hash` Yes (pseudonymous) Analytics, debugging
441
- ======== =================== =======================================
434
+ | Strategy | Preserves Identity? | Best For |
435
+ | -------- | ------------------- | --------------------------------------- |
436
+ | `block` | N/A | Avoid PII completely |
437
+ | `redact` | No | General compliance, log sanitization |
438
+ | `mask` | No | Human readability, customer service UIs |
439
+ | `hash` | Yes (pseudonymous) | Analytics, debugging |
442
440
 
443
441
  Example:
444
442
  ```python
@@ -6,7 +6,7 @@ from __future__ import annotations
6
6
  from typing import TYPE_CHECKING, Annotated, Literal
7
7
 
8
8
  if TYPE_CHECKING:
9
- from collections.abc import Callable
9
+ from collections.abc import Awaitable, Callable
10
10
 
11
11
  from langchain_core.messages import ToolMessage
12
12
  from langchain_core.tools import tool
@@ -126,7 +126,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
126
126
  )
127
127
 
128
128
 
129
- class PlanningMiddleware(AgentMiddleware):
129
+ class TodoListMiddleware(AgentMiddleware):
130
130
  """Middleware that provides todo list management capabilities to agents.
131
131
 
132
132
  This middleware adds a `write_todos` tool that allows agents to create and manage
@@ -139,10 +139,10 @@ class PlanningMiddleware(AgentMiddleware):
139
139
 
140
140
  Example:
141
141
  ```python
142
- from langchain.agents.middleware.planning import PlanningMiddleware
142
+ from langchain.agents.middleware.todo import TodoListMiddleware
143
143
  from langchain.agents import create_agent
144
144
 
145
- agent = create_agent("openai:gpt-4o", middleware=[PlanningMiddleware()])
145
+ agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
146
146
 
147
147
  # Agent now has access to write_todos tool and todo state tracking
148
148
  result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
@@ -165,7 +165,7 @@ class PlanningMiddleware(AgentMiddleware):
165
165
  system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
166
166
  tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
167
167
  ) -> None:
168
- """Initialize the PlanningMiddleware with optional custom prompts.
168
+ """Initialize the TodoListMiddleware with optional custom prompts.
169
169
 
170
170
  Args:
171
171
  system_prompt: Custom system prompt to guide the agent on using the todo tool.
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
204
204
  else self.system_prompt
205
205
  )
206
206
  return handler(request)
207
+
208
+ async def awrap_model_call(
209
+ self,
210
+ request: ModelRequest,
211
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
212
+ ) -> ModelCallResult:
213
+ """Update the system prompt to include the todo system prompt (async version)."""
214
+ request.system_prompt = (
215
+ request.system_prompt + "\n\n" + self.system_prompt
216
+ if request.system_prompt
217
+ else self.system_prompt
218
+ )
219
+ return await handler(request)