langchain 1.0.0a2__py3-none-any.whl → 1.0.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,554 @@
1
+ """Middleware agent implementation."""
2
+
3
+ import itertools
4
+ from collections.abc import Callable, Sequence
5
+ from typing import Any, Union
6
+
7
+ from langchain_core.language_models.chat_models import BaseChatModel
8
+ from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
9
+ from langchain_core.runnables import Runnable
10
+ from langchain_core.tools import BaseTool
11
+ from langgraph.constants import END, START
12
+ from langgraph.graph.state import StateGraph
13
+ from langgraph.typing import ContextT
14
+ from typing_extensions import TypedDict, TypeVar
15
+
16
+ from langchain.agents.middleware.types import (
17
+ AgentMiddleware,
18
+ AgentState,
19
+ JumpTo,
20
+ ModelRequest,
21
+ PublicAgentState,
22
+ )
23
+
24
+ # Import structured output classes from the old implementation
25
+ from langchain.agents.structured_output import (
26
+ MultipleStructuredOutputsError,
27
+ OutputToolBinding,
28
+ ProviderStrategy,
29
+ ProviderStrategyBinding,
30
+ ResponseFormat,
31
+ StructuredOutputValidationError,
32
+ ToolStrategy,
33
+ )
34
+ from langchain.agents.tool_node import ToolNode
35
+ from langchain.chat_models import init_chat_model
36
+
37
+ STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
38
+
39
+
40
+ def _merge_state_schemas(schemas: list[type]) -> type:
41
+ """Merge multiple TypedDict schemas into a single schema with all fields."""
42
+ if not schemas:
43
+ return AgentState
44
+
45
+ all_annotations = {}
46
+
47
+ for schema in schemas:
48
+ all_annotations.update(schema.__annotations__)
49
+
50
+ return TypedDict("MergedState", all_annotations) # type: ignore[operator]
51
+
52
+
53
+ def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]:
54
+ """Filter state to only include fields defined in the given schema."""
55
+ if not hasattr(schema, "__annotations__"):
56
+ return state
57
+
58
+ schema_fields = set(schema.__annotations__.keys())
59
+ return {k: v for k, v in state.items() if k in schema_fields}
60
+
61
+
62
+ def _supports_native_structured_output(model: Union[str, BaseChatModel]) -> bool:
63
+ """Check if a model supports native structured output."""
64
+ model_name: str | None = None
65
+ if isinstance(model, str):
66
+ model_name = model
67
+ elif isinstance(model, BaseChatModel):
68
+ model_name = getattr(model, "model_name", None)
69
+
70
+ return (
71
+ "grok" in model_name.lower()
72
+ or any(part in model_name for part in ["gpt-5", "gpt-4.1", "gpt-oss", "o3-pro", "o3-mini"])
73
+ if model_name
74
+ else False
75
+ )
76
+
77
+
78
+ def _handle_structured_output_error(
79
+ exception: Exception,
80
+ response_format: ResponseFormat,
81
+ ) -> tuple[bool, str]:
82
+ """Handle structured output error. Returns (should_retry, retry_tool_message)."""
83
+ if not isinstance(response_format, ToolStrategy):
84
+ return False, ""
85
+
86
+ handle_errors = response_format.handle_errors
87
+
88
+ if handle_errors is False:
89
+ return False, ""
90
+ if handle_errors is True:
91
+ return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
92
+ if isinstance(handle_errors, str):
93
+ return True, handle_errors
94
+ if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
95
+ if isinstance(exception, handle_errors):
96
+ return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
97
+ return False, ""
98
+ if isinstance(handle_errors, tuple):
99
+ if any(isinstance(exception, exc_type) for exc_type in handle_errors):
100
+ return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
101
+ return False, ""
102
+ if callable(handle_errors):
103
+ # type narrowing not working appropriately w/ callable check, can fix later
104
+ return True, handle_errors(exception) # type: ignore[return-value,call-arg]
105
+ return False, ""
106
+
107
+
108
+ ResponseT = TypeVar("ResponseT")
109
+
110
+
111
+ def create_agent( # noqa: PLR0915
112
+ *,
113
+ model: str | BaseChatModel,
114
+ tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
115
+ system_prompt: str | None = None,
116
+ middleware: Sequence[AgentMiddleware] = (),
117
+ response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
118
+ context_schema: type[ContextT] | None = None,
119
+ ) -> StateGraph[
120
+ AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
121
+ ]:
122
+ """Create a middleware agent graph."""
123
+ # init chat model
124
+ if isinstance(model, str):
125
+ model = init_chat_model(model)
126
+
127
+ # Handle tools being None or empty
128
+ if tools is None:
129
+ tools = []
130
+
131
+ # Setup structured output
132
+ structured_output_tools: dict[str, OutputToolBinding] = {}
133
+ native_output_binding: ProviderStrategyBinding | None = None
134
+
135
+ if response_format is not None:
136
+ if not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
137
+ # Auto-detect strategy based on model capabilities
138
+ if _supports_native_structured_output(model):
139
+ response_format = ProviderStrategy(schema=response_format)
140
+ else:
141
+ response_format = ToolStrategy(schema=response_format)
142
+
143
+ if isinstance(response_format, ToolStrategy):
144
+ # Setup tools strategy for structured output
145
+ for response_schema in response_format.schema_specs:
146
+ structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
147
+ structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
148
+ elif isinstance(response_format, ProviderStrategy):
149
+ # Setup native strategy
150
+ native_output_binding = ProviderStrategyBinding.from_schema_spec(
151
+ response_format.schema_spec
152
+ )
153
+ middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
154
+
155
+ # Setup tools
156
+ tool_node: ToolNode | None = None
157
+ if isinstance(tools, list):
158
+ # Extract builtin provider tools (dict format)
159
+ builtin_tools = [t for t in tools if isinstance(t, dict)]
160
+ regular_tools = [t for t in tools if not isinstance(t, dict)]
161
+
162
+ # Add structured output tools to regular tools
163
+ structured_tools = [info.tool for info in structured_output_tools.values()]
164
+ all_tools = middleware_tools + regular_tools + structured_tools
165
+
166
+ # Only create ToolNode if we have tools
167
+ tool_node = ToolNode(tools=all_tools) if all_tools else None
168
+ default_tools = regular_tools + builtin_tools + structured_tools + middleware_tools
169
+ elif isinstance(tools, ToolNode):
170
+ # tools is ToolNode or None
171
+ tool_node = tools
172
+ if tool_node:
173
+ default_tools = list(tool_node.tools_by_name.values()) + middleware_tools
174
+ # Update tool node to know about tools provided by middleware
175
+ all_tools = list(tool_node.tools_by_name.values()) + middleware_tools
176
+ tool_node = ToolNode(all_tools)
177
+ # Add structured output tools
178
+ for info in structured_output_tools.values():
179
+ default_tools.append(info.tool)
180
+ else:
181
+ default_tools = (
182
+ list(structured_output_tools.values()) if structured_output_tools else []
183
+ ) + middleware_tools
184
+
185
+ # validate middleware
186
+ assert len({m.__class__.__name__ for m in middleware}) == len(middleware), ( # noqa: S101
187
+ "Please remove duplicate middleware instances."
188
+ )
189
+ middleware_w_before = [
190
+ m for m in middleware if m.__class__.before_model is not AgentMiddleware.before_model
191
+ ]
192
+ middleware_w_modify_model_request = [
193
+ m
194
+ for m in middleware
195
+ if m.__class__.modify_model_request is not AgentMiddleware.modify_model_request
196
+ ]
197
+ middleware_w_after = [
198
+ m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
199
+ ]
200
+
201
+ # Collect all middleware state schemas and create merged schema
202
+ merged_state_schema: type[AgentState] = _merge_state_schemas(
203
+ [m.state_schema for m in middleware]
204
+ )
205
+
206
+ # create graph, add nodes
207
+ graph = StateGraph(
208
+ merged_state_schema,
209
+ input_schema=PublicAgentState,
210
+ output_schema=PublicAgentState,
211
+ context_schema=context_schema,
212
+ )
213
+
214
+ def _prepare_model_request(state: dict[str, Any]) -> tuple[ModelRequest, list[AnyMessage]]:
215
+ """Prepare model request and messages."""
216
+ request = state.get("model_request") or ModelRequest(
217
+ model=model,
218
+ tools=default_tools,
219
+ system_prompt=system_prompt,
220
+ response_format=response_format,
221
+ messages=state["messages"],
222
+ tool_choice=None,
223
+ )
224
+
225
+ # prepare messages
226
+ messages = request.messages
227
+ if request.system_prompt:
228
+ messages = [SystemMessage(request.system_prompt), *messages]
229
+
230
+ return request, messages
231
+
232
+ def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
233
+ """Handle model output including structured responses."""
234
+ # Handle structured output with native strategy
235
+ if isinstance(response_format, ProviderStrategy):
236
+ if not output.tool_calls and native_output_binding:
237
+ structured_response = native_output_binding.parse(output)
238
+ return {"messages": [output], "response": structured_response}
239
+ if state.get("response") is not None:
240
+ return {"messages": [output], "response": None}
241
+ return {"messages": [output]}
242
+
243
+ # Handle structured output with tools strategy
244
+ if (
245
+ isinstance(response_format, ToolStrategy)
246
+ and isinstance(output, AIMessage)
247
+ and output.tool_calls
248
+ ):
249
+ structured_tool_calls = [
250
+ tc for tc in output.tool_calls if tc["name"] in structured_output_tools
251
+ ]
252
+
253
+ if structured_tool_calls:
254
+ exception: Exception | None = None
255
+ if len(structured_tool_calls) > 1:
256
+ # Handle multiple structured outputs error
257
+ tool_names = [tc["name"] for tc in structured_tool_calls]
258
+ exception = MultipleStructuredOutputsError(tool_names)
259
+ should_retry, error_message = _handle_structured_output_error(
260
+ exception, response_format
261
+ )
262
+ if not should_retry:
263
+ raise exception
264
+
265
+ # Add error messages and retry
266
+ tool_messages = [
267
+ ToolMessage(
268
+ content=error_message,
269
+ tool_call_id=tc["id"],
270
+ name=tc["name"],
271
+ )
272
+ for tc in structured_tool_calls
273
+ ]
274
+ return {"messages": [output, *tool_messages]}
275
+
276
+ # Handle single structured output
277
+ tool_call = structured_tool_calls[0]
278
+ try:
279
+ structured_tool_binding = structured_output_tools[tool_call["name"]]
280
+ structured_response = structured_tool_binding.parse(tool_call["args"])
281
+
282
+ tool_message_content = (
283
+ response_format.tool_message_content
284
+ if response_format.tool_message_content
285
+ else f"Returning structured response: {structured_response}"
286
+ )
287
+
288
+ return {
289
+ "messages": [
290
+ output,
291
+ ToolMessage(
292
+ content=tool_message_content,
293
+ tool_call_id=tool_call["id"],
294
+ name=tool_call["name"],
295
+ ),
296
+ ],
297
+ "response": structured_response,
298
+ }
299
+ except Exception as exc: # noqa: BLE001
300
+ exception = StructuredOutputValidationError(tool_call["name"], exc)
301
+ should_retry, error_message = _handle_structured_output_error(
302
+ exception, response_format
303
+ )
304
+ if not should_retry:
305
+ raise exception
306
+
307
+ return {
308
+ "messages": [
309
+ output,
310
+ ToolMessage(
311
+ content=error_message,
312
+ tool_call_id=tool_call["id"],
313
+ name=tool_call["name"],
314
+ ),
315
+ ],
316
+ }
317
+
318
+ # Standard response handling
319
+ if state.get("response") is not None:
320
+ return {"messages": [output], "response": None}
321
+ return {"messages": [output]}
322
+
323
+ def _get_bound_model(request: ModelRequest) -> Runnable:
324
+ """Get the model with appropriate tool bindings."""
325
+ if isinstance(response_format, ProviderStrategy):
326
+ # Use native structured output
327
+ kwargs = response_format.to_model_kwargs()
328
+ return request.model.bind_tools(
329
+ request.tools, strict=True, **kwargs, **request.model_settings
330
+ )
331
+ if isinstance(response_format, ToolStrategy):
332
+ tool_choice = "any" if structured_output_tools else request.tool_choice
333
+ return request.model.bind_tools(
334
+ request.tools, tool_choice=tool_choice, **request.model_settings
335
+ )
336
+ # Standard model binding
337
+ if request.tools:
338
+ return request.model.bind_tools(
339
+ request.tools, tool_choice=request.tool_choice, **request.model_settings
340
+ )
341
+ return request.model.bind(**request.model_settings)
342
+
343
+ def model_request(state: dict[str, Any]) -> dict[str, Any]:
344
+ """Sync model request handler with sequential middleware processing."""
345
+ # Start with the base model request
346
+ request, messages = _prepare_model_request(state)
347
+
348
+ # Apply modify_model_request middleware in sequence
349
+ for m in middleware_w_modify_model_request:
350
+ # Filter state to only include fields defined in this middleware's schema
351
+ filtered_state = _filter_state_for_schema(state, m.state_schema)
352
+ request = m.modify_model_request(request, filtered_state)
353
+
354
+ # Get the bound model with the final request
355
+ model_ = _get_bound_model(request)
356
+ output = model_.invoke(messages)
357
+ return _handle_model_output(state, output)
358
+
359
+ async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
360
+ """Async model request handler with sequential middleware processing."""
361
+ # Start with the base model request
362
+ request, messages = _prepare_model_request(state)
363
+
364
+ # Apply modify_model_request middleware in sequence
365
+ for m in middleware_w_modify_model_request:
366
+ # Filter state to only include fields defined in this middleware's schema
367
+ filtered_state = _filter_state_for_schema(state, m.state_schema)
368
+ request = m.modify_model_request(request, filtered_state)
369
+
370
+ # Get the bound model with the final request
371
+ model_ = _get_bound_model(request)
372
+ output = await model_.ainvoke(messages)
373
+ return _handle_model_output(state, output)
374
+
375
+ # Use sync or async based on model capabilities
376
+ from langgraph._internal._runnable import RunnableCallable
377
+
378
+ graph.add_node("model_request", RunnableCallable(model_request, amodel_request))
379
+
380
+ # Only add tools node if we have tools
381
+ if tool_node is not None:
382
+ graph.add_node("tools", tool_node)
383
+
384
+ # Add middleware nodes
385
+ for m in middleware:
386
+ if m.__class__.before_model is not AgentMiddleware.before_model:
387
+ graph.add_node(
388
+ f"{m.__class__.__name__}.before_model",
389
+ m.before_model,
390
+ input_schema=m.state_schema,
391
+ )
392
+
393
+ if m.__class__.after_model is not AgentMiddleware.after_model:
394
+ graph.add_node(
395
+ f"{m.__class__.__name__}.after_model",
396
+ m.after_model,
397
+ input_schema=m.state_schema,
398
+ )
399
+
400
+ # add start edge
401
+ first_node = (
402
+ f"{middleware_w_before[0].__class__.__name__}.before_model"
403
+ if middleware_w_before
404
+ else "model_request"
405
+ )
406
+ last_node = (
407
+ f"{middleware_w_after[0].__class__.__name__}.after_model"
408
+ if middleware_w_after
409
+ else "model_request"
410
+ )
411
+ graph.add_edge(START, first_node)
412
+
413
+ # add conditional edges only if tools exist
414
+ if tool_node is not None:
415
+ graph.add_conditional_edges(
416
+ "tools",
417
+ _make_tools_to_model_edge(tool_node, first_node),
418
+ [first_node, END],
419
+ )
420
+ graph.add_conditional_edges(
421
+ last_node,
422
+ _make_model_to_tools_edge(first_node, structured_output_tools),
423
+ [first_node, "tools", END],
424
+ )
425
+ elif last_node == "model_request":
426
+ # If no tools, just go to END from model
427
+ graph.add_edge(last_node, END)
428
+ else:
429
+ # If after_model, then need to check for jump_to
430
+ _add_middleware_edge(
431
+ graph,
432
+ f"{middleware_w_after[0].__class__.__name__}.after_model",
433
+ END,
434
+ first_node,
435
+ tools_available=tool_node is not None,
436
+ )
437
+
438
+ # Add middleware edges (same as before)
439
+ if middleware_w_before:
440
+ for m1, m2 in itertools.pairwise(middleware_w_before):
441
+ _add_middleware_edge(
442
+ graph,
443
+ f"{m1.__class__.__name__}.before_model",
444
+ f"{m2.__class__.__name__}.before_model",
445
+ first_node,
446
+ tools_available=tool_node is not None,
447
+ )
448
+ # Go directly to model_request after the last before_model
449
+ _add_middleware_edge(
450
+ graph,
451
+ f"{middleware_w_before[-1].__class__.__name__}.before_model",
452
+ "model_request",
453
+ first_node,
454
+ tools_available=tool_node is not None,
455
+ )
456
+
457
+ if middleware_w_after:
458
+ graph.add_edge("model_request", f"{middleware_w_after[-1].__class__.__name__}.after_model")
459
+ for idx in range(len(middleware_w_after) - 1, 0, -1):
460
+ m1 = middleware_w_after[idx]
461
+ m2 = middleware_w_after[idx - 1]
462
+ _add_middleware_edge(
463
+ graph,
464
+ f"{m1.__class__.__name__}.after_model",
465
+ f"{m2.__class__.__name__}.after_model",
466
+ first_node,
467
+ tools_available=tool_node is not None,
468
+ )
469
+
470
+ return graph
471
+
472
+
473
+ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
474
+ if jump_to == "model":
475
+ return first_node
476
+ if jump_to:
477
+ return jump_to
478
+ return None
479
+
480
+
481
+ def _make_model_to_tools_edge(
482
+ first_node: str, structured_output_tools: dict[str, OutputToolBinding]
483
+ ) -> Callable[[AgentState], str | None]:
484
+ def model_to_tools(state: AgentState) -> str | None:
485
+ if jump_to := state.get("jump_to"):
486
+ return _resolve_jump(jump_to, first_node)
487
+
488
+ message = state["messages"][-1]
489
+
490
+ # Check if this is a ToolMessage from structured output - if so, end
491
+ if isinstance(message, ToolMessage) and message.name in structured_output_tools:
492
+ return END
493
+
494
+ # Check for tool calls
495
+ if isinstance(message, AIMessage) and message.tool_calls:
496
+ # If all tool calls are for structured output, don't go to tools
497
+ non_structured_calls = [
498
+ tc for tc in message.tool_calls if tc["name"] not in structured_output_tools
499
+ ]
500
+ if non_structured_calls:
501
+ return "tools"
502
+
503
+ return END
504
+
505
+ return model_to_tools
506
+
507
+
508
+ def _make_tools_to_model_edge(
509
+ tool_node: ToolNode, next_node: str
510
+ ) -> Callable[[AgentState], str | None]:
511
+ def tools_to_model(state: AgentState) -> str | None:
512
+ ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
513
+ if all(
514
+ tool_node.tools_by_name[c["name"]].return_direct
515
+ for c in ai_message.tool_calls
516
+ if c["name"] in tool_node.tools_by_name
517
+ ):
518
+ return END
519
+
520
+ return next_node
521
+
522
+ return tools_to_model
523
+
524
+
525
+ def _add_middleware_edge(
526
+ graph: StateGraph[AgentState, ContextT, PublicAgentState, PublicAgentState],
527
+ name: str,
528
+ default_destination: str,
529
+ model_destination: str,
530
+ tools_available: bool, # noqa: FBT001
531
+ ) -> None:
532
+ """Add an edge to the graph for a middleware node.
533
+
534
+ Args:
535
+ graph: The graph to add the edge to.
536
+ method: The method to call for the middleware node.
537
+ name: The name of the middleware node.
538
+ default_destination: The default destination for the edge.
539
+ model_destination: The destination for the edge to the model.
540
+ tools_available: Whether tools are available for the edge to potentially route to.
541
+ """
542
+
543
+ def jump_edge(state: AgentState) -> str:
544
+ return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
545
+
546
+ destinations = [default_destination]
547
+ if default_destination != END:
548
+ destinations.append(END)
549
+ if tools_available:
550
+ destinations.append("tools")
551
+ if name != model_destination:
552
+ destinations.append(model_destination)
553
+
554
+ graph.add_conditional_edges(name, jump_edge, destinations)
@@ -1,3 +1,5 @@
1
+ """React agent implementation."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import inspect
@@ -43,6 +45,7 @@ from langgraph.typing import ContextT, StateT
43
45
  from pydantic import BaseModel
44
46
  from typing_extensions import NotRequired, TypedDict, TypeVar
45
47
 
48
+ from langchain.agents.middleware_agent import create_agent as create_middleware_agent
46
49
  from langchain.agents.structured_output import (
47
50
  MultipleStructuredOutputsError,
48
51
  OutputToolBinding,
@@ -64,6 +67,7 @@ if TYPE_CHECKING:
64
67
  from langchain.agents._internal._typing import (
65
68
  SyncOrAsync,
66
69
  )
70
+ from langchain.agents.types import AgentMiddleware
67
71
 
68
72
  StructuredResponseT = TypeVar("StructuredResponseT", default=None)
69
73
 
@@ -898,7 +902,7 @@ def _supports_native_structured_output(
898
902
  )
899
903
 
900
904
 
901
- def create_react_agent( # noqa: D417
905
+ def create_agent( # noqa: D417
902
906
  model: Union[
903
907
  str,
904
908
  BaseChatModel,
@@ -906,6 +910,7 @@ def create_react_agent( # noqa: D417
906
910
  ],
907
911
  tools: Union[Sequence[Union[BaseTool, Callable, dict[str, Any]]], ToolNode],
908
912
  *,
913
+ middleware: Sequence[AgentMiddleware] = (),
909
914
  prompt: Prompt | None = None,
910
915
  response_format: Union[
911
916
  ToolStrategy[StructuredResponseT],
@@ -928,7 +933,7 @@ def create_react_agent( # noqa: D417
928
933
  ) -> CompiledStateGraph[StateT, ContextT]:
929
934
  """Creates an agent graph that calls tools in a loop until a stopping condition is met.
930
935
 
931
- For more details on using `create_react_agent`, visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
936
+ For more details on using `create_agent`, visit [Agents](https://langchain-ai.github.io/langgraph/agents/overview/) documentation.
932
937
 
933
938
  Args:
934
939
  model: The language model for the agent. Supports static and dynamic
@@ -1096,13 +1101,13 @@ def create_react_agent( # noqa: D417
1096
1101
 
1097
1102
  Example:
1098
1103
  ```python
1099
- from langchain.agents import create_react_agent
1104
+ from langchain.agents import create_agent
1100
1105
 
1101
1106
  def check_weather(location: str) -> str:
1102
1107
  '''Return the weather forecast for the specified location.'''
1103
1108
  return f"It's always sunny in {location}"
1104
1109
 
1105
- graph = create_react_agent(
1110
+ graph = create_agent(
1106
1111
  "anthropic:claude-3-7-sonnet-latest",
1107
1112
  tools=[check_weather],
1108
1113
  prompt="You are a helpful assistant",
@@ -1112,6 +1117,29 @@ def create_react_agent( # noqa: D417
1112
1117
  print(chunk)
1113
1118
  ```
1114
1119
  """
1120
+ if middleware:
1121
+ assert isinstance(model, str | BaseChatModel) # noqa: S101
1122
+ assert isinstance(prompt, str | None) # noqa: S101
1123
+ assert not isinstance(response_format, tuple) # noqa: S101
1124
+ assert pre_model_hook is None # noqa: S101
1125
+ assert post_model_hook is None # noqa: S101
1126
+ assert state_schema is None # noqa: S101
1127
+ return create_middleware_agent( # type: ignore[return-value]
1128
+ model=model,
1129
+ tools=tools,
1130
+ system_prompt=prompt,
1131
+ middleware=middleware,
1132
+ response_format=response_format,
1133
+ context_schema=context_schema,
1134
+ ).compile(
1135
+ checkpointer=checkpointer,
1136
+ store=store,
1137
+ name=name,
1138
+ interrupt_after=interrupt_after,
1139
+ interrupt_before=interrupt_before,
1140
+ debug=debug,
1141
+ )
1142
+
1115
1143
  # Handle deprecated config_schema parameter
1116
1144
  if (config_schema := deprecated_kwargs.pop("config_schema", MISSING)) is not MISSING:
1117
1145
  warn(
@@ -1123,7 +1151,7 @@ def create_react_agent( # noqa: D417
1123
1151
  context_schema = config_schema
1124
1152
 
1125
1153
  if len(deprecated_kwargs) > 0:
1126
- msg = f"create_react_agent() got unexpected keyword arguments: {deprecated_kwargs}"
1154
+ msg = f"create_agent() got unexpected keyword arguments: {deprecated_kwargs}"
1127
1155
  raise TypeError(msg)
1128
1156
 
1129
1157
  if response_format and not isinstance(response_format, (ToolStrategy, ProviderStrategy)):
@@ -1171,5 +1199,5 @@ __all__ = [
1171
1199
  "AgentStatePydantic",
1172
1200
  "AgentStateWithStructuredResponse",
1173
1201
  "AgentStateWithStructuredResponsePydantic",
1174
- "create_react_agent",
1202
+ "create_agent",
1175
1203
  ]
@@ -1,3 +1,5 @@
1
+ """Chat models."""
2
+
1
3
  from langchain_core.language_models import BaseChatModel
2
4
 
3
5
  from langchain.chat_models.base import init_chat_model
@@ -1,3 +1,5 @@
1
+ """Factory functions for chat models."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import warnings
@@ -1,3 +1,5 @@
1
+ """Document."""
2
+
1
3
  from langchain_core.documents import Document
2
4
 
3
5
  __all__ = [
@@ -1,3 +1,5 @@
1
+ """Embeddings."""
2
+
1
3
  from langchain_core.embeddings import Embeddings
2
4
 
3
5
  from langchain.embeddings.base import init_embeddings
@@ -1,3 +1,5 @@
1
+ """Factory functions for embeddings."""
2
+
1
3
  import functools
2
4
  from importlib import util
3
5
  from typing import Any, Union