langchain 1.0.0a12__py3-none-any.whl → 1.0.0a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/factory.py +498 -167
  3. langchain/agents/middleware/__init__.py +9 -3
  4. langchain/agents/middleware/context_editing.py +15 -14
  5. langchain/agents/middleware/human_in_the_loop.py +213 -170
  6. langchain/agents/middleware/model_call_limit.py +2 -2
  7. langchain/agents/middleware/model_fallback.py +46 -36
  8. langchain/agents/middleware/pii.py +19 -19
  9. langchain/agents/middleware/planning.py +16 -11
  10. langchain/agents/middleware/prompt_caching.py +14 -11
  11. langchain/agents/middleware/summarization.py +1 -1
  12. langchain/agents/middleware/tool_call_limit.py +5 -5
  13. langchain/agents/middleware/tool_emulator.py +200 -0
  14. langchain/agents/middleware/tool_selection.py +25 -21
  15. langchain/agents/middleware/types.py +484 -225
  16. langchain/chat_models/base.py +85 -90
  17. langchain/embeddings/base.py +20 -20
  18. langchain/embeddings/cache.py +21 -21
  19. langchain/messages/__init__.py +2 -0
  20. langchain/storage/encoder_backed.py +22 -23
  21. langchain/tools/tool_node.py +388 -80
  22. {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/METADATA +8 -5
  23. langchain-1.0.0a13.dist-info/RECORD +36 -0
  24. langchain/_internal/__init__.py +0 -0
  25. langchain/_internal/_documents.py +0 -35
  26. langchain/_internal/_lazy_import.py +0 -35
  27. langchain/_internal/_prompts.py +0 -158
  28. langchain/_internal/_typing.py +0 -70
  29. langchain/_internal/_utils.py +0 -7
  30. langchain/agents/_internal/__init__.py +0 -1
  31. langchain/agents/_internal/_typing.py +0 -13
  32. langchain-1.0.0a12.dist-info/RECORD +0 -43
  33. {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/WHEEL +0 -0
  34. {langchain-1.0.0a12.dist-info → langchain-1.0.0a13.dist-info}/licenses/LICENSE +0 -0
@@ -13,6 +13,9 @@ from typing import (
13
13
  get_type_hints,
14
14
  )
15
15
 
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Awaitable
18
+
16
19
  from langchain_core.language_models.chat_models import BaseChatModel
17
20
  from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
18
21
  from langchain_core.tools import BaseTool
@@ -20,7 +23,7 @@ from langgraph._internal._runnable import RunnableCallable
20
23
  from langgraph.constants import END, START
21
24
  from langgraph.graph.state import StateGraph
22
25
  from langgraph.runtime import Runtime # noqa: TC002
23
- from langgraph.types import Send
26
+ from langgraph.types import Command, Send
24
27
  from langgraph.typing import ContextT # noqa: TC002
25
28
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar
26
29
 
@@ -29,6 +32,7 @@ from langchain.agents.middleware.types import (
29
32
  AgentState,
30
33
  JumpTo,
31
34
  ModelRequest,
35
+ ModelResponse,
32
36
  OmitFromSchema,
33
37
  PublicAgentState,
34
38
  )
@@ -44,6 +48,7 @@ from langchain.agents.structured_output import (
44
48
  )
45
49
  from langchain.chat_models import init_chat_model
46
50
  from langchain.tools import ToolNode
51
+ from langchain.tools.tool_node import ToolCallWithContext
47
52
 
48
53
  if TYPE_CHECKING:
49
54
  from collections.abc import Callable, Sequence
@@ -54,11 +59,217 @@ if TYPE_CHECKING:
54
59
  from langgraph.store.base import BaseStore
55
60
  from langgraph.types import Checkpointer
56
61
 
62
+ from langchain.tools.tool_node import ToolCallRequest, ToolCallWrapper
63
+
57
64
  STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
58
65
 
59
66
  ResponseT = TypeVar("ResponseT")
60
67
 
61
68
 
69
+ def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
70
+ """Normalize middleware return value to ModelResponse."""
71
+ if isinstance(result, AIMessage):
72
+ return ModelResponse(result=[result], structured_response=None)
73
+ return result
74
+
75
+
76
+ def _chain_model_call_handlers(
77
+ handlers: Sequence[
78
+ Callable[
79
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
80
+ ModelResponse | AIMessage,
81
+ ]
82
+ ],
83
+ ) -> (
84
+ Callable[
85
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
86
+ ModelResponse,
87
+ ]
88
+ | None
89
+ ):
90
+ """Compose multiple wrap_model_call handlers into single middleware stack.
91
+
92
+ Composes handlers so first in list becomes outermost layer. Each handler
93
+ receives a handler callback to execute inner layers.
94
+
95
+ Args:
96
+ handlers: List of handlers. First handler wraps all others.
97
+
98
+ Returns:
99
+ Composed handler, or None if handlers empty.
100
+
101
+ Example:
102
+ ```python
103
+ # handlers=[auth, retry] means: auth wraps retry
104
+ # Flow: auth calls retry, retry calls base handler
105
+ def auth(req, state, runtime, handler):
106
+ try:
107
+ return handler(req)
108
+ except UnauthorizedError:
109
+ refresh_token()
110
+ return handler(req)
111
+
112
+
113
+ def retry(req, state, runtime, handler):
114
+ for attempt in range(3):
115
+ try:
116
+ return handler(req)
117
+ except Exception:
118
+ if attempt == 2:
119
+ raise
120
+
121
+
122
+ handler = _chain_model_call_handlers([auth, retry])
123
+ ```
124
+ """
125
+ if not handlers:
126
+ return None
127
+
128
+ if len(handlers) == 1:
129
+ # Single handler - wrap to normalize output
130
+ single_handler = handlers[0]
131
+
132
+ def normalized_single(
133
+ request: ModelRequest,
134
+ handler: Callable[[ModelRequest], ModelResponse],
135
+ ) -> ModelResponse:
136
+ result = single_handler(request, handler)
137
+ return _normalize_to_model_response(result)
138
+
139
+ return normalized_single
140
+
141
+ def compose_two(
142
+ outer: Callable[
143
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
144
+ ModelResponse | AIMessage,
145
+ ],
146
+ inner: Callable[
147
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
148
+ ModelResponse | AIMessage,
149
+ ],
150
+ ) -> Callable[
151
+ [ModelRequest, Callable[[ModelRequest], ModelResponse]],
152
+ ModelResponse,
153
+ ]:
154
+ """Compose two handlers where outer wraps inner."""
155
+
156
+ def composed(
157
+ request: ModelRequest,
158
+ handler: Callable[[ModelRequest], ModelResponse],
159
+ ) -> ModelResponse:
160
+ # Create a wrapper that calls inner with the base handler and normalizes
161
+ def inner_handler(req: ModelRequest) -> ModelResponse:
162
+ inner_result = inner(req, handler)
163
+ return _normalize_to_model_response(inner_result)
164
+
165
+ # Call outer with the wrapped inner as its handler and normalize
166
+ outer_result = outer(request, inner_handler)
167
+ return _normalize_to_model_response(outer_result)
168
+
169
+ return composed
170
+
171
+ # Compose right-to-left: outer(inner(innermost(handler)))
172
+ result = handlers[-1]
173
+ for handler in reversed(handlers[:-1]):
174
+ result = compose_two(handler, result)
175
+
176
+ # Wrap to ensure final return type is exactly ModelResponse
177
+ def final_normalized(
178
+ request: ModelRequest,
179
+ handler: Callable[[ModelRequest], ModelResponse],
180
+ ) -> ModelResponse:
181
+ # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
182
+ final_result = result(request, handler)
183
+ return _normalize_to_model_response(final_result)
184
+
185
+ return final_normalized
186
+
187
+
188
+ def _chain_async_model_call_handlers(
189
+ handlers: Sequence[
190
+ Callable[
191
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
192
+ Awaitable[ModelResponse | AIMessage],
193
+ ]
194
+ ],
195
+ ) -> (
196
+ Callable[
197
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
198
+ Awaitable[ModelResponse],
199
+ ]
200
+ | None
201
+ ):
202
+ """Compose multiple async wrap_model_call handlers into single middleware stack.
203
+
204
+ Args:
205
+ handlers: List of async handlers. First handler wraps all others.
206
+
207
+ Returns:
208
+ Composed async handler, or None if handlers empty.
209
+ """
210
+ if not handlers:
211
+ return None
212
+
213
+ if len(handlers) == 1:
214
+ # Single handler - wrap to normalize output
215
+ single_handler = handlers[0]
216
+
217
+ async def normalized_single(
218
+ request: ModelRequest,
219
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
220
+ ) -> ModelResponse:
221
+ result = await single_handler(request, handler)
222
+ return _normalize_to_model_response(result)
223
+
224
+ return normalized_single
225
+
226
+ def compose_two(
227
+ outer: Callable[
228
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
229
+ Awaitable[ModelResponse | AIMessage],
230
+ ],
231
+ inner: Callable[
232
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
233
+ Awaitable[ModelResponse | AIMessage],
234
+ ],
235
+ ) -> Callable[
236
+ [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
237
+ Awaitable[ModelResponse],
238
+ ]:
239
+ """Compose two async handlers where outer wraps inner."""
240
+
241
+ async def composed(
242
+ request: ModelRequest,
243
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
244
+ ) -> ModelResponse:
245
+ # Create a wrapper that calls inner with the base handler and normalizes
246
+ async def inner_handler(req: ModelRequest) -> ModelResponse:
247
+ inner_result = await inner(req, handler)
248
+ return _normalize_to_model_response(inner_result)
249
+
250
+ # Call outer with the wrapped inner as its handler and normalize
251
+ outer_result = await outer(request, inner_handler)
252
+ return _normalize_to_model_response(outer_result)
253
+
254
+ return composed
255
+
256
+ # Compose right-to-left: outer(inner(innermost(handler)))
257
+ result = handlers[-1]
258
+ for handler in reversed(handlers[:-1]):
259
+ result = compose_two(handler, result)
260
+
261
+ # Wrap to ensure final return type is exactly ModelResponse
262
+ async def final_normalized(
263
+ request: ModelRequest,
264
+ handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
265
+ ) -> ModelResponse:
266
+ # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
267
+ final_result = await result(request, handler)
268
+ return _normalize_to_model_response(final_result)
269
+
270
+ return final_normalized
271
+
272
+
62
273
  def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
63
274
  """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
64
275
 
@@ -146,7 +357,7 @@ def _supports_provider_strategy(model: str | BaseChatModel) -> bool:
146
357
  model: Model name string or BaseChatModel instance.
147
358
 
148
359
  Returns:
149
- ``True`` if the model supports provider-specific structured output, ``False`` otherwise.
360
+ `True` if the model supports provider-specific structured output, `False` otherwise.
150
361
  """
151
362
  model_name: str | None = None
152
363
  if isinstance(model, str):
@@ -192,6 +403,52 @@ def _handle_structured_output_error(
192
403
  return False, ""
193
404
 
194
405
 
406
+ def _chain_tool_call_wrappers(
407
+ wrappers: Sequence[ToolCallWrapper],
408
+ ) -> ToolCallWrapper | None:
409
+ """Compose wrappers into middleware stack (first = outermost).
410
+
411
+ Args:
412
+ wrappers: Wrappers in middleware order.
413
+
414
+ Returns:
415
+ Composed wrapper, or None if empty.
416
+
417
+ Example:
418
+ wrapper = _chain_tool_call_wrappers([auth, cache, retry])
419
+ # Request flows: auth -> cache -> retry -> tool
420
+ # Response flows: tool -> retry -> cache -> auth
421
+ """
422
+ if not wrappers:
423
+ return None
424
+
425
+ if len(wrappers) == 1:
426
+ return wrappers[0]
427
+
428
+ def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
429
+ """Compose two wrappers where outer wraps inner."""
430
+
431
+ def composed(
432
+ request: ToolCallRequest,
433
+ execute: Callable[[ToolCallRequest], ToolMessage | Command],
434
+ ) -> ToolMessage | Command:
435
+ # Create a callable that invokes inner with the original execute
436
+ def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
437
+ return inner(req, execute)
438
+
439
+ # Outer can call call_inner multiple times
440
+ return outer(request, call_inner)
441
+
442
+ return composed
443
+
444
+ # Chain all wrappers: first -> second -> ... -> last
445
+ result = wrappers[-1]
446
+ for wrapper in reversed(wrappers[:-1]):
447
+ result = compose_two(wrapper, result)
448
+
449
+ return result
450
+
451
+
195
452
  def create_agent( # noqa: PLR0915
196
453
  model: str | BaseChatModel,
197
454
  tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
@@ -212,13 +469,13 @@ def create_agent( # noqa: PLR0915
212
469
  ]:
213
470
  """Creates an agent graph that calls tools in a loop until a stopping condition is met.
214
471
 
215
- For more details on using ``create_agent``,
472
+ For more details on using `create_agent`,
216
473
  visit [Agents](https://docs.langchain.com/oss/python/langchain/agents) documentation.
217
474
 
218
475
  Args:
219
476
  model: The language model for the agent. Can be a string identifier
220
- (e.g., ``"openai:gpt-4"``), a chat model instance (e.g., ``ChatOpenAI()``).
221
- tools: A list of tools, dicts, or callables. If ``None`` or an empty list,
477
+ (e.g., `"openai:gpt-4"`), a chat model instance (e.g., `ChatOpenAI()`).
478
+ tools: A list of tools, dicts, or callables. If `None` or an empty list,
222
479
  the agent will consist of a model node without a tool calling loop.
223
480
  system_prompt: An optional system prompt for the LLM. If provided as a string,
224
481
  it will be converted to a SystemMessage and added to the beginning
@@ -253,10 +510,10 @@ def create_agent( # noqa: PLR0915
253
510
  A compiled StateGraph that can be used for chat interactions.
254
511
 
255
512
  The agent node calls the language model with the messages list (after applying
256
- the system prompt). If the resulting AIMessage contains ``tool_calls``, the graph will
513
+ the system prompt). If the resulting AIMessage contains `tool_calls`, the graph will
257
514
  then call the tools. The tools node executes the tools and adds the responses
258
- to the messages list as ``ToolMessage`` objects. The agent node then calls the
259
- language model again. The process repeats until no more ``tool_calls`` are
515
+ to the messages list as `ToolMessage` objects. The agent node then calls the
516
+ language model again. The process repeats until no more `tool_calls` are
260
517
  present in the response. The agent then returns the full list of messages.
261
518
 
262
519
  Example:
@@ -319,6 +576,17 @@ def create_agent( # noqa: PLR0915
319
576
  structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
320
577
  middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
321
578
 
579
+ # Collect middleware with wrap_tool_call hooks
580
+ middleware_w_wrap_tool_call = [
581
+ m for m in middleware if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
582
+ ]
583
+
584
+ # Chain all wrap_tool_call handlers into a single composed handler
585
+ wrap_tool_call_wrapper = None
586
+ if middleware_w_wrap_tool_call:
587
+ wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
588
+ wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
589
+
322
590
  # Setup tools
323
591
  tool_node: ToolNode | None = None
324
592
  # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
@@ -329,7 +597,11 @@ def create_agent( # noqa: PLR0915
329
597
  available_tools = middleware_tools + regular_tools
330
598
 
331
599
  # Only create ToolNode if we have client-side tools
332
- tool_node = ToolNode(tools=available_tools) if available_tools else None
600
+ tool_node = (
601
+ ToolNode(tools=available_tools, wrap_tool_call=wrap_tool_call_wrapper)
602
+ if available_tools
603
+ else None
604
+ )
333
605
 
334
606
  # Default tools for ModelRequest initialization
335
607
  # Use converted BaseTool instances from ToolNode (not raw callables)
@@ -356,12 +628,6 @@ def create_agent( # noqa: PLR0915
356
628
  if m.__class__.before_model is not AgentMiddleware.before_model
357
629
  or m.__class__.abefore_model is not AgentMiddleware.abefore_model
358
630
  ]
359
- middleware_w_modify_model_request = [
360
- m
361
- for m in middleware
362
- if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
363
- or m.__class__.amodify_model_request is not AgentMiddleware.amodify_model_request
364
- ]
365
631
  middleware_w_after_model = [
366
632
  m
367
633
  for m in middleware
@@ -374,13 +640,27 @@ def create_agent( # noqa: PLR0915
374
640
  if m.__class__.after_agent is not AgentMiddleware.after_agent
375
641
  or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
376
642
  ]
377
- middleware_w_retry = [
643
+ middleware_w_wrap_model_call = [
644
+ m for m in middleware if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
645
+ ]
646
+ middleware_w_awrap_model_call = [
378
647
  m
379
648
  for m in middleware
380
- if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request
381
- or m.__class__.aretry_model_request is not AgentMiddleware.aretry_model_request
649
+ if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
382
650
  ]
383
651
 
652
+ # Compose wrap_model_call handlers into a single middleware stack (sync)
653
+ wrap_model_call_handler = None
654
+ if middleware_w_wrap_model_call:
655
+ sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
656
+ wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
657
+
658
+ # Compose awrap_model_call handlers into a single middleware stack (async)
659
+ awrap_model_call_handler = None
660
+ if middleware_w_awrap_model_call:
661
+ async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
662
+ awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
663
+
384
664
  state_schemas = {m.state_schema for m in middleware}
385
665
  state_schemas.add(AgentState)
386
666
 
@@ -504,7 +784,7 @@ def create_agent( # noqa: PLR0915
504
784
  request: The model request containing model, tools, and response format.
505
785
 
506
786
  Returns:
507
- Tuple of (bound_model, effective_response_format) where ``effective_response_format``
787
+ Tuple of (bound_model, effective_response_format) where `effective_response_format`
508
788
  is the actual strategy used (may differ from initial if auto-detected).
509
789
  """
510
790
  # Validate ONLY client-side tools that need to exist in tool_node
@@ -608,6 +888,30 @@ def create_agent( # noqa: PLR0915
608
888
  )
609
889
  return request.model.bind(**request.model_settings), None
610
890
 
891
+ def _execute_model_sync(request: ModelRequest) -> ModelResponse:
892
+ """Execute model and return response.
893
+
894
+ This is the core model execution logic wrapped by wrap_model_call handlers.
895
+ Raises any exceptions that occur during model invocation.
896
+ """
897
+ # Get the bound model (with auto-detection if needed)
898
+ model_, effective_response_format = _get_bound_model(request)
899
+ messages = request.messages
900
+ if request.system_prompt:
901
+ messages = [SystemMessage(request.system_prompt), *messages]
902
+
903
+ output = model_.invoke(messages)
904
+
905
+ # Handle model output to get messages and structured_response
906
+ handled_output = _handle_model_output(output, effective_response_format)
907
+ messages_list = handled_output["messages"]
908
+ structured_response = handled_output.get("structured_response")
909
+
910
+ return ModelResponse(
911
+ result=messages_list,
912
+ structured_response=structured_response,
913
+ )
914
+
611
915
  def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
612
916
  """Sync model request handler with sequential middleware processing."""
613
917
  request = ModelRequest(
@@ -617,62 +921,51 @@ def create_agent( # noqa: PLR0915
617
921
  response_format=initial_response_format,
618
922
  messages=state["messages"],
619
923
  tool_choice=None,
924
+ state=state,
925
+ runtime=runtime,
620
926
  )
621
927
 
622
- # Apply modify_model_request middleware in sequence
623
- for m in middleware_w_modify_model_request:
624
- if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request:
625
- m.modify_model_request(request, state, runtime)
626
- else:
627
- msg = (
628
- f"No synchronous function provided for "
629
- f'{m.__class__.__name__}.amodify_model_request".'
630
- "\nEither initialize with a synchronous function or invoke"
631
- " via the async API (ainvoke, astream, etc.)"
632
- )
633
- raise TypeError(msg)
928
+ if wrap_model_call_handler is None:
929
+ # No handlers - execute directly
930
+ response = _execute_model_sync(request)
931
+ else:
932
+ # Call composed handler with base handler
933
+ response = wrap_model_call_handler(request, _execute_model_sync)
634
934
 
635
- # Retry loop for model invocation with error handling
636
- # Hard limit of 100 attempts to prevent infinite loops from buggy middleware
637
- max_attempts = 100
638
- for attempt in range(1, max_attempts + 1):
639
- try:
640
- # Get the bound model (with auto-detection if needed)
641
- model_, effective_response_format = _get_bound_model(request)
642
- messages = request.messages
643
- if request.system_prompt:
644
- messages = [SystemMessage(request.system_prompt), *messages]
645
-
646
- output = model_.invoke(messages)
647
- return {
648
- "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
649
- "run_model_call_count": state.get("run_model_call_count", 0) + 1,
650
- **_handle_model_output(output, effective_response_format),
651
- }
652
- except Exception as error:
653
- # Try retry_model_request on each middleware
654
- for m in middleware_w_retry:
655
- if m.__class__.retry_model_request is not AgentMiddleware.retry_model_request:
656
- if retry_request := m.retry_model_request(
657
- error, request, state, runtime, attempt
658
- ):
659
- # Break on first middleware that wants to retry
660
- request = retry_request
661
- break
662
- else:
663
- msg = (
664
- f"No synchronous function provided for "
665
- f'{m.__class__.__name__}.aretry_model_request".'
666
- "\nEither initialize with a synchronous function or invoke"
667
- " via the async API (ainvoke, astream, etc.)"
668
- )
669
- raise TypeError(msg)
670
- else:
671
- raise
935
+ # Extract state updates from ModelResponse
936
+ state_updates = {"messages": response.result}
937
+ if response.structured_response is not None:
938
+ state_updates["structured_response"] = response.structured_response
939
+
940
+ return {
941
+ "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
942
+ "run_model_call_count": state.get("run_model_call_count", 0) + 1,
943
+ **state_updates,
944
+ }
672
945
 
673
- # If we exit the loop, max attempts exceeded
674
- msg = f"Maximum retry attempts ({max_attempts}) exceeded"
675
- raise RuntimeError(msg)
946
+ async def _execute_model_async(request: ModelRequest) -> ModelResponse:
947
+ """Execute model asynchronously and return response.
948
+
949
+ This is the core async model execution logic wrapped by wrap_model_call handlers.
950
+ Raises any exceptions that occur during model invocation.
951
+ """
952
+ # Get the bound model (with auto-detection if needed)
953
+ model_, effective_response_format = _get_bound_model(request)
954
+ messages = request.messages
955
+ if request.system_prompt:
956
+ messages = [SystemMessage(request.system_prompt), *messages]
957
+
958
+ output = await model_.ainvoke(messages)
959
+
960
+ # Handle model output to get messages and structured_response
961
+ handled_output = _handle_model_output(output, effective_response_format)
962
+ messages_list = handled_output["messages"]
963
+ structured_response = handled_output.get("structured_response")
964
+
965
+ return ModelResponse(
966
+ result=messages_list,
967
+ structured_response=structured_response,
968
+ )
676
969
 
677
970
  async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
678
971
  """Async model request handler with sequential middleware processing."""
@@ -683,45 +976,27 @@ def create_agent( # noqa: PLR0915
683
976
  response_format=initial_response_format,
684
977
  messages=state["messages"],
685
978
  tool_choice=None,
979
+ state=state,
980
+ runtime=runtime,
686
981
  )
687
982
 
688
- # Apply modify_model_request middleware in sequence
689
- for m in middleware_w_modify_model_request:
690
- await m.amodify_model_request(request, state, runtime)
983
+ if awrap_model_call_handler is None:
984
+ # No async handlers - execute directly
985
+ response = await _execute_model_async(request)
986
+ else:
987
+ # Call composed async handler with base handler
988
+ response = await awrap_model_call_handler(request, _execute_model_async)
691
989
 
692
- # Retry loop for model invocation with error handling
693
- # Hard limit of 100 attempts to prevent infinite loops from buggy middleware
694
- max_attempts = 100
695
- for attempt in range(1, max_attempts + 1):
696
- try:
697
- # Get the bound model (with auto-detection if needed)
698
- model_, effective_response_format = _get_bound_model(request)
699
- messages = request.messages
700
- if request.system_prompt:
701
- messages = [SystemMessage(request.system_prompt), *messages]
702
-
703
- output = await model_.ainvoke(messages)
704
- return {
705
- "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
706
- "run_model_call_count": state.get("run_model_call_count", 0) + 1,
707
- **_handle_model_output(output, effective_response_format),
708
- }
709
- except Exception as error:
710
- # Try retry_model_request on each middleware
711
- for m in middleware_w_retry:
712
- if retry_request := await m.aretry_model_request(
713
- error, request, state, runtime, attempt
714
- ):
715
- # Break on first middleware that wants to retry
716
- request = retry_request
717
- break
718
- else:
719
- # If no middleware wants to retry, re-raise the error
720
- raise
990
+ # Extract state updates from ModelResponse
991
+ state_updates = {"messages": response.result}
992
+ if response.structured_response is not None:
993
+ state_updates["structured_response"] = response.structured_response
721
994
 
722
- # If we exit the loop, max attempts exceeded
723
- msg = f"Maximum retry attempts ({max_attempts}) exceeded"
724
- raise RuntimeError(msg)
995
+ return {
996
+ "thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
997
+ "run_model_call_count": state.get("run_model_call_count", 0) + 1,
998
+ **state_updates,
999
+ }
725
1000
 
726
1001
  # Use sync or async based on model capabilities
727
1002
  graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
@@ -842,22 +1117,40 @@ def create_agent( # noqa: PLR0915
842
1117
  graph.add_conditional_edges(
843
1118
  "tools",
844
1119
  _make_tools_to_model_edge(
845
- tool_node, loop_entry_node, structured_output_tools, exit_node
1120
+ tool_node=tool_node,
1121
+ model_destination=loop_entry_node,
1122
+ structured_output_tools=structured_output_tools,
1123
+ end_destination=exit_node,
846
1124
  ),
847
1125
  [loop_entry_node, exit_node],
848
1126
  )
849
1127
 
1128
+ # base destinations are tools and exit_node
1129
+ # we add the loop_entry node to edge destinations if:
1130
+ # - there is an after model hook(s) -- allows jump_to to model
1131
+ # potentially artificially injected tool messages, ex HITL
1132
+ # - there is a response format -- to allow for jumping to model to handle
1133
+ # regenerating structured output tool calls
1134
+ model_to_tools_destinations = ["tools", exit_node]
1135
+ if response_format or loop_exit_node != "model":
1136
+ model_to_tools_destinations.append(loop_entry_node)
1137
+
850
1138
  graph.add_conditional_edges(
851
1139
  loop_exit_node,
852
1140
  _make_model_to_tools_edge(
853
- loop_entry_node, structured_output_tools, tool_node, exit_node
1141
+ model_destination=loop_entry_node,
1142
+ structured_output_tools=structured_output_tools,
1143
+ end_destination=exit_node,
854
1144
  ),
855
- [loop_entry_node, "tools", exit_node],
1145
+ model_to_tools_destinations,
856
1146
  )
857
1147
  elif len(structured_output_tools) > 0:
858
1148
  graph.add_conditional_edges(
859
1149
  loop_exit_node,
860
- _make_model_to_model_edge(loop_entry_node, exit_node),
1150
+ _make_model_to_model_edge(
1151
+ model_destination=loop_entry_node,
1152
+ end_destination=exit_node,
1153
+ ),
861
1154
  [loop_entry_node, exit_node],
862
1155
  )
863
1156
  elif loop_exit_node == "model":
@@ -867,9 +1160,10 @@ def create_agent( # noqa: PLR0915
867
1160
  else:
868
1161
  _add_middleware_edge(
869
1162
  graph,
870
- f"{middleware_w_after_model[0].name}.after_model",
871
- exit_node,
872
- loop_entry_node,
1163
+ name=f"{middleware_w_after_model[0].name}.after_model",
1164
+ default_destination=exit_node,
1165
+ model_destination=loop_entry_node,
1166
+ end_destination=exit_node,
873
1167
  can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
874
1168
  )
875
1169
 
@@ -878,17 +1172,19 @@ def create_agent( # noqa: PLR0915
878
1172
  for m1, m2 in itertools.pairwise(middleware_w_before_agent):
879
1173
  _add_middleware_edge(
880
1174
  graph,
881
- f"{m1.name}.before_agent",
882
- f"{m2.name}.before_agent",
883
- loop_entry_node,
1175
+ name=f"{m1.name}.before_agent",
1176
+ default_destination=f"{m2.name}.before_agent",
1177
+ model_destination=loop_entry_node,
1178
+ end_destination=exit_node,
884
1179
  can_jump_to=_get_can_jump_to(m1, "before_agent"),
885
1180
  )
886
1181
  # Connect last before_agent to loop_entry_node (before_model or model)
887
1182
  _add_middleware_edge(
888
1183
  graph,
889
- f"{middleware_w_before_agent[-1].name}.before_agent",
890
- loop_entry_node,
891
- loop_entry_node,
1184
+ name=f"{middleware_w_before_agent[-1].name}.before_agent",
1185
+ default_destination=loop_entry_node,
1186
+ model_destination=loop_entry_node,
1187
+ end_destination=exit_node,
892
1188
  can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
893
1189
  )
894
1190
 
@@ -897,17 +1193,19 @@ def create_agent( # noqa: PLR0915
897
1193
  for m1, m2 in itertools.pairwise(middleware_w_before_model):
898
1194
  _add_middleware_edge(
899
1195
  graph,
900
- f"{m1.name}.before_model",
901
- f"{m2.name}.before_model",
902
- loop_entry_node,
1196
+ name=f"{m1.name}.before_model",
1197
+ default_destination=f"{m2.name}.before_model",
1198
+ model_destination=loop_entry_node,
1199
+ end_destination=exit_node,
903
1200
  can_jump_to=_get_can_jump_to(m1, "before_model"),
904
1201
  )
905
1202
  # Go directly to model after the last before_model
906
1203
  _add_middleware_edge(
907
1204
  graph,
908
- f"{middleware_w_before_model[-1].name}.before_model",
909
- "model",
910
- loop_entry_node,
1205
+ name=f"{middleware_w_before_model[-1].name}.before_model",
1206
+ default_destination="model",
1207
+ model_destination=loop_entry_node,
1208
+ end_destination=exit_node,
911
1209
  can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
912
1210
  )
913
1211
 
@@ -919,9 +1217,10 @@ def create_agent( # noqa: PLR0915
919
1217
  m2 = middleware_w_after_model[idx - 1]
920
1218
  _add_middleware_edge(
921
1219
  graph,
922
- f"{m1.name}.after_model",
923
- f"{m2.name}.after_model",
924
- loop_entry_node,
1220
+ name=f"{m1.name}.after_model",
1221
+ default_destination=f"{m2.name}.after_model",
1222
+ model_destination=loop_entry_node,
1223
+ end_destination=exit_node,
925
1224
  can_jump_to=_get_can_jump_to(m1, "after_model"),
926
1225
  )
927
1226
  # Note: Connection from after_model to after_agent/END is handled above
@@ -935,18 +1234,20 @@ def create_agent( # noqa: PLR0915
935
1234
  m2 = middleware_w_after_agent[idx - 1]
936
1235
  _add_middleware_edge(
937
1236
  graph,
938
- f"{m1.name}.after_agent",
939
- f"{m2.name}.after_agent",
940
- loop_entry_node,
1237
+ name=f"{m1.name}.after_agent",
1238
+ default_destination=f"{m2.name}.after_agent",
1239
+ model_destination=loop_entry_node,
1240
+ end_destination=exit_node,
941
1241
  can_jump_to=_get_can_jump_to(m1, "after_agent"),
942
1242
  )
943
1243
 
944
1244
  # Connect the last after_agent to END
945
1245
  _add_middleware_edge(
946
1246
  graph,
947
- f"{middleware_w_after_agent[0].name}.after_agent",
948
- END,
949
- loop_entry_node,
1247
+ name=f"{middleware_w_after_agent[0].name}.after_agent",
1248
+ default_destination=END,
1249
+ model_destination=loop_entry_node,
1250
+ end_destination=exit_node,
950
1251
  can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
951
1252
  )
952
1253
 
@@ -961,11 +1262,16 @@ def create_agent( # noqa: PLR0915
961
1262
  )
962
1263
 
963
1264
 
964
- def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
1265
+ def _resolve_jump(
1266
+ jump_to: JumpTo | None,
1267
+ *,
1268
+ model_destination: str,
1269
+ end_destination: str,
1270
+ ) -> str | None:
965
1271
  if jump_to == "model":
966
- return first_node
1272
+ return model_destination
967
1273
  if jump_to == "end":
968
- return "__end__"
1274
+ return end_destination
969
1275
  if jump_to == "tools":
970
1276
  return "tools"
971
1277
  return None
@@ -988,17 +1294,21 @@ def _fetch_last_ai_and_tool_messages(
988
1294
 
989
1295
 
990
1296
  def _make_model_to_tools_edge(
991
- first_node: str,
1297
+ *,
1298
+ model_destination: str,
992
1299
  structured_output_tools: dict[str, OutputToolBinding],
993
- tool_node: ToolNode,
994
- exit_node: str,
995
- ) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
1300
+ end_destination: str,
1301
+ ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
996
1302
  def model_to_tools(
997
- state: dict[str, Any], runtime: Runtime[ContextT]
1303
+ state: dict[str, Any],
998
1304
  ) -> str | list[Send] | None:
999
1305
  # 1. if there's an explicit jump_to in the state, use it
1000
1306
  if jump_to := state.get("jump_to"):
1001
- return _resolve_jump(jump_to, first_node)
1307
+ return _resolve_jump(
1308
+ jump_to,
1309
+ model_destination=model_destination,
1310
+ end_destination=end_destination,
1311
+ )
1002
1312
 
1003
1313
  last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
1004
1314
  tool_message_ids = [m.tool_call_id for m in tool_messages]
@@ -1006,7 +1316,7 @@ def _make_model_to_tools_edge(
1006
1316
  # 2. if the model hasn't called any tools, exit the loop
1007
1317
  # this is the classic exit condition for an agent loop
1008
1318
  if len(last_ai_message.tool_calls) == 0:
1009
- return exit_node
1319
+ return end_destination
1010
1320
 
1011
1321
  pending_tool_calls = [
1012
1322
  c
@@ -1016,53 +1326,64 @@ def _make_model_to_tools_edge(
1016
1326
 
1017
1327
  # 3. if there are pending tool calls, jump to the tool node
1018
1328
  if pending_tool_calls:
1019
- pending_tool_calls = [
1020
- tool_node.inject_tool_args(call, state, runtime.store)
1021
- for call in pending_tool_calls
1329
+ return [
1330
+ Send(
1331
+ "tools",
1332
+ ToolCallWithContext(
1333
+ __type="tool_call_with_context",
1334
+ tool_call=tool_call,
1335
+ state=state,
1336
+ ),
1337
+ )
1338
+ for tool_call in pending_tool_calls
1022
1339
  ]
1023
- return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
1024
1340
 
1025
1341
  # 4. if there is a structured response, exit the loop
1026
1342
  if "structured_response" in state:
1027
- return exit_node
1343
+ return end_destination
1028
1344
 
1029
1345
  # 5. AIMessage has tool calls, but there are no pending tool calls
1030
- # which suggests the injection of artificial tool messages. jump to the first node
1031
- return first_node
1346
+ # which suggests the injection of artificial tool messages. jump to the model node
1347
+ return model_destination
1032
1348
 
1033
1349
  return model_to_tools
1034
1350
 
1035
1351
 
1036
1352
  def _make_model_to_model_edge(
1037
- first_node: str,
1038
- exit_node: str,
1039
- ) -> Callable[[dict[str, Any], Runtime[ContextT]], str | list[Send] | None]:
1353
+ *,
1354
+ model_destination: str,
1355
+ end_destination: str,
1356
+ ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
1040
1357
  def model_to_model(
1041
1358
  state: dict[str, Any],
1042
- runtime: Runtime[ContextT], # noqa: ARG001
1043
1359
  ) -> str | list[Send] | None:
1044
1360
  # 1. Priority: Check for explicit jump_to directive from middleware
1045
1361
  if jump_to := state.get("jump_to"):
1046
- return _resolve_jump(jump_to, first_node)
1362
+ return _resolve_jump(
1363
+ jump_to,
1364
+ model_destination=model_destination,
1365
+ end_destination=end_destination,
1366
+ )
1047
1367
 
1048
1368
  # 2. Exit condition: A structured response was generated
1049
1369
  if "structured_response" in state:
1050
- return exit_node
1370
+ return end_destination
1051
1371
 
1052
1372
  # 3. Default: Continue the loop, there may have been an issue
1053
1373
  # with structured output generation, so we need to retry
1054
- return first_node
1374
+ return model_destination
1055
1375
 
1056
1376
  return model_to_model
1057
1377
 
1058
1378
 
1059
1379
  def _make_tools_to_model_edge(
1380
+ *,
1060
1381
  tool_node: ToolNode,
1061
- next_node: str,
1382
+ model_destination: str,
1062
1383
  structured_output_tools: dict[str, OutputToolBinding],
1063
- exit_node: str,
1064
- ) -> Callable[[dict[str, Any], Runtime[ContextT]], str | None]:
1065
- def tools_to_model(state: dict[str, Any], runtime: Runtime[ContextT]) -> str | None: # noqa: ARG001
1384
+ end_destination: str,
1385
+ ) -> Callable[[dict[str, Any]], str | None]:
1386
+ def tools_to_model(state: dict[str, Any]) -> str | None:
1066
1387
  last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
1067
1388
 
1068
1389
  # 1. Exit condition: All executed tools have return_direct=True
@@ -1071,25 +1392,27 @@ def _make_tools_to_model_edge(
1071
1392
  for c in last_ai_message.tool_calls
1072
1393
  if c["name"] in tool_node.tools_by_name
1073
1394
  ):
1074
- return exit_node
1395
+ return end_destination
1075
1396
 
1076
1397
  # 2. Exit condition: A structured output tool was executed
1077
1398
  if any(t.name in structured_output_tools for t in tool_messages):
1078
- return exit_node
1399
+ return end_destination
1079
1400
 
1080
1401
  # 3. Default: Continue the loop
1081
1402
  # Tool execution completed successfully, route back to the model
1082
1403
  # so it can process the tool results and decide the next action.
1083
- return next_node
1404
+ return model_destination
1084
1405
 
1085
1406
  return tools_to_model
1086
1407
 
1087
1408
 
1088
1409
  def _add_middleware_edge(
1089
1410
  graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
1411
+ *,
1090
1412
  name: str,
1091
1413
  default_destination: str,
1092
1414
  model_destination: str,
1415
+ end_destination: str,
1093
1416
  can_jump_to: list[JumpTo] | None,
1094
1417
  ) -> None:
1095
1418
  """Add an edge to the graph for a middleware node.
@@ -1099,17 +1422,25 @@ def _add_middleware_edge(
1099
1422
  name: The name of the middleware node.
1100
1423
  default_destination: The default destination for the edge.
1101
1424
  model_destination: The destination for the edge to the model.
1425
+ end_destination: The destination for the edge to the end.
1102
1426
  can_jump_to: The conditionally jumpable destinations for the edge.
1103
1427
  """
1104
1428
  if can_jump_to:
1105
1429
 
1106
1430
  def jump_edge(state: dict[str, Any]) -> str:
1107
- return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
1431
+ return (
1432
+ _resolve_jump(
1433
+ state.get("jump_to"),
1434
+ model_destination=model_destination,
1435
+ end_destination=end_destination,
1436
+ )
1437
+ or default_destination
1438
+ )
1108
1439
 
1109
1440
  destinations = [default_destination]
1110
1441
 
1111
1442
  if "end" in can_jump_to:
1112
- destinations.append(END)
1443
+ destinations.append(end_destination)
1113
1444
  if "tools" in can_jump_to:
1114
1445
  destinations.append("tools")
1115
1446
  if "model" in can_jump_to and name != model_destination: