langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +7 -1
  3. langchain/agents/factory.py +722 -226
  4. langchain/agents/middleware/__init__.py +36 -9
  5. langchain/agents/middleware/_execution.py +388 -0
  6. langchain/agents/middleware/_redaction.py +350 -0
  7. langchain/agents/middleware/context_editing.py +46 -17
  8. langchain/agents/middleware/file_search.py +382 -0
  9. langchain/agents/middleware/human_in_the_loop.py +220 -173
  10. langchain/agents/middleware/model_call_limit.py +43 -10
  11. langchain/agents/middleware/model_fallback.py +79 -36
  12. langchain/agents/middleware/pii.py +68 -504
  13. langchain/agents/middleware/shell_tool.py +718 -0
  14. langchain/agents/middleware/summarization.py +2 -2
  15. langchain/agents/middleware/{planning.py → todo.py} +35 -16
  16. langchain/agents/middleware/tool_call_limit.py +308 -114
  17. langchain/agents/middleware/tool_emulator.py +200 -0
  18. langchain/agents/middleware/tool_retry.py +384 -0
  19. langchain/agents/middleware/tool_selection.py +25 -21
  20. langchain/agents/middleware/types.py +714 -257
  21. langchain/agents/structured_output.py +37 -27
  22. langchain/chat_models/__init__.py +7 -1
  23. langchain/chat_models/base.py +192 -190
  24. langchain/embeddings/__init__.py +13 -3
  25. langchain/embeddings/base.py +49 -29
  26. langchain/messages/__init__.py +50 -1
  27. langchain/tools/__init__.py +9 -7
  28. langchain/tools/tool_node.py +16 -1174
  29. langchain-1.0.4.dist-info/METADATA +92 -0
  30. langchain-1.0.4.dist-info/RECORD +34 -0
  31. langchain/_internal/__init__.py +0 -0
  32. langchain/_internal/_documents.py +0 -35
  33. langchain/_internal/_lazy_import.py +0 -35
  34. langchain/_internal/_prompts.py +0 -158
  35. langchain/_internal/_typing.py +0 -70
  36. langchain/_internal/_utils.py +0 -7
  37. langchain/agents/_internal/__init__.py +0 -1
  38. langchain/agents/_internal/_typing.py +0 -13
  39. langchain/agents/middleware/prompt_caching.py +0 -86
  40. langchain/documents/__init__.py +0 -7
  41. langchain/embeddings/cache.py +0 -361
  42. langchain/storage/__init__.py +0 -22
  43. langchain/storage/encoder_backed.py +0 -123
  44. langchain/storage/exceptions.py +0 -5
  45. langchain/storage/in_memory.py +0 -13
  46. langchain-1.0.0a12.dist-info/METADATA +0 -122
  47. langchain-1.0.0a12.dist-info/RECORD +0 -43
  48. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
  49. {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -19,18 +19,23 @@ from langchain_core.tools import BaseTool
19
19
  from langgraph._internal._runnable import RunnableCallable
20
20
  from langgraph.constants import END, START
21
21
  from langgraph.graph.state import StateGraph
22
+ from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
22
23
  from langgraph.runtime import Runtime # noqa: TC002
23
- from langgraph.types import Send
24
+ from langgraph.types import Command, Send
24
25
  from langgraph.typing import ContextT # noqa: TC002
25
- from typing_extensions import NotRequired, Required, TypedDict, TypeVar
26
+ from typing_extensions import NotRequired, Required, TypedDict
26
27
 
27
28
  from langchain.agents.middleware.types import (
28
29
  AgentMiddleware,
29
30
  AgentState,
30
31
  JumpTo,
31
32
  ModelRequest,
33
+ ModelResponse,
32
34
  OmitFromSchema,
33
- PublicAgentState,
35
+ ResponseT,
36
+ StateT_co,
37
+ _InputAgentState,
38
+ _OutputAgentState,
34
39
  )
35
40
  from langchain.agents.structured_output import (
36
41
  AutoStrategy,
@@ -39,14 +44,14 @@ from langchain.agents.structured_output import (
39
44
  ProviderStrategy,
40
45
  ProviderStrategyBinding,
41
46
  ResponseFormat,
47
+ StructuredOutputError,
42
48
  StructuredOutputValidationError,
43
49
  ToolStrategy,
44
50
  )
45
51
  from langchain.chat_models import init_chat_model
46
- from langchain.tools import ToolNode
47
52
 
48
53
  if TYPE_CHECKING:
49
- from collections.abc import Callable, Sequence
54
+ from collections.abc import Awaitable, Callable, Sequence
50
55
 
51
56
  from langchain_core.runnables import Runnable
52
57
  from langgraph.cache.base import BaseCache
@@ -54,18 +59,223 @@ if TYPE_CHECKING:
54
59
  from langgraph.store.base import BaseStore
55
60
  from langgraph.types import Checkpointer
56
61
 
62
+ from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
63
+
57
64
  STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
58
65
 
59
- ResponseT = TypeVar("ResponseT")
66
+
67
+ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
68
+ """Normalize middleware return value to ModelResponse."""
69
+ if isinstance(result, AIMessage):
70
+ return ModelResponse(result=[result], structured_response=None)
71
+ return result
72
+
73
+
74
+ def _chain_model_call_handlers(
75
+ handlers: Sequence[
76
+ Callable[
77
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
78
+ ModelResponse | AIMessage,
79
+ ]
80
+ ],
81
+ ) -> (
82
+ Callable[
83
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
84
+ ModelResponse,
85
+ ]
86
+ | None
87
+ ):
88
+ """Compose multiple wrap_model_call handlers into single middleware stack.
89
+
90
+ Composes handlers so first in list becomes outermost layer. Each handler
91
+ receives a handler callback to execute inner layers.
92
+
93
+ Args:
94
+ handlers: List of handlers. First handler wraps all others.
95
+
96
+ Returns:
97
+ Composed handler, or `None` if handlers empty.
98
+
99
+ Example:
100
+ ```python
101
+ # handlers=[auth, retry] means: auth wraps retry
102
+ # Flow: auth calls retry, retry calls base handler
103
+ def auth(req, state, runtime, handler):
104
+ try:
105
+ return handler(req)
106
+ except UnauthorizedError:
107
+ refresh_token()
108
+ return handler(req)
109
+
110
+
111
+ def retry(req, state, runtime, handler):
112
+ for attempt in range(3):
113
+ try:
114
+ return handler(req)
115
+ except Exception:
116
+ if attempt == 2:
117
+ raise
118
+
119
+
120
+ handler = _chain_model_call_handlers([auth, retry])
121
+ ```
122
+ """
123
+ if not handlers:
124
+ return None
125
+
126
+ if len(handlers) == 1:
127
+ # Single handler - wrap to normalize output
128
+ single_handler = handlers[0]
129
+
130
+ def normalized_single(
131
+ request: ModelRequest,
132
+ handler: Callable[[ModelRequest], ModelResponse],
133
+ ) -> ModelResponse:
134
+ result = single_handler(request, handler)
135
+ return _normalize_to_model_response(result)
136
+
137
+ return normalized_single
138
+
139
+ def compose_two(
140
+ outer: Callable[
141
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
142
+ ModelResponse | AIMessage,
143
+ ],
144
+ inner: Callable[
145
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
146
+ ModelResponse | AIMessage,
147
+ ],
148
+ ) -> Callable[
149
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
150
+ ModelResponse,
151
+ ]:
152
+ """Compose two handlers where outer wraps inner."""
153
+
154
+ def composed(
155
+ request: ModelRequest,
156
+ handler: Callable[[ModelRequest], ModelResponse],
157
+ ) -> ModelResponse:
158
+ # Create a wrapper that calls inner with the base handler and normalizes
159
+ def inner_handler(req: ModelRequest) -> ModelResponse:
160
+ inner_result = inner(req, handler)
161
+ return _normalize_to_model_response(inner_result)
162
+
163
+ # Call outer with the wrapped inner as its handler and normalize
164
+ outer_result = outer(request, inner_handler)
165
+ return _normalize_to_model_response(outer_result)
166
+
167
+ return composed
168
+
169
+ # Compose right-to-left: outer(inner(innermost(handler)))
170
+ result = handlers[-1]
171
+ for handler in reversed(handlers[:-1]):
172
+ result = compose_two(handler, result)
173
+
174
+ # Wrap to ensure final return type is exactly ModelResponse
175
+ def final_normalized(
176
+ request: ModelRequest,
177
+ handler: Callable[[ModelRequest], ModelResponse],
178
+ ) -> ModelResponse:
179
+ # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
180
+ final_result = result(request, handler)
181
+ return _normalize_to_model_response(final_result)
182
+
183
+ return final_normalized
184
+
185
+
186
+ def _chain_async_model_call_handlers(
187
+ handlers: Sequence[
188
+ Callable[
189
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
190
+ Awaitable[ModelResponse | AIMessage],
191
+ ]
192
+ ],
193
+ ) -> (
194
+ Callable[
195
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
196
+ Awaitable[ModelResponse],
197
+ ]
198
+ | None
199
+ ):
200
+ """Compose multiple async `wrap_model_call` handlers into single middleware stack.
201
+
202
+ Args:
203
+ handlers: List of async handlers. First handler wraps all others.
204
+
205
+ Returns:
206
+ Composed async handler, or `None` if handlers empty.
207
+ """
208
+ if not handlers:
209
+ return None
210
+
211
+ if len(handlers) == 1:
212
+ # Single handler - wrap to normalize output
213
+ single_handler = handlers[0]
214
+
215
+ async def normalized_single(
216
+ request: ModelRequest,
217
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
218
+ ) -> ModelResponse:
219
+ result = await single_handler(request, handler)
220
+ return _normalize_to_model_response(result)
221
+
222
+ return normalized_single
223
+
224
+ def compose_two(
225
+ outer: Callable[
226
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
227
+ Awaitable[ModelResponse | AIMessage],
228
+ ],
229
+ inner: Callable[
230
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
231
+ Awaitable[ModelResponse | AIMessage],
232
+ ],
233
+ ) -> Callable[
234
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
235
+ Awaitable[ModelResponse],
236
+ ]:
237
+ """Compose two async handlers where outer wraps inner."""
238
+
239
+ async def composed(
240
+ request: ModelRequest,
241
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
242
+ ) -> ModelResponse:
243
+ # Create a wrapper that calls inner with the base handler and normalizes
244
+ async def inner_handler(req: ModelRequest) -> ModelResponse:
245
+ inner_result = await inner(req, handler)
246
+ return _normalize_to_model_response(inner_result)
247
+
248
+ # Call outer with the wrapped inner as its handler and normalize
249
+ outer_result = await outer(request, inner_handler)
250
+ return _normalize_to_model_response(outer_result)
251
+
252
+ return composed
253
+
254
+ # Compose right-to-left: outer(inner(innermost(handler)))
255
+ result = handlers[-1]
256
+ for handler in reversed(handlers[:-1]):
257
+ result = compose_two(handler, result)
258
+
259
+ # Wrap to ensure final return type is exactly ModelResponse
260
+ async def final_normalized(
261
+ request: ModelRequest,
262
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
263
+ ) -> ModelResponse:
264
+ # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
265
+ final_result = await result(request, handler)
266
+ return _normalize_to_model_response(final_result)
267
+
268
+ return final_normalized
60
269
 
61
270
 
62
271
  def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
63
- """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
272
+ """Resolve schema by merging schemas and optionally respecting `OmitFromSchema` annotations.
64
273
 
65
274
  Args:
66
275
  schemas: List of schema types to merge
67
- schema_name: Name for the generated TypedDict
68
- omit_flag: If specified, omit fields with this flag set ('input' or 'output')
276
+ schema_name: Name for the generated `TypedDict`
277
+ omit_flag: If specified, omit fields with this flag set (`'input'` or
278
+ `'output'`)
69
279
  """
70
280
  all_annotations = {}
71
281
 
@@ -105,11 +315,11 @@ def _extract_metadata(type_: type) -> list:
105
315
 
106
316
 
107
317
  def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]:
108
- """Get the can_jump_to list from either sync or async hook methods.
318
+ """Get the `can_jump_to` list from either sync or async hook methods.
109
319
 
110
320
  Args:
111
321
  middleware: The middleware instance to inspect.
112
- hook_name: The name of the hook ('before_model' or 'after_model').
322
+ hook_name: The name of the hook (`'before_model'` or `'after_model'`).
113
323
 
114
324
  Returns:
115
325
  List of jump destinations, or empty list if not configured.
@@ -143,10 +353,10 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
143
353
  """Check if a model supports provider-specific structured output.
144
354
 
145
355
  Args:
146
- model: Model name string or BaseChatModel instance.
356
+ model: Model name string or `BaseChatModel` instance.
147
357
 
148
358
  Returns:
149
- ``True`` if the model supports provider-specific structured output, ``False`` otherwise.
359
+ `True` if the model supports provider-specific structured output, `False` otherwise.
150
360
  """
151
361
  model_name: str | None = None
152
362
  if isinstance(model, str):
@@ -166,7 +376,7 @@ def _handle_structured_output_error(
166
376
  exception: Exception,
167
377
  response_format: ResponseFormat,
168
378
  ) -> tuple[bool, str]:
169
- """Handle structured output error. Returns (should_retry, retry_tool_message)."""
379
+ """Handle structured output error. Returns `(should_retry, retry_tool_message)`."""
170
380
  if not isinstance(response_format, ToolStrategy):
171
381
  return False, ""
172
382
 
@@ -192,13 +402,124 @@ def _handle_structured_output_error(
192
402
  return False, ""
193
403
 
194
404
 
405
+ def _chain_tool_call_wrappers(
406
+ wrappers: Sequence[ToolCallWrapper],
407
+ ) -> ToolCallWrapper | None:
408
+ """Compose wrappers into middleware stack (first = outermost).
409
+
410
+ Args:
411
+ wrappers: Wrappers in middleware order.
412
+
413
+ Returns:
414
+ Composed wrapper, or `None` if empty.
415
+
416
+ Example:
417
+ wrapper = _chain_tool_call_wrappers([auth, cache, retry])
418
+ # Request flows: auth -> cache -> retry -> tool
419
+ # Response flows: tool -> retry -> cache -> auth
420
+ """
421
+ if not wrappers:
422
+ return None
423
+
424
+ if len(wrappers) == 1:
425
+ return wrappers[0]
426
+
427
+ def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
428
+ """Compose two wrappers where outer wraps inner."""
429
+
430
+ def composed(
431
+ request: ToolCallRequest,
432
+ execute: Callable[[ToolCallRequest], ToolMessage | Command],
433
+ ) -> ToolMessage | Command:
434
+ # Create a callable that invokes inner with the original execute
435
+ def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
436
+ return inner(req, execute)
437
+
438
+ # Outer can call call_inner multiple times
439
+ return outer(request, call_inner)
440
+
441
+ return composed
442
+
443
+ # Chain all wrappers: first -> second -> ... -> last
444
+ result = wrappers[-1]
445
+ for wrapper in reversed(wrappers[:-1]):
446
+ result = compose_two(wrapper, result)
447
+
448
+ return result
449
+
450
+
451
+ def _chain_async_tool_call_wrappers(
452
+ wrappers: Sequence[
453
+ Callable[
454
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
455
+ Awaitable[ToolMessage | Command],
456
+ ]
457
+ ],
458
+ ) -> (
459
+ Callable[
460
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
461
+ Awaitable[ToolMessage | Command],
462
+ ]
463
+ | None
464
+ ):
465
+ """Compose async wrappers into middleware stack (first = outermost).
466
+
467
+ Args:
468
+ wrappers: Async wrappers in middleware order.
469
+
470
+ Returns:
471
+ Composed async wrapper, or `None` if empty.
472
+ """
473
+ if not wrappers:
474
+ return None
475
+
476
+ if len(wrappers) == 1:
477
+ return wrappers[0]
478
+
479
+ def compose_two(
480
+ outer: Callable[
481
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
482
+ Awaitable[ToolMessage | Command],
483
+ ],
484
+ inner: Callable[
485
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
486
+ Awaitable[ToolMessage | Command],
487
+ ],
488
+ ) -> Callable[
489
+ [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
490
+ Awaitable[ToolMessage | Command],
491
+ ]:
492
+ """Compose two async wrappers where outer wraps inner."""
493
+
494
+ async def composed(
495
+ request: ToolCallRequest,
496
+ execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
497
+ ) -> ToolMessage | Command:
498
+ # Create an async callable that invokes inner with the original execute
499
+ async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
500
+ return await inner(req, execute)
501
+
502
+ # Outer can call call_inner multiple times
503
+ return await outer(request, call_inner)
504
+
505
+ return composed
506
+
507
+ # Chain all wrappers: first -> second -> ... -> last
508
+ result = wrappers[-1]
509
+ for wrapper in reversed(wrappers[:-1]):
510
+ result = compose_two(wrapper, result)
511
+
512
+ return result
513
+
514
+
195
515
  def create_agent( # noqa: PLR0915
196
516
  model: str | BaseChatModel,
197
517
  tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
198
518
  *,
199
519
  system_prompt: str | None = None,
200
- middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
520
+ middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
201
521
  response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
522
+ state_schema: type[AgentState[ResponseT]] | None = None,
202
523
  context_schema: type[ContextT] | None = None,
203
524
  checkpointer: Checkpointer | None = None,
204
525
  store: BaseStore | None = None,
@@ -208,56 +529,89 @@ def create_agent( # noqa: PLR0915
208
529
  name: str | None = None,
209
530
  cache: BaseCache | None = None,
210
531
  ) -> CompiledStateGraph[
211
- AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
532
+ AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
212
533
  ]:
213
534
  """Creates an agent graph that calls tools in a loop until a stopping condition is met.
214
535
 
215
- For more details on using ``create_agent``,
216
- visit [Agents](https://docs.langchain.com/oss/python/langchain/agents) documentation.
536
+ For more details on using `create_agent`,
537
+ visit the [Agents](https://docs.langchain.com/oss/python/langchain/agents) docs.
217
538
 
218
539
  Args:
219
540
  model: The language model for the agent. Can be a string identifier
220
- (e.g., ``"openai:gpt-4"``), a chat model instance (e.g., ``ChatOpenAI()``).
221
- tools: A list of tools, dicts, or callables. If ``None`` or an empty list,
222
- the agent will consist of a model node without a tool calling loop.
223
- system_prompt: An optional system prompt for the LLM. If provided as a string,
224
- it will be converted to a SystemMessage and added to the beginning
225
- of the message list.
541
+ (e.g., `"openai:gpt-4"`) or a direct chat model instance (e.g.,
542
+ [`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
543
+ [chat model](https://docs.langchain.com/oss/python/integrations/chat)).
544
+
545
+ For a full list of supported model strings, see
546
+ [`init_chat_model`][langchain.chat_models.init_chat_model(model_provider)].
547
+ tools: A list of tools, `dicts`, or `Callable`.
548
+
549
+ If `None` or an empty list, the agent will consist of a model node without a
550
+ tool calling loop.
551
+ system_prompt: An optional system prompt for the LLM.
552
+
553
+ Prompts are converted to a
554
+ [`SystemMessage`][langchain.messages.SystemMessage] and added to the
555
+ beginning of the message list.
226
556
  middleware: A sequence of middleware instances to apply to the agent.
227
- Middleware can intercept and modify agent behavior at various stages.
557
+
558
+ Middleware can intercept and modify agent behavior at various stages. See
559
+ the [full guide](https://docs.langchain.com/oss/python/langchain/middleware).
228
560
  response_format: An optional configuration for structured responses.
229
- Can be a ToolStrategy, ProviderStrategy, or a Pydantic model class.
561
+
562
+ Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
563
+
230
564
  If provided, the agent will handle structured output during the
231
565
  conversation flow. Raw schemas will be wrapped in an appropriate strategy
232
566
  based on model capabilities.
567
+ state_schema: An optional `TypedDict` schema that extends `AgentState`.
568
+
569
+ When provided, this schema is used instead of `AgentState` as the base
570
+ schema for merging with middleware state schemas. This allows users to
571
+ add custom state fields without needing to create custom middleware.
572
+ Generally, it's recommended to use `state_schema` extensions via middleware
573
+ to keep relevant extensions scoped to corresponding hooks / tools.
574
+
575
+ The schema must be a subclass of `AgentState[ResponseT]`.
233
576
  context_schema: An optional schema for runtime context.
234
- checkpointer: An optional checkpoint saver object. This is used for persisting
235
- the state of the graph (e.g., as chat memory) for a single thread
236
- (e.g., a single conversation).
237
- store: An optional store object. This is used for persisting data
238
- across multiple threads (e.g., multiple conversations / users).
577
+ checkpointer: An optional checkpoint saver object.
578
+
579
+ Used for persisting the state of the graph (e.g., as chat memory) for a
580
+ single thread (e.g., a single conversation).
581
+ store: An optional store object.
582
+
583
+ Used for persisting data across multiple threads (e.g., multiple
584
+ conversations / users).
239
585
  interrupt_before: An optional list of node names to interrupt before.
240
- This is useful if you want to add a user confirmation or other interrupt
586
+
587
+ Useful if you want to add a user confirmation or other interrupt
241
588
  before taking an action.
242
589
  interrupt_after: An optional list of node names to interrupt after.
243
- This is useful if you want to return directly or run additional processing
590
+
591
+ Useful if you want to return directly or run additional processing
244
592
  on an output.
245
- debug: A flag indicating whether to enable debug mode.
246
- name: An optional name for the CompiledStateGraph.
593
+ debug: Whether to enable verbose logging for graph execution.
594
+
595
+ When enabled, prints detailed information about each node execution, state
596
+ updates, and transitions during agent runtime. Useful for debugging
597
+ middleware behavior and understanding agent execution flow.
598
+ name: An optional name for the `CompiledStateGraph`.
599
+
247
600
  This name will be automatically used when adding the agent graph to
248
601
  another graph as a subgraph node - particularly useful for building
249
602
  multi-agent systems.
250
- cache: An optional BaseCache instance to enable caching of graph execution.
603
+ cache: An optional `BaseCache` instance to enable caching of graph execution.
251
604
 
252
605
  Returns:
253
- A compiled StateGraph that can be used for chat interactions.
606
+ A compiled `StateGraph` that can be used for chat interactions.
254
607
 
255
608
  The agent node calls the language model with the messages list (after applying
256
- the system prompt). If the resulting AIMessage contains ``tool_calls``, the graph will
257
- then call the tools. The tools node executes the tools and adds the responses
258
- to the messages list as ``ToolMessage`` objects. The agent node then calls the
259
- language model again. The process repeats until no more ``tool_calls`` are
260
- present in the response. The agent then returns the full list of messages.
609
+ the system prompt). If the resulting [`AIMessage`][langchain.messages.AIMessage]
610
+ contains `tool_calls`, the graph will then call the tools. The tools node executes
611
+ the tools and adds the responses to the messages list as
612
+ [`ToolMessage`][langchain.messages.ToolMessage] objects. The agent node then calls
613
+ the language model again. The process repeats until no more `tool_calls` are present
614
+ in the response. The agent then returns the full list of messages.
261
615
 
262
616
  Example:
263
617
  ```python
@@ -270,7 +624,7 @@ def create_agent( # noqa: PLR0915
270
624
 
271
625
 
272
626
  graph = create_agent(
273
- model="anthropic:claude-3-7-sonnet-latest",
627
+ model="anthropic:claude-sonnet-4-5-20250929",
274
628
  tools=[check_weather],
275
629
  system_prompt="You are a helpful assistant",
276
630
  )
@@ -319,6 +673,38 @@ def create_agent( # noqa: PLR0915
319
673
  structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
320
674
  middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
321
675
 
676
+ # Collect middleware with wrap_tool_call or awrap_tool_call hooks
677
+ # Include middleware with either implementation to ensure NotImplementedError is raised
678
+ # when middleware doesn't support the execution path
679
+ middleware_w_wrap_tool_call = [
680
+ m
681
+ for m in middleware
682
+ if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
683
+ or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
684
+ ]
685
+
686
+ # Chain all wrap_tool_call handlers into a single composed handler
687
+ wrap_tool_call_wrapper = None
688
+ if middleware_w_wrap_tool_call:
689
+ wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
690
+ wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
691
+
692
+ # Collect middleware with awrap_tool_call or wrap_tool_call hooks
693
+ # Include middleware with either implementation to ensure NotImplementedError is raised
694
+ # when middleware doesn't support the execution path
695
+ middleware_w_awrap_tool_call = [
696
+ m
697
+ for m in middleware
698
+ if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
699
+ or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
700
+ ]
701
+
702
+ # Chain all awrap_tool_call handlers into a single composed async handler
703
+ awrap_tool_call_wrapper = None
704
+ if middleware_w_awrap_tool_call:
705
+ async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
706
+ awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
707
+
322
708
  # Setup tools
323
709
  tool_node: ToolNode | None = None
324
710
  # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
@@ -329,7 +715,15 @@ def create_agent( # noqa: PLR0915
329
715
  available_tools = middleware_tools + regular_tools
330
716
 
331
717
  # Only create ToolNode if we have client-side tools
332
- tool_node = ToolNode(tools=available_tools) if available_tools else None
718
+ tool_node = (
719
+ ToolNode(
720
+ tools=available_tools,
721
+ wrap_tool_call=wrap_tool_call_wrapper,
722
+ awrap_tool_call=awrap_tool_call_wrapper,
723
+ )
724
+ if available_tools
725
+ else None
726
+ )
333
727
 
334
728
  # Default tools for ModelRequest initialization
335
729
  # Use converted BaseTool instances from ToolNode (not raw callables)
@@ -356,12 +750,6 @@ def create_agent( # noqa: PLR0915
356
750
  if m.__class__.before_model is not AgentMiddleware.before_model
357
751
  or m.__class__.abefore_model is not AgentMiddleware.abefore_model
358
752
  ]
359
- middleware_w_modify_model_request = [
360
- m
361
- for m in middleware
362
- if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
363
- or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
364
- ]
365
753
  middleware_w_after_model = [
366
754
  m
367
755
  for m in middleware
@@ -374,25 +762,51 @@ def create_agent( # noqa: PLR0915
374
762
  if m.__class__.after_agent is not AgentMiddleware.after_agent
375
763
  or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
376
764
  ]
377
- middleware_w_retry = [
765
+ # Collect middleware with wrap_model_call or awrap_model_call hooks
766
+ # Include middleware with either implementation to ensure NotImplementedError is raised
767
+ # when middleware doesn't support the execution path
768
+ middleware_w_wrap_model_call = [
769
+ m
770
+ for m in middleware
771
+ if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
772
+ or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
773
+ ]
774
+ # Collect middleware with awrap_model_call or wrap_model_call hooks
775
+ # Include middleware with either implementation to ensure NotImplementedError is raised
776
+ # when middleware doesn't support the execution path
777
+ middleware_w_awrap_model_call = [
378
778
  m
379
779
  for m in middleware
380
- if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
381
- or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
780
+ if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
781
+ or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
382
782
  ]
383
783
 
384
- state_schemas = {m.state_schema for m in middleware}
385
- state_schemas.add(AgentState)
784
+ # Compose wrap_model_call handlers into a single middleware stack (sync)
785
+ wrap_model_call_handler = None
786
+ if middleware_w_wrap_model_call:
787
+ sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
788
+ wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
386
789
 
387
- state_schema = _resolve_schema(state_schemas, "StateSchema", None)
790
+ # Compose awrap_model_call handlers into a single middleware stack (async)
791
+ awrap_model_call_handler = None
792
+ if middleware_w_awrap_model_call:
793
+ async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
794
+ awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
795
+
796
+ state_schemas: set[type] = {m.state_schema for m in middleware}
797
+ # Use provided state_schema if available, otherwise use base AgentState
798
+ base_state = state_schema if state_schema is not None else AgentState
799
+ state_schemas.add(base_state)
800
+
801
+ resolved_state_schema = _resolve_schema(state_schemas, "StateSchema", None)
388
802
  input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
389
803
  output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
390
804
 
391
805
  # create graph, add nodes
392
806
  graph: StateGraph[
393
- AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
807
+ AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
394
808
  ] = StateGraph(
395
- state_schema=state_schema,
809
+ state_schema=resolved_state_schema,
396
810
  input_schema=input_schema,
397
811
  output_schema=output_schema,
398
812
  context_schema=context_schema,
@@ -414,8 +828,16 @@ def create_agent( # noqa: PLR0915
414
828
  provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
415
829
  effective_response_format.schema_spec
416
830
  )
417
- structured_response = provider_strategy_binding.parse(output)
418
- return {"messages": [output], "structured_response": structured_response}
831
+ try:
832
+ structured_response = provider_strategy_binding.parse(output)
833
+ except Exception as exc: # noqa: BLE001
834
+ schema_name = getattr(
835
+ effective_response_format.schema_spec.schema, "__name__", "response_format"
836
+ )
837
+ validation_error = StructuredOutputValidationError(schema_name, exc, output)
838
+ raise validation_error
839
+ else:
840
+ return {"messages": [output], "structured_response": structured_response}
419
841
  return {"messages": [output]}
420
842
 
421
843
  # Handle structured output with tool strategy
@@ -429,11 +851,11 @@ def create_agent( # noqa: PLR0915
429
851
  ]
430
852
 
431
853
  if structured_tool_calls:
432
- exception: Exception | None = None
854
+ exception: StructuredOutputError | None = None
433
855
  if len(structured_tool_calls) > 1:
434
856
  # Handle multiple structured outputs error
435
857
  tool_names = [tc["name"] for tc in structured_tool_calls]
436
- exception = MultipleStructuredOutputsError(tool_names)
858
+ exception = MultipleStructuredOutputsError(tool_names, output)
437
859
  should_retry, error_message = _handle_structured_output_error(
438
860
  exception, effective_response_format
439
861
  )
@@ -475,7 +897,7 @@ def create_agent( # noqa: PLR0915
475
897
  "structured_response": structured_response,
476
898
  }
477
899
  except Exception as exc: # noqa: BLE001
478
- exception = StructuredOutputValidationError(tool_call["name"], exc)
900
+ exception = StructuredOutputValidationError(tool_call["name"], exc, output)
479
901
  should_retry, error_message = _handle_structured_output_error(
480
902
  exception, effective_response_format
481
903
  )
@@ -504,8 +926,9 @@ def create_agent( # noqa: PLR0915
504
926
  request: The model request containing model, tools, and response format.
505
927
 
506
928
  Returns:
507
- Tuple of (bound_model, effective_response_format) where ``effective_response_format``
508
- is the actual strategy used (may differ from initial if auto-detected).
929
+ Tuple of `(bound_model, effective_response_format)` where
930
+ `effective_response_format` is the actual strategy used (may differ from
931
+ initial if auto-detected).
509
932
  """
510
933
  # Validate ONLY client-side tools that need to exist in tool_node
511
934
  # Build map of available client-side tools from the ToolNode
@@ -608,6 +1031,30 @@ def create_agent( # noqa: PLR0915
608
1031
  )
609
1032
  return request.model.bind(**request.model_settings), None
610
1033
 
1034
+ def _execute_model_sync(request: ModelRequest) -> ModelResponse:
1035
+ """Execute model and return response.
1036
+
1037
+ This is the core model execution logic wrapped by `wrap_model_call` handlers.
1038
+ Raises any exceptions that occur during model invocation.
1039
+ """
1040
+ # Get the bound model (with auto-detection if needed)
1041
+ model_, effective_response_format = _get_bound_model(request)
1042
+ messages = request.messages
1043
+ if request.system_prompt:
1044
+ messages = [SystemMessage(request.system_prompt), *messages]
1045
+
1046
+ output = model_.invoke(messages)
1047
+
1048
+ # Handle model output to get messages and structured_response
1049
+ handled_output = _handle_model_output(output, effective_response_format)
1050
+ messages_list = handled_output["messages"]
1051
+ structured_response = handled_output.get("structured_response")
1052
+
1053
+ return ModelResponse(
1054
+ result=messages_list,
1055
+ structured_response=structured_response,
1056
+ )
1057
+
611
1058
  def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
612
1059
  """Sync model request handler with sequential middleware processing."""
613
1060
  request = ModelRequest(
@@ -617,62 +1064,49 @@ def create_agent( # noqa: PLR0915
617
1064
  response_format=initial_response_format,
618
1065
  messages=state["messages"],
619
1066
  tool_choice=None,
1067
+ state=state,
1068
+ runtime=runtime,
620
1069
  )
621
1070
 
622
- # Apply modify_model_request middleware in sequence
623
- for m in middleware_w_modify_model_request:
624
- if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
625
- m.modify_model_request(request, state, runtime)
626
- else:
627
- msg = (
628
- f"No synchronous function provided for "
629
- f'{m.__class__.__name__}.amodify_model_request".'
630
- "\nEither initialize with a synchronous function or invoke"
631
- " via the async API (ainvoke, astream, etc.)"
632
- )
633
- raise TypeError(msg)
1071
+ if wrap_model_call_handler is None:
1072
+ # No handlers - execute directly
1073
+ response = _execute_model_sync(request)
1074
+ else:
1075
+ # Call composed handler with base handler
1076
+ response = wrap_model_call_handler(request, _execute_model_sync)
634
1077
 
635
- # Retry loop for model invocation with error handling
636
- # Hard limit of 100 attempts to prevent infinite loops from buggy middleware
637
- max_attempts = 100
638
- for attempt in range(1, max_attempts + 1):
639
- try:
640
- # Get the bound model (with auto-detection if needed)
641
- model_, effective_response_format = _get_bound_model(request)
642
- messages = request.messages
643
- if request.system_prompt:
644
- messages = [SystemMessage(request.system_prompt), *messages]
645
-
646
- output = model_.invoke(messages)
647
- return {
648
- "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
649
- "run_model_call_count": state.get("run_model_call_count", 0) + 1,
650
- **_handle_model_output(output, effective_response_format),
651
- }
652
- except Exception as error:
653
- # Try retry_model_request on each middleware
654
- for m in middleware_w_retry:
655
- if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
656
- if retry_request := m.retry_model_request(
657
- error, request, state, runtime, attempt
658
- ):
659
- # Break on first middleware that wants to retry
660
- request = retry_request
661
- break
662
- else:
663
- msg = (
664
- f"No synchronous function provided for "
665
- f'{m.__class__.__name__}.aretry_model_request".'
666
- "\nEither initialize with a synchronous function or invoke"
667
- " via the async API (ainvoke, astream, etc.)"
668
- )
669
- raise TypeError(msg)
670
- else:
671
- raise
1078
+ # Extract state updates from ModelResponse
1079
+ state_updates = {"messages": response.result}
1080
+ if response.structured_response is not None:
1081
+ state_updates["structured_response"] = response.structured_response
1082
+
1083
+ return state_updates
1084
+
1085
+ async def _execute_model_async(request: ModelRequest) -> ModelResponse:
1086
+ """Execute model asynchronously and return response.
672
1087
 
673
- # If we exit the loop, max attempts exceeded
674
- msg = f"Maximum retry attempts ({max_attempts}) exceeded"
675
- raise RuntimeError(msg)
1088
+ This is the core async model execution logic wrapped by `wrap_model_call`
1089
+ handlers.
1090
+
1091
+ Raises any exceptions that occur during model invocation.
1092
+ """
1093
+ # Get the bound model (with auto-detection if needed)
1094
+ model_, effective_response_format = _get_bound_model(request)
1095
+ messages = request.messages
1096
+ if request.system_prompt:
1097
+ messages = [SystemMessage(request.system_prompt), *messages]
1098
+
1099
+ output = await model_.ainvoke(messages)
1100
+
1101
+ # Handle model output to get messages and structured_response
1102
+ handled_output = _handle_model_output(output, effective_response_format)
1103
+ messages_list = handled_output["messages"]
1104
+ structured_response = handled_output.get("structured_response")
1105
+
1106
+ return ModelResponse(
1107
+ result=messages_list,
1108
+ structured_response=structured_response,
1109
+ )
676
1110
 
677
1111
  async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
678
1112
  """Async model request handler with sequential middleware processing."""
@@ -683,45 +1117,23 @@ def create_agent( # noqa: PLR0915
683
1117
  response_format=initial_response_format,
684
1118
  messages=state["messages"],
685
1119
  tool_choice=None,
1120
+ state=state,
1121
+ runtime=runtime,
686
1122
  )
687
1123
 
688
- # Apply modify_model_request middleware in sequence
689
- for m in middleware_w_modify_model_request:
690
- await m.amodify_model_request(request, state, runtime)
1124
+ if awrap_model_call_handler is None:
1125
+ # No async handlers - execute directly
1126
+ response = await _execute_model_async(request)
1127
+ else:
1128
+ # Call composed async handler with base handler
1129
+ response = await awrap_model_call_handler(request, _execute_model_async)
691
1130
 
692
- # Retry loop for model invocation with error handling
693
- # Hard limit of 100 attempts to prevent infinite loops from buggy middleware
694
- max_attempts = 100
695
- for attempt in range(1, max_attempts + 1):
696
- try:
697
- # Get the bound model (with auto-detection if needed)
698
- model_, effective_response_format = _get_bound_model(request)
699
- messages = request.messages
700
- if request.system_prompt:
701
- messages = [SystemMessage(request.system_prompt), *messages]
702
-
703
- output = await model_.ainvoke(messages)
704
- return {
705
- "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
706
- "run_model_call_count": state.get("run_model_call_count", 0) + 1,
707
- **_handle_model_output(output, effective_response_format),
708
- }
709
- except Exception as error:
710
- # Try retry_model_request on each middleware
711
- for m in middleware_w_retry:
712
- if retry_request := await m.aretry_model_request(
713
- error, request, state, runtime, attempt
714
- ):
715
- # Break on first middleware that wants to retry
716
- request = retry_request
717
- break
718
- else:
719
- # If no middleware wants to retry, re-raise the error
720
- raise
1131
+ # Extract state updates from ModelResponse
1132
+ state_updates = {"messages": response.result}
1133
+ if response.structured_response is not None:
1134
+ state_updates["structured_response"] = response.structured_response
721
1135
 
722
- # If we exit the loop, max attempts exceeded
723
- msg = f"Maximum retry attempts ({max_attempts}) exceeded"
724
- raise RuntimeError(msg)
1136
+ return state_updates
725
1137
 
726
1138
  # Use sync or async based on model capabilities
727
1139
  graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
@@ -749,7 +1161,9 @@ def create_agent( # noqa: PLR0915
749
1161
  else None
750
1162
  )
751
1163
  before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
752
- graph.add_node(f"{m.name}.before_agent", before_agent_node, input_schema=state_schema)
1164
+ graph.add_node(
1165
+ f"{m.name}.before_agent", before_agent_node, input_schema=resolved_state_schema
1166
+ )
753
1167
 
754
1168
  if (
755
1169
  m.__class__.before_model is not AgentMiddleware.before_model
@@ -768,7 +1182,9 @@ def create_agent( # noqa: PLR0915
768
1182
  else None
769
1183
  )
770
1184
  before_node = RunnableCallable(sync_before, async_before, trace=False)
771
- graph.add_node(f"{m.name}.before_model", before_node, input_schema=state_schema)
1185
+ graph.add_node(
1186
+ f"{m.name}.before_model", before_node, input_schema=resolved_state_schema
1187
+ )
772
1188
 
773
1189
  if (
774
1190
  m.__class__.after_model is not AgentMiddleware.after_model
@@ -787,7 +1203,7 @@ def create_agent( # noqa: PLR0915
787
1203
  else None
788
1204
  )
789
1205
  after_node = RunnableCallable(sync_after, async_after, trace=False)
790
- graph.add_node(f"{m.name}.after_model", after_node, input_schema=state_schema)
1206
+ graph.add_node(f"{m.name}.after_model", after_node, input_schema=resolved_state_schema)
791
1207
 
792
1208
  if (
793
1209
  m.__class__.after_agent is not AgentMiddleware.after_agent
@@ -806,7 +1222,9 @@ def create_agent( # noqa: PLR0915
806
1222
  else None
807
1223
  )
808
1224
  after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
809
- graph.add_node(f"{m.name}.after_agent", after_agent_node, input_schema=state_schema)
1225
+ graph.add_node(
1226
+ f"{m.name}.after_agent", after_agent_node, input_schema=resolved_state_schema
1227
+ )
810
1228
 
811
1229
  # Determine the entry node (runs once at start): before_agent -> before_model -> model
812
1230
  if middleware_w_before_agent:
@@ -839,25 +1257,61 @@ def create_agent( # noqa: PLR0915
839
1257
  graph.add_edge(START, entry_node)
840
1258
  # add conditional edges only if tools exist
841
1259
  if tool_node is not None:
1260
+ # Only include exit_node in destinations if any tool has return_direct=True
1261
+ # or if there are structured output tools
1262
+ tools_to_model_destinations = [loop_entry_node]
1263
+ if (
1264
+ any(tool.return_direct for tool in tool_node.tools_by_name.values())
1265
+ or structured_output_tools
1266
+ ):
1267
+ tools_to_model_destinations.append(exit_node)
1268
+
842
1269
  graph.add_conditional_edges(
843
1270
  "tools",
844
- _make_tools_to_model_edge(
845
- tool_node, loop_entry_node, structured_output_tools, exit_node
1271
+ RunnableCallable(
1272
+ _make_tools_to_model_edge(
1273
+ tool_node=tool_node,
1274
+ model_destination=loop_entry_node,
1275
+ structured_output_tools=structured_output_tools,
1276
+ end_destination=exit_node,
1277
+ ),
1278
+ trace=False,
846
1279
  ),
847
- [loop_entry_node, exit_node],
1280
+ tools_to_model_destinations,
848
1281
  )
849
1282
 
1283
+ # base destinations are tools and exit_node
1284
+ # we add the loop_entry node to edge destinations if:
1285
+ # - there is an after model hook(s) -- allows jump_to to model
1286
+ # potentially artificially injected tool messages, ex HITL
1287
+ # - there is a response format -- to allow for jumping to model to handle
1288
+ # regenerating structured output tool calls
1289
+ model_to_tools_destinations = ["tools", exit_node]
1290
+ if response_format or loop_exit_node != "model":
1291
+ model_to_tools_destinations.append(loop_entry_node)
1292
+
850
1293
  graph.add_conditional_edges(
851
1294
  loop_exit_node,
852
- _make_model_to_tools_edge(
853
- loop_entry_node, structured_output_tools, tool_node, exit_node
1295
+ RunnableCallable(
1296
+ _make_model_to_tools_edge(
1297
+ model_destination=loop_entry_node,
1298
+ structured_output_tools=structured_output_tools,
1299
+ end_destination=exit_node,
1300
+ ),
1301
+ trace=False,
854
1302
  ),
855
- [loop_entry_node, "tools", exit_node],
1303
+ model_to_tools_destinations,
856
1304
  )
857
1305
  elif len(structured_output_tools) > 0:
858
1306
  graph.add_conditional_edges(
859
1307
  loop_exit_node,
860
- _make_model_to_model_edge(loop_entry_node, exit_node),
1308
+ RunnableCallable(
1309
+ _make_model_to_model_edge(
1310
+ model_destination=loop_entry_node,
1311
+ end_destination=exit_node,
1312
+ ),
1313
+ trace=False,
1314
+ ),
861
1315
  [loop_entry_node, exit_node],
862
1316
  )
863
1317
  elif loop_exit_node == "model":
@@ -867,9 +1321,10 @@ def create_agent( # noqa: PLR0915
867
1321
  else:
868
1322
  _add_middleware_edge(
869
1323
  graph,
870
- f"{middleware_w_after_model[0].name}.after_model",
871
- exit_node,
872
- loop_entry_node,
1324
+ name=f"{middleware_w_after_model[0].name}.after_model",
1325
+ default_destination=exit_node,
1326
+ model_destination=loop_entry_node,
1327
+ end_destination=exit_node,
873
1328
  can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
874
1329
  )
875
1330
 
@@ -878,17 +1333,19 @@ def create_agent( # noqa: PLR0915
878
1333
  for m1, m2 in itertools.pairwise(middleware_w_before_agent):
879
1334
  _add_middleware_edge(
880
1335
  graph,
881
- f"{m1.name}.before_agent",
882
- f"{m2.name}.before_agent",
883
- loop_entry_node,
1336
+ name=f"{m1.name}.before_agent",
1337
+ default_destination=f"{m2.name}.before_agent",
1338
+ model_destination=loop_entry_node,
1339
+ end_destination=exit_node,
884
1340
  can_jump_to=_get_can_jump_to(m1, "before_agent"),
885
1341
  )
886
1342
  # Connect last before_agent to loop_entry_node (before_model or model)
887
1343
  _add_middleware_edge(
888
1344
  graph,
889
- f"{middleware_w_before_agent[-1].name}.before_agent",
890
- loop_entry_node,
891
- loop_entry_node,
1345
+ name=f"{middleware_w_before_agent[-1].name}.before_agent",
1346
+ default_destination=loop_entry_node,
1347
+ model_destination=loop_entry_node,
1348
+ end_destination=exit_node,
892
1349
  can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
893
1350
  )
894
1351
 
@@ -897,17 +1354,19 @@ def create_agent( # noqa: PLR0915
897
1354
  for m1, m2 in itertools.pairwise(middleware_w_before_model):
898
1355
  _add_middleware_edge(
899
1356
  graph,
900
- f"{m1.name}.before_model",
901
- f"{m2.name}.before_model",
902
- loop_entry_node,
1357
+ name=f"{m1.name}.before_model",
1358
+ default_destination=f"{m2.name}.before_model",
1359
+ model_destination=loop_entry_node,
1360
+ end_destination=exit_node,
903
1361
  can_jump_to=_get_can_jump_to(m1, "before_model"),
904
1362
  )
905
1363
  # Go directly to model after the last before_model
906
1364
  _add_middleware_edge(
907
1365
  graph,
908
- f"{middleware_w_before_model[-1].name}.before_model",
909
- "model",
910
- loop_entry_node,
1366
+ name=f"{middleware_w_before_model[-1].name}.before_model",
1367
+ default_destination="model",
1368
+ model_destination=loop_entry_node,
1369
+ end_destination=exit_node,
911
1370
  can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
912
1371
  )
913
1372
 
@@ -919,9 +1378,10 @@ def create_agent( # noqa: PLR0915
919
1378
  m2 = middleware_w_after_model[idx - 1]
920
1379
  _add_middleware_edge(
921
1380
  graph,
922
- f"{m1.name}.after_model",
923
- f"{m2.name}.after_model",
924
- loop_entry_node,
1381
+ name=f"{m1.name}.after_model",
1382
+ default_destination=f"{m2.name}.after_model",
1383
+ model_destination=loop_entry_node,
1384
+ end_destination=exit_node,
925
1385
  can_jump_to=_get_can_jump_to(m1, "after_model"),
926
1386
  )
927
1387
  # Note: Connection from after_model to after_agent/END is handled above
@@ -935,18 +1395,20 @@ def create_agent( # noqa: PLR0915
935
1395
  m2 = middleware_w_after_agent[idx - 1]
936
1396
  _add_middleware_edge(
937
1397
  graph,
938
- f"{m1.name}.after_agent",
939
- f"{m2.name}.after_agent",
940
- loop_entry_node,
1398
+ name=f"{m1.name}.after_agent",
1399
+ default_destination=f"{m2.name}.after_agent",
1400
+ model_destination=loop_entry_node,
1401
+ end_destination=exit_node,
941
1402
  can_jump_to=_get_can_jump_to(m1, "after_agent"),
942
1403
  )
943
1404
 
944
1405
  # Connect the last after_agent to END
945
1406
  _add_middleware_edge(
946
1407
  graph,
947
- f"{middleware_w_after_agent[0].name}.after_agent",
948
- END,
949
- loop_entry_node,
1408
+ name=f"{middleware_w_after_agent[0].name}.after_agent",
1409
+ default_destination=END,
1410
+ model_destination=loop_entry_node,
1411
+ end_destination=exit_node,
950
1412
  can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
951
1413
  )
952
1414
 
@@ -961,11 +1423,16 @@ def create_agent( # noqa: PLR0915
961
1423
  )
962
1424
 
963
1425
 
964
- def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
1426
+ def _resolve_jump(
1427
+ jump_to: JumpTo | None,
1428
+ *,
1429
+ model_destination: str,
1430
+ end_destination: str,
1431
+ ) -> str | None:
965
1432
  if jump_to == "model":
966
- return first_node
1433
+ return model_destination
967
1434
  if jump_to == "end":
968
- return "__end__"
1435
+ return end_destination
969
1436
  if jump_to == "tools":
970
1437
  return "tools"
971
1438
  return None
@@ -988,17 +1455,21 @@ def _fetch_last_ai_and_tool_messages(
988
1455
 
989
1456
 
990
1457
  def _make_model_to_tools_edge(
991
- first_node: str,
1458
+ *,
1459
+ model_destination: str,
992
1460
  structured_output_tools: dict[str, OutputToolBinding],
993
- tool_node: ToolNode,
994
- exit_node: str,
995
- ) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
1461
+ end_destination: str,
1462
+ ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
996
1463
  def model_to_tools(
997
- state: dict[str, Any], runtime: Runtime[ContextT]
1464
+ state: dict[str, Any],
998
1465
  ) -> str | list[Send] | None:
999
1466
  # 1. if there's an explicit jump_to in the state, use it
1000
1467
  if jump_to := state.get("jump_to"):
1001
- return _resolve_jump(jump_to, first_node)
1468
+ return _resolve_jump(
1469
+ jump_to,
1470
+ model_destination=model_destination,
1471
+ end_destination=end_destination,
1472
+ )
1002
1473
 
1003
1474
  last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
1004
1475
  tool_message_ids = [m.tool_call_id for m in tool_messages]
@@ -1006,7 +1477,7 @@ def _make_model_to_tools_edge(
1006
1477
  # 2. if the model hasn't called any tools, exit the loop
1007
1478
  # this is the classic exit condition for an agent loop
1008
1479
  if len(last_ai_message.tool_calls) == 0:
1009
- return exit_node
1480
+ return end_destination
1010
1481
 
1011
1482
  pending_tool_calls = [
1012
1483
  c
@@ -1016,80 +1487,97 @@ def _make_model_to_tools_edge(
1016
1487
 
1017
1488
  # 3. if there are pending tool calls, jump to the tool node
1018
1489
  if pending_tool_calls:
1019
- pending_tool_calls = [
1020
- tool_node.inject_tool_args(call, state, runtime.store)
1021
- for call in pending_tool_calls
1490
+ return [
1491
+ Send(
1492
+ "tools",
1493
+ ToolCallWithContext(
1494
+ __type="tool_call_with_context",
1495
+ tool_call=tool_call,
1496
+ state=state,
1497
+ ),
1498
+ )
1499
+ for tool_call in pending_tool_calls
1022
1500
  ]
1023
- return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
1024
1501
 
1025
1502
  # 4. if there is a structured response, exit the loop
1026
1503
  if "structured_response" in state:
1027
- return exit_node
1504
+ return end_destination
1028
1505
 
1029
1506
  # 5. AIMessage has tool calls, but there are no pending tool calls
1030
- # which suggests the injection of artificial tool messages. jump to the first node
1031
- return first_node
1507
+ # which suggests the injection of artificial tool messages. jump to the model node
1508
+ return model_destination
1032
1509
 
1033
1510
  return model_to_tools
1034
1511
 
1035
1512
 
1036
1513
  def _make_model_to_model_edge(
1037
- first_node: str,
1038
- exit_node: str,
1039
- ) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
1514
+ *,
1515
+ model_destination: str,
1516
+ end_destination: str,
1517
+ ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
1040
1518
  def model_to_model(
1041
1519
  state: dict[str, Any],
1042
- runtime: Runtime[ContextT], # noqa: ARG001
1043
1520
  ) -> str | list[Send] | None:
1044
1521
  # 1. Priority: Check for explicit jump_to directive from middleware
1045
1522
  if jump_to := state.get("jump_to"):
1046
- return _resolve_jump(jump_to, first_node)
1523
+ return _resolve_jump(
1524
+ jump_to,
1525
+ model_destination=model_destination,
1526
+ end_destination=end_destination,
1527
+ )
1047
1528
 
1048
1529
  # 2. Exit condition: A structured response was generated
1049
1530
  if "structured_response" in state:
1050
- return exit_node
1531
+ return end_destination
1051
1532
 
1052
1533
  # 3. Default: Continue the loop, there may have been an issue
1053
1534
  # with structured output generation, so we need to retry
1054
- return first_node
1535
+ return model_destination
1055
1536
 
1056
1537
  return model_to_model
1057
1538
 
1058
1539
 
1059
1540
  def _make_tools_to_model_edge(
1541
+ *,
1060
1542
  tool_node: ToolNode,
1061
- next_node: str,
1543
+ model_destination: str,
1062
1544
  structured_output_tools: dict[str, OutputToolBinding],
1063
- exit_node: str,
1064
- ) -> Callable[[dict[str, Any], Runtime[ContextT]], str | None]:
1065
- def tools_to_model(state: dict[str, Any], runtime: Runtime[ContextT]) -> str | None: # noqa: ARG001
1545
+ end_destination: str,
1546
+ ) -> Callable[[dict[str, Any]], str | None]:
1547
+ def tools_to_model(state: dict[str, Any]) -> str | None:
1066
1548
  last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
1067
1549
 
1068
1550
  # 1. Exit condition: All executed tools have return_direct=True
1069
- if all(
1070
- tool_node.tools_by_name[c["name"]].return_direct
1071
- for c in last_ai_message.tool_calls
1072
- if c["name"] in tool_node.tools_by_name
1551
+ # Filter to only client-side tools (provider tools are not in tool_node)
1552
+ client_side_tool_calls = [
1553
+ c for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name
1554
+ ]
1555
+ if client_side_tool_calls and all(
1556
+ tool_node.tools_by_name[c["name"]].return_direct for c in client_side_tool_calls
1073
1557
  ):
1074
- return exit_node
1558
+ return end_destination
1075
1559
 
1076
1560
  # 2. Exit condition: A structured output tool was executed
1077
1561
  if any(t.name in structured_output_tools for t in tool_messages):
1078
- return exit_node
1562
+ return end_destination
1079
1563
 
1080
1564
  # 3. Default: Continue the loop
1081
1565
  # Tool execution completed successfully, route back to the model
1082
1566
  # so it can process the tool results and decide the next action.
1083
- return next_node
1567
+ return model_destination
1084
1568
 
1085
1569
  return tools_to_model
1086
1570
 
1087
1571
 
1088
1572
  def _add_middleware_edge(
1089
- graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
1573
+ graph: StateGraph[
1574
+ AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
1575
+ ],
1576
+ *,
1090
1577
  name: str,
1091
1578
  default_destination: str,
1092
1579
  model_destination: str,
1580
+ end_destination: str,
1093
1581
  can_jump_to: list[JumpTo] | None,
1094
1582
  ) -> None:
1095
1583
  """Add an edge to the graph for a middleware node.
@@ -1099,23 +1587,31 @@ def _add_middleware_edge(
1099
1587
  name: The name of the middleware node.
1100
1588
  default_destination: The default destination for the edge.
1101
1589
  model_destination: The destination for the edge to the model.
1590
+ end_destination: The destination for the edge to the end.
1102
1591
  can_jump_to: The conditionally jumpable destinations for the edge.
1103
1592
  """
1104
1593
  if can_jump_to:
1105
1594
 
1106
1595
  def jump_edge(state: dict[str, Any]) -> str:
1107
- return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
1596
+ return (
1597
+ _resolve_jump(
1598
+ state.get("jump_to"),
1599
+ model_destination=model_destination,
1600
+ end_destination=end_destination,
1601
+ )
1602
+ or default_destination
1603
+ )
1108
1604
 
1109
1605
  destinations = [default_destination]
1110
1606
 
1111
1607
  if "end" in can_jump_to:
1112
- destinations.append(END)
1608
+ destinations.append(end_destination)
1113
1609
  if "tools" in can_jump_to:
1114
1610
  destinations.append("tools")
1115
1611
  if "model" in can_jump_to and name != model_destination:
1116
1612
  destinations.append(model_destination)
1117
1613
 
1118
- graph.add_conditional_edges(name, jump_edge, destinations)
1614
+ graph.add_conditional_edges(name, RunnableCallable(jump_edge, trace=False), destinations)
1119
1615
 
1120
1616
  else:
1121
1617
  graph.add_edge(name, default_destination)