langchain 1.0.0a7__py3-none-any.whl → 1.0.0a8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

@@ -1,6 +1,5 @@
1
1
  """Middleware plugins for agents."""
2
2
 
3
- from .dynamic_system_prompt import DynamicSystemPromptMiddleware
4
3
  from .human_in_the_loop import HumanInTheLoopMiddleware
5
4
  from .prompt_caching import AnthropicPromptCachingMiddleware
6
5
  from .summarization import SummarizationMiddleware
@@ -11,7 +10,6 @@ __all__ = [
11
10
  "AgentState",
12
11
  # should move to langchain-anthropic if we decide to keep it
13
12
  "AnthropicPromptCachingMiddleware",
14
- "DynamicSystemPromptMiddleware",
15
13
  "HumanInTheLoopMiddleware",
16
14
  "ModelRequest",
17
15
  "SummarizationMiddleware",
@@ -112,14 +112,14 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
112
112
 
113
113
  def __init__(
114
114
  self,
115
- tool_configs: dict[str, bool | ToolConfig],
115
+ interrupt_on: dict[str, bool | ToolConfig],
116
116
  *,
117
117
  description_prefix: str = "Tool execution requires approval",
118
118
  ) -> None:
119
119
  """Initialize the human in the loop middleware.
120
120
 
121
121
  Args:
122
- tool_configs: Mapping of tool name to allowed actions.
122
+ interrupt_on: Mapping of tool name to allowed actions.
123
123
  If a tool doesn't have an entry, it's auto-approved by default.
124
124
  * `True` indicates all actions are allowed: accept, edit, and respond.
125
125
  * `False` indicates that the tool is auto-approved.
@@ -130,7 +130,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
130
130
  """
131
131
  super().__init__()
132
132
  resolved_tool_configs: dict[str, ToolConfig] = {}
133
- for tool_name, tool_config in tool_configs.items():
133
+ for tool_name, tool_config in interrupt_on.items():
134
134
  if isinstance(tool_config, bool):
135
135
  if tool_config is True:
136
136
  resolved_tool_configs[tool_name] = ToolConfig(
@@ -138,13 +138,15 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
138
138
  allow_edit=True,
139
139
  allow_respond=True,
140
140
  )
141
- else:
141
+ elif any(
142
+ tool_config.get(x, False) for x in ["allow_accept", "allow_edit", "allow_respond"]
143
+ ):
142
144
  resolved_tool_configs[tool_name] = tool_config
143
- self.tool_configs = resolved_tool_configs
145
+ self.interrupt_on = resolved_tool_configs
144
146
  self.description_prefix = description_prefix
145
147
 
146
148
  def after_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
147
- """Trigger HITL flows for relevant tool calls after an AIMessage."""
149
+ """Trigger interrupt flows for relevant tool calls after an AIMessage."""
148
150
  messages = state["messages"]
149
151
  if not messages:
150
152
  return None
@@ -154,16 +156,16 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
154
156
  return None
155
157
 
156
158
  # Separate tool calls that need interrupts from those that don't
157
- hitl_tool_calls: list[ToolCall] = []
159
+ interrupt_tool_calls: list[ToolCall] = []
158
160
  auto_approved_tool_calls = []
159
161
 
160
162
  for tool_call in last_ai_msg.tool_calls:
161
- hitl_tool_calls.append(tool_call) if tool_call[
163
+ interrupt_tool_calls.append(tool_call) if tool_call[
162
164
  "name"
163
- ] in self.tool_configs else auto_approved_tool_calls.append(tool_call)
165
+ ] in self.interrupt_on else auto_approved_tool_calls.append(tool_call)
164
166
 
165
167
  # If no interrupts needed, return early
166
- if not hitl_tool_calls:
168
+ if not interrupt_tool_calls:
167
169
  return None
168
170
 
169
171
  # Process all tool calls that require interrupts
@@ -171,11 +173,11 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
171
173
  artificial_tool_messages: list[ToolMessage] = []
172
174
 
173
175
  # Create interrupt requests for all tools that need approval
174
- hitl_requests: list[HumanInTheLoopRequest] = []
175
- for tool_call in hitl_tool_calls:
176
+ interrupt_requests: list[HumanInTheLoopRequest] = []
177
+ for tool_call in interrupt_tool_calls:
176
178
  tool_name = tool_call["name"]
177
179
  tool_args = tool_call["args"]
178
- config = self.tool_configs[tool_name]
180
+ config = self.interrupt_on[tool_name]
179
181
  description = (
180
182
  config.get("description")
181
183
  or f"{self.description_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
@@ -189,21 +191,23 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
189
191
  "config": config,
190
192
  "description": description,
191
193
  }
192
- hitl_requests.append(request)
194
+ interrupt_requests.append(request)
193
195
 
194
- responses: list[HumanInTheLoopResponse] = interrupt(hitl_requests)
196
+ responses: list[HumanInTheLoopResponse] = interrupt(interrupt_requests)
195
197
 
196
198
  # Validate that the number of responses matches the number of interrupt tool calls
197
- if (responses_len := len(responses)) != (hitl_tool_calls_len := len(hitl_tool_calls)):
199
+ if (responses_len := len(responses)) != (
200
+ interrupt_tool_calls_len := len(interrupt_tool_calls)
201
+ ):
198
202
  msg = (
199
203
  f"Number of human responses ({responses_len}) does not match "
200
- f"number of hanging tool calls ({hitl_tool_calls_len})."
204
+ f"number of hanging tool calls ({interrupt_tool_calls_len})."
201
205
  )
202
206
  raise ValueError(msg)
203
207
 
204
208
  for i, response in enumerate(responses):
205
- tool_call = hitl_tool_calls[i]
206
- config = self.tool_configs[tool_call["name"]]
209
+ tool_call = interrupt_tool_calls[i]
210
+ config = self.interrupt_on[tool_call["name"]]
207
211
 
208
212
  if response["type"] == "accept" and config.get("allow_accept"):
209
213
  approved_tool_calls.append(tool_call)
@@ -1,6 +1,7 @@
1
1
  """Anthropic prompt caching middleware."""
2
2
 
3
3
  from typing import Literal
4
+ from warnings import warn
4
5
 
5
6
  from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
6
7
 
@@ -19,6 +20,7 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
19
20
  type: Literal["ephemeral"] = "ephemeral",
20
21
  ttl: Literal["5m", "1h"] = "5m",
21
22
  min_messages_to_cache: int = 0,
23
+ unsupported_model_behavior: Literal["ignore", "warn", "raise"] = "warn",
22
24
  ) -> None:
23
25
  """Initialize the middleware with cache control settings.
24
26
 
@@ -27,10 +29,15 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
27
29
  ttl: The time to live for the cache, only "5m" and "1h" are supported.
28
30
  min_messages_to_cache: The minimum number of messages until the cache is used,
29
31
  default is 0.
32
+ unsupported_model_behavior: The behavior to take when an unsupported model is used.
33
+ "ignore" will ignore the unsupported model and continue without caching.
34
+ "warn" will warn the user and continue without caching.
35
+ "raise" will raise an error and stop the agent.
30
36
  """
31
37
  self.type = type
32
38
  self.ttl = ttl
33
39
  self.min_messages_to_cache = min_messages_to_cache
40
+ self.unsupported_model_behavior = unsupported_model_behavior
34
41
 
35
42
  def modify_model_request( # type: ignore[override]
36
43
  self,
@@ -40,19 +47,29 @@ class AnthropicPromptCachingMiddleware(AgentMiddleware):
40
47
  try:
41
48
  from langchain_anthropic import ChatAnthropic
42
49
  except ImportError:
50
+ ChatAnthropic = None # noqa: N806
51
+
52
+ msg: str | None = None
53
+
54
+ if ChatAnthropic is None:
43
55
  msg = (
44
56
  "AnthropicPromptCachingMiddleware caching middleware only supports "
45
- "Anthropic models."
57
+ "Anthropic models. "
46
58
  "Please install langchain-anthropic."
47
59
  )
48
- raise ValueError(msg)
49
-
50
- if not isinstance(request.model, ChatAnthropic):
60
+ elif not isinstance(request.model, ChatAnthropic):
51
61
  msg = (
52
62
  "AnthropicPromptCachingMiddleware caching middleware only supports "
53
63
  f"Anthropic models, not instances of {type(request.model)}"
54
64
  )
55
- raise ValueError(msg)
65
+
66
+ if msg is not None:
67
+ if self.unsupported_model_behavior == "raise":
68
+ raise ValueError(msg)
69
+ if self.unsupported_model_behavior == "warn":
70
+ warn(msg, stacklevel=3)
71
+ else:
72
+ return request
56
73
 
57
74
  messages_count = (
58
75
  len(request.messages) + 1 if request.system_prompt else len(request.messages)
@@ -3,7 +3,20 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from dataclasses import dataclass, field
6
- from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
6
+ from inspect import signature
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Annotated,
10
+ Any,
11
+ ClassVar,
12
+ Generic,
13
+ Literal,
14
+ Protocol,
15
+ TypeAlias,
16
+ TypeGuard,
17
+ cast,
18
+ overload,
19
+ )
7
20
 
8
21
  # needed as top level import for pydantic schema generation on AgentState
9
22
  from langchain_core.messages import AnyMessage # noqa: TC002
@@ -14,9 +27,12 @@ from langgraph.typing import ContextT
14
27
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar
15
28
 
16
29
  if TYPE_CHECKING:
30
+ from collections.abc import Callable
31
+
17
32
  from langchain_core.language_models.chat_models import BaseChatModel
18
33
  from langchain_core.tools import BaseTool
19
34
  from langgraph.runtime import Runtime
35
+ from langgraph.types import Command
20
36
 
21
37
  from langchain.agents.structured_output import ResponseFormat
22
38
 
@@ -88,6 +104,7 @@ class PublicAgentState(TypedDict, Generic[ResponseT]):
88
104
 
89
105
 
90
106
  StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
107
+ StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True)
91
108
 
92
109
 
93
110
  class AgentMiddleware(Generic[StateT, ContextT]):
@@ -103,6 +120,12 @@ class AgentMiddleware(Generic[StateT, ContextT]):
103
120
  tools: list[BaseTool]
104
121
  """Additional tools registered by the middleware."""
105
122
 
123
+ before_model_jump_to: ClassVar[list[JumpTo]] = []
124
+ """Valid jump destinations for before_model hook. Used to establish conditional edges."""
125
+
126
+ after_model_jump_to: ClassVar[list[JumpTo]] = []
127
+ """Valid jump destinations for after_model hook. Used to establish conditional edges."""
128
+
106
129
  def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
107
130
  """Logic to run before the model is called."""
108
131
 
@@ -117,3 +140,404 @@ class AgentMiddleware(Generic[StateT, ContextT]):
117
140
 
118
141
  def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
119
142
  """Logic to run after the model is called."""
143
+
144
+
145
+ class _CallableWithState(Protocol[StateT_contra]):
146
+ """Callable with AgentState as argument."""
147
+
148
+ def __call__(self, state: StateT_contra) -> dict[str, Any] | Command | None:
149
+ """Perform some logic with the state."""
150
+ ...
151
+
152
+
153
+ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
154
+ """Callable with AgentState and Runtime as arguments."""
155
+
156
+ def __call__(
157
+ self, state: StateT_contra, runtime: Runtime[ContextT]
158
+ ) -> dict[str, Any] | Command | None:
159
+ """Perform some logic with the state and runtime."""
160
+ ...
161
+
162
+
163
+ class _CallableWithModelRequestAndState(Protocol[StateT_contra]):
164
+ """Callable with ModelRequest and AgentState as arguments."""
165
+
166
+ def __call__(self, request: ModelRequest, state: StateT_contra) -> ModelRequest:
167
+ """Perform some logic with the model request and state."""
168
+ ...
169
+
170
+
171
+ class _CallableWithModelRequestAndStateAndRuntime(Protocol[StateT_contra, ContextT]):
172
+ """Callable with ModelRequest, AgentState, and Runtime as arguments."""
173
+
174
+ def __call__(
175
+ self, request: ModelRequest, state: StateT_contra, runtime: Runtime[ContextT]
176
+ ) -> ModelRequest:
177
+ """Perform some logic with the model request, state, and runtime."""
178
+ ...
179
+
180
+
181
+ _NodeSignature: TypeAlias = (
182
+ _CallableWithState[StateT] | _CallableWithStateAndRuntime[StateT, ContextT]
183
+ )
184
+ _ModelRequestSignature: TypeAlias = (
185
+ _CallableWithModelRequestAndState[StateT]
186
+ | _CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]
187
+ )
188
+
189
+
190
+ def is_callable_with_runtime(
191
+ func: _NodeSignature[StateT, ContextT],
192
+ ) -> TypeGuard[_CallableWithStateAndRuntime[StateT, ContextT]]:
193
+ return "runtime" in signature(func).parameters
194
+
195
+
196
+ def is_callable_with_runtime_and_request(
197
+ func: _ModelRequestSignature[StateT, ContextT],
198
+ ) -> TypeGuard[_CallableWithModelRequestAndStateAndRuntime[StateT, ContextT]]:
199
+ return "runtime" in signature(func).parameters
200
+
201
+
202
+ @overload
203
+ def before_model(
204
+ func: _NodeSignature[StateT, ContextT],
205
+ ) -> AgentMiddleware[StateT, ContextT]: ...
206
+
207
+
208
+ @overload
209
+ def before_model(
210
+ func: None = None,
211
+ *,
212
+ state_schema: type[StateT] | None = None,
213
+ tools: list[BaseTool] | None = None,
214
+ jump_to: list[JumpTo] | None = None,
215
+ name: str | None = None,
216
+ ) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
217
+
218
+
219
+ def before_model(
220
+ func: _NodeSignature[StateT, ContextT] | None = None,
221
+ *,
222
+ state_schema: type[StateT] | None = None,
223
+ tools: list[BaseTool] | None = None,
224
+ jump_to: list[JumpTo] | None = None,
225
+ name: str | None = None,
226
+ ) -> (
227
+ Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
228
+ | AgentMiddleware[StateT, ContextT]
229
+ ):
230
+ """Decorator used to dynamically create a middleware with the before_model hook.
231
+
232
+ Args:
233
+ func: The function to be decorated. Can accept either:
234
+ - `state: StateT` - Just the agent state
235
+ - `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
236
+ state_schema: Optional custom state schema type. If not provided, uses the default
237
+ AgentState schema.
238
+ tools: Optional list of additional tools to register with this middleware.
239
+ jump_to: Optional list of valid jump destinations for conditional edges.
240
+ Valid values are: "tools", "model", "__end__"
241
+ name: Optional name for the generated middleware class. If not provided,
242
+ uses the decorated function's name.
243
+
244
+ Returns:
245
+ Either an AgentMiddleware instance (if func is provided directly) or a decorator function
246
+ that can be applied to a function its wrapping.
247
+
248
+ The decorated function should return:
249
+ - `dict[str, Any]` - State updates to merge into the agent state
250
+ - `Command` - A command to control flow (e.g., jump to different node)
251
+ - `None` - No state updates or flow control
252
+
253
+ Examples:
254
+ Basic usage with state only:
255
+ ```python
256
+ @before_model
257
+ def log_before_model(state: AgentState) -> None:
258
+ print(f"About to call model with {len(state['messages'])} messages")
259
+ ```
260
+
261
+ Advanced usage with runtime and conditional jumping:
262
+ ```python
263
+ @before_model(jump_to=["__end__"])
264
+ def conditional_before_model(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
265
+ if some_condition(state):
266
+ return {"jump_to": "__end__"}
267
+ return None
268
+ ```
269
+
270
+ With custom state schema:
271
+ ```python
272
+ @before_model(
273
+ state_schema=MyCustomState,
274
+ )
275
+ def custom_before_model(state: MyCustomState) -> dict[str, Any]:
276
+ return {"custom_field": "updated_value"}
277
+ ```
278
+ """
279
+
280
+ def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
281
+ if is_callable_with_runtime(func):
282
+
283
+ def wrapped_with_runtime(
284
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
285
+ state: StateT,
286
+ runtime: Runtime[ContextT],
287
+ ) -> dict[str, Any] | Command | None:
288
+ return func(state, runtime)
289
+
290
+ wrapped = wrapped_with_runtime
291
+ else:
292
+
293
+ def wrapped_without_runtime(
294
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
295
+ state: StateT,
296
+ ) -> dict[str, Any] | Command | None:
297
+ return func(state) # type: ignore[call-arg]
298
+
299
+ wrapped = wrapped_without_runtime # type: ignore[assignment]
300
+
301
+ # Use function name as default if no name provided
302
+ middleware_name = name or cast("str", getattr(func, "__name__", "BeforeModelMiddleware"))
303
+
304
+ return type(
305
+ middleware_name,
306
+ (AgentMiddleware,),
307
+ {
308
+ "state_schema": state_schema or AgentState,
309
+ "tools": tools or [],
310
+ "before_model_jump_to": jump_to or [],
311
+ "before_model": wrapped,
312
+ },
313
+ )()
314
+
315
+ if func is not None:
316
+ return decorator(func)
317
+ return decorator
318
+
319
+
320
+ @overload
321
+ def modify_model_request(
322
+ func: _ModelRequestSignature[StateT, ContextT],
323
+ ) -> AgentMiddleware[StateT, ContextT]: ...
324
+
325
+
326
+ @overload
327
+ def modify_model_request(
328
+ func: None = None,
329
+ *,
330
+ state_schema: type[StateT] | None = None,
331
+ tools: list[BaseTool] | None = None,
332
+ name: str | None = None,
333
+ ) -> Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
334
+
335
+
336
+ def modify_model_request(
337
+ func: _ModelRequestSignature[StateT, ContextT] | None = None,
338
+ *,
339
+ state_schema: type[StateT] | None = None,
340
+ tools: list[BaseTool] | None = None,
341
+ name: str | None = None,
342
+ ) -> (
343
+ Callable[[_ModelRequestSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
344
+ | AgentMiddleware[StateT, ContextT]
345
+ ):
346
+ r"""Decorator used to dynamically create a middleware with the modify_model_request hook.
347
+
348
+ Args:
349
+ func: The function to be decorated. Can accept either:
350
+ - `request: ModelRequest, state: StateT` - Model request and agent state
351
+ - `request: ModelRequest, state: StateT, runtime: Runtime[ContextT]` -
352
+ Model request, state, and runtime context
353
+ state_schema: Optional custom state schema type. If not provided, uses the default
354
+ AgentState schema.
355
+ tools: Optional list of additional tools to register with this middleware.
356
+ name: Optional name for the generated middleware class. If not provided,
357
+ uses the decorated function's name.
358
+
359
+ Returns:
360
+ Either an AgentMiddleware instance (if func is provided) or a decorator function
361
+ that can be applied to a function.
362
+
363
+ The decorated function should return:
364
+ - `ModelRequest` - The modified model request to be sent to the language model
365
+
366
+ Examples:
367
+ Basic usage to modify system prompt:
368
+ ```python
369
+ @modify_model_request
370
+ def add_context_to_prompt(request: ModelRequest, state: AgentState) -> ModelRequest:
371
+ if request.system_prompt:
372
+ request.system_prompt += "\n\nAdditional context: ..."
373
+ else:
374
+ request.system_prompt = "Additional context: ..."
375
+ return request
376
+ ```
377
+
378
+ Advanced usage with runtime and custom model settings:
379
+ ```python
380
+ @modify_model_request
381
+ def dynamic_model_settings(
382
+ request: ModelRequest, state: AgentState, runtime: Runtime
383
+ ) -> ModelRequest:
384
+ # Use a different model based on user subscription tier
385
+ if runtime.context.get("subscription_tier") == "premium":
386
+ request.model = "gpt-4o"
387
+ else:
388
+ request.model = "gpt-4o-mini"
389
+
390
+ return request
391
+ ```
392
+ """
393
+
394
+ def decorator(
395
+ func: _ModelRequestSignature[StateT, ContextT],
396
+ ) -> AgentMiddleware[StateT, ContextT]:
397
+ if is_callable_with_runtime_and_request(func):
398
+
399
+ def wrapped_with_runtime(
400
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
401
+ request: ModelRequest,
402
+ state: StateT,
403
+ runtime: Runtime[ContextT],
404
+ ) -> ModelRequest:
405
+ return func(request, state, runtime)
406
+
407
+ wrapped = wrapped_with_runtime
408
+ else:
409
+
410
+ def wrapped_without_runtime(
411
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
412
+ request: ModelRequest,
413
+ state: StateT,
414
+ ) -> ModelRequest:
415
+ return func(request, state) # type: ignore[call-arg]
416
+
417
+ wrapped = wrapped_without_runtime # type: ignore[assignment]
418
+
419
+ # Use function name as default if no name provided
420
+ middleware_name = name or cast(
421
+ "str", getattr(func, "__name__", "ModifyModelRequestMiddleware")
422
+ )
423
+
424
+ return type(
425
+ middleware_name,
426
+ (AgentMiddleware,),
427
+ {
428
+ "state_schema": state_schema or AgentState,
429
+ "tools": tools or [],
430
+ "modify_model_request": wrapped,
431
+ },
432
+ )()
433
+
434
+ if func is not None:
435
+ return decorator(func)
436
+ return decorator
437
+
438
+
439
+ @overload
440
+ def after_model(
441
+ func: _NodeSignature[StateT, ContextT],
442
+ ) -> AgentMiddleware[StateT, ContextT]: ...
443
+
444
+
445
+ @overload
446
+ def after_model(
447
+ func: None = None,
448
+ *,
449
+ state_schema: type[StateT] | None = None,
450
+ tools: list[BaseTool] | None = None,
451
+ jump_to: list[JumpTo] | None = None,
452
+ name: str | None = None,
453
+ ) -> Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]: ...
454
+
455
+
456
+ def after_model(
457
+ func: _NodeSignature[StateT, ContextT] | None = None,
458
+ *,
459
+ state_schema: type[StateT] | None = None,
460
+ tools: list[BaseTool] | None = None,
461
+ jump_to: list[JumpTo] | None = None,
462
+ name: str | None = None,
463
+ ) -> (
464
+ Callable[[_NodeSignature[StateT, ContextT]], AgentMiddleware[StateT, ContextT]]
465
+ | AgentMiddleware[StateT, ContextT]
466
+ ):
467
+ """Decorator used to dynamically create a middleware with the after_model hook.
468
+
469
+ Args:
470
+ func: The function to be decorated. Can accept either:
471
+ - `state: StateT` - Just the agent state (includes model response)
472
+ - `state: StateT, runtime: Runtime[ContextT]` - State and runtime context
473
+ state_schema: Optional custom state schema type. If not provided, uses the default
474
+ AgentState schema.
475
+ tools: Optional list of additional tools to register with this middleware.
476
+ jump_to: Optional list of valid jump destinations for conditional edges.
477
+ Valid values are: "tools", "model", "__end__"
478
+ name: Optional name for the generated middleware class. If not provided,
479
+ uses the decorated function's name.
480
+
481
+ Returns:
482
+ Either an AgentMiddleware instance (if func is provided) or a decorator function
483
+ that can be applied to a function.
484
+
485
+ The decorated function should return:
486
+ - `dict[str, Any]` - State updates to merge into the agent state
487
+ - `Command` - A command to control flow (e.g., jump to different node)
488
+ - `None` - No state updates or flow control
489
+
490
+ Examples:
491
+ Basic usage for logging model responses:
492
+ ```python
493
+ @after_model
494
+ def log_latest_message(state: AgentState) -> None:
495
+ print(state["messages"][-1].content)
496
+ ```
497
+
498
+ With custom state schema:
499
+ ```python
500
+ @after_model(state_schema=MyCustomState, name="MyAfterModelMiddleware")
501
+ def custom_after_model(state: MyCustomState) -> dict[str, Any]:
502
+ return {"custom_field": "updated_after_model"}
503
+ ```
504
+ """
505
+
506
+ def decorator(func: _NodeSignature[StateT, ContextT]) -> AgentMiddleware[StateT, ContextT]:
507
+ if is_callable_with_runtime(func):
508
+
509
+ def wrapped_with_runtime(
510
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
511
+ state: StateT,
512
+ runtime: Runtime[ContextT],
513
+ ) -> dict[str, Any] | Command | None:
514
+ return func(state, runtime)
515
+
516
+ wrapped = wrapped_with_runtime
517
+ else:
518
+
519
+ def wrapped_without_runtime(
520
+ self: AgentMiddleware[StateT, ContextT], # noqa: ARG001
521
+ state: StateT,
522
+ ) -> dict[str, Any] | Command | None:
523
+ return func(state) # type: ignore[call-arg]
524
+
525
+ wrapped = wrapped_without_runtime # type: ignore[assignment]
526
+
527
+ # Use function name as default if no name provided
528
+ middleware_name = name or cast("str", getattr(func, "__name__", "AfterModelMiddleware"))
529
+
530
+ return type(
531
+ middleware_name,
532
+ (AgentMiddleware,),
533
+ {
534
+ "state_schema": state_schema or AgentState,
535
+ "tools": tools or [],
536
+ "after_model_jump_to": jump_to or [],
537
+ "after_model": wrapped,
538
+ },
539
+ )()
540
+
541
+ if func is not None:
542
+ return decorator(func)
543
+ return decorator
@@ -464,7 +464,7 @@ def create_agent( # noqa: PLR0915
464
464
  f"{middleware_w_after[0].__class__.__name__}.after_model",
465
465
  END,
466
466
  first_node,
467
- tools_available=tool_node is not None,
467
+ jump_to=middleware_w_after[0].after_model_jump_to,
468
468
  )
469
469
 
470
470
  # Add middleware edges (same as before)
@@ -475,7 +475,7 @@ def create_agent( # noqa: PLR0915
475
475
  f"{m1.__class__.__name__}.before_model",
476
476
  f"{m2.__class__.__name__}.before_model",
477
477
  first_node,
478
- tools_available=tool_node is not None,
478
+ jump_to=m1.before_model_jump_to,
479
479
  )
480
480
  # Go directly to model_request after the last before_model
481
481
  _add_middleware_edge(
@@ -483,7 +483,7 @@ def create_agent( # noqa: PLR0915
483
483
  f"{middleware_w_before[-1].__class__.__name__}.before_model",
484
484
  "model_request",
485
485
  first_node,
486
- tools_available=tool_node is not None,
486
+ jump_to=middleware_w_before[-1].before_model_jump_to,
487
487
  )
488
488
 
489
489
  if middleware_w_after:
@@ -496,7 +496,7 @@ def create_agent( # noqa: PLR0915
496
496
  f"{m1.__class__.__name__}.after_model",
497
497
  f"{m2.__class__.__name__}.after_model",
498
498
  first_node,
499
- tools_available=tool_node is not None,
499
+ jump_to=m1.after_model_jump_to,
500
500
  )
501
501
 
502
502
  return graph
@@ -528,8 +528,8 @@ def _fetch_last_ai_and_tool_messages(
528
528
 
529
529
  def _make_model_to_tools_edge(
530
530
  first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
531
- ) -> Callable[[AgentState], str | list[Send] | None]:
532
- def model_to_tools(state: AgentState) -> str | list[Send] | None:
531
+ ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
532
+ def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
533
533
  if jump_to := state.get("jump_to"):
534
534
  return _resolve_jump(jump_to, first_node)
535
535
 
@@ -548,8 +548,7 @@ def _make_model_to_tools_edge(
548
548
  # of using Send w/ tool calls directly which allows more intuitive interrupt behavior
549
549
  # largely internal so can be fixed later
550
550
  pending_tool_calls = [
551
- tool_node.inject_tool_args(call, state, None) # type: ignore[arg-type]
552
- for call in pending_tool_calls
551
+ tool_node.inject_tool_args(call, state, None) for call in pending_tool_calls
553
552
  ]
554
553
  return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
555
554
 
@@ -560,8 +559,8 @@ def _make_model_to_tools_edge(
560
559
 
561
560
  def _make_tools_to_model_edge(
562
561
  tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
563
- ) -> Callable[[AgentState], str | None]:
564
- def tools_to_model(state: AgentState) -> str | None:
562
+ ) -> Callable[[dict[str, Any]], str | None]:
563
+ def tools_to_model(state: dict[str, Any]) -> str | None:
565
564
  last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
566
565
 
567
566
  if all(
@@ -584,7 +583,7 @@ def _add_middleware_edge(
584
583
  name: str,
585
584
  default_destination: str,
586
585
  model_destination: str,
587
- tools_available: bool, # noqa: FBT001
586
+ jump_to: list[JumpTo] | None,
588
587
  ) -> None:
589
588
  """Add an edge to the graph for a middleware node.
590
589
 
@@ -594,18 +593,23 @@ def _add_middleware_edge(
594
593
  name: The name of the middleware node.
595
594
  default_destination: The default destination for the edge.
596
595
  model_destination: The destination for the edge to the model.
597
- tools_available: Whether tools are available for the edge to potentially route to.
596
+ jump_to: The conditionally jumpable destinations for the edge.
598
597
  """
598
+ if jump_to:
599
+
600
+ def jump_edge(state: dict[str, Any]) -> str:
601
+ return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
599
602
 
600
- def jump_edge(state: AgentState) -> str:
601
- return _resolve_jump(state.get("jump_to"), model_destination) or default_destination
603
+ destinations = [default_destination]
602
604
 
603
- destinations = [default_destination]
604
- if default_destination != END:
605
- destinations.append(END)
606
- if tools_available:
607
- destinations.append("tools")
608
- if name != model_destination:
609
- destinations.append(model_destination)
605
+ if "__end__" in jump_to:
606
+ destinations.append(END)
607
+ if "tools" in jump_to:
608
+ destinations.append("tools")
609
+ if "model" in jump_to and name != model_destination:
610
+ destinations.append(model_destination)
610
611
 
611
- graph.add_conditional_edges(name, jump_edge, destinations)
612
+ graph.add_conditional_edges(name, jump_edge, destinations)
613
+
614
+ else:
615
+ graph.add_edge(name, default_destination)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain
3
- Version: 1.0.0a7
3
+ Version: 1.0.0a8
4
4
  Summary: Building applications with LLMs through composability
5
5
  License: MIT
6
6
  Project-URL: Source Code, https://github.com/langchain-ai/langchain/tree/master/libs/langchain
@@ -1,7 +1,7 @@
1
- langchain-1.0.0a7.dist-info/METADATA,sha256=avxuGLMGlcJmBGdtOt03ZUAml3qw6l6Hfo8Lkvp1Q2g,6259
2
- langchain-1.0.0a7.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
3
- langchain-1.0.0a7.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
4
- langchain-1.0.0a7.dist-info/licenses/LICENSE,sha256=TsZ-TKbmch26hJssqCJhWXyGph7iFLvyFBYAa3stBHg,1067
1
+ langchain-1.0.0a8.dist-info/METADATA,sha256=slSRaiJXGZdoNrLCLCYpJqqaPdM0keHqtv2aCBHLC7w,6259
2
+ langchain-1.0.0a8.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90
3
+ langchain-1.0.0a8.dist-info/entry_points.txt,sha256=6OYgBcLyFCUgeqLgnvMyOJxPCWzgy7se4rLPKtNonMs,34
4
+ langchain-1.0.0a8.dist-info/licenses/LICENSE,sha256=TsZ-TKbmch26hJssqCJhWXyGph7iFLvyFBYAa3stBHg,1067
5
5
  langchain/__init__.py,sha256=Z6r4MjNaC6DSyiMgFSRmi8EhizelVIOb_CVJJRAVjDc,604
6
6
  langchain/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  langchain/_internal/_documents.py,sha256=z9wAPukoASOMw4WTFFBKCCZYsvsKbo-Cq6CeHjdq9eE,1045
@@ -12,13 +12,12 @@ langchain/_internal/_utils.py,sha256=lG8X9muiRAWtQjRPudq-1x-wHbk0J3spu_rYZckVdYs
12
12
  langchain/agents/__init__.py,sha256=NG2S3dic9L3i4sAD9mpgaTv6Dl4L3u45xxK6jn-I4W8,281
13
13
  langchain/agents/_internal/__init__.py,sha256=5nNBeaeQIvv9IOQjY4_aNW8pffWzMXQgi0b6Nx-WghM,37
14
14
  langchain/agents/_internal/_typing.py,sha256=JoWa-KL5uLNeq6yrm56wnIvhDeFnCt2fTzgUcj5zWy4,270
15
- langchain/agents/middleware/__init__.py,sha256=Q68coEBPtxTnb7LuDDSKvUrK0RvppqEwyyzpE8-RgiE,613
16
- langchain/agents/middleware/dynamic_system_prompt.py,sha256=uakW4wyVc9h52T2QO4BeKWmbc2UK31VqFGuMMvj9wX8,3267
17
- langchain/agents/middleware/human_in_the_loop.py,sha256=0XX0fvd6XezrbhAnhBRe3OgHQDnZeWK0LcorqWC0_BE,10101
18
- langchain/agents/middleware/prompt_caching.py,sha256=du_qrBr0_kwWhdO_xggtfrEN5FTcGLKu3oYQDnSS0Do,2263
15
+ langchain/agents/middleware/__init__.py,sha256=-NzMTmD5ogpzlsqHGjv6SnTrfXqU3vTahGUoGDk299U,511
16
+ langchain/agents/middleware/human_in_the_loop.py,sha256=_6THKNzp1dvYBwBLdnZ9PXsHJP3uedn4A60ZON4xlvI,10301
17
+ langchain/agents/middleware/prompt_caching.py,sha256=QLoWdd9jUiXAytGqbXW0I_Mg8WgrgTBO8gOZ-s8Bx8g,3081
19
18
  langchain/agents/middleware/summarization.py,sha256=qqEqAuJXQ5rfewhFHftHLnrX8jhdMu9dPfz0akhzfuc,10281
20
- langchain/agents/middleware/types.py,sha256=DRsl0GjgWXbPlFTiiVnI8pMhzMJF3Y2VkE2zLMKQhaY,3826
21
- langchain/agents/middleware_agent.py,sha256=hfIt4LdtDjkZGs0ylo8xti67iecpPEsuklYCjJ20V8k,23746
19
+ langchain/agents/middleware/types.py,sha256=a9B6Ihx12mNTroopL1SqHxsO51ZfSkxdtkPeZXw8EJc,18606
20
+ langchain/agents/middleware_agent.py,sha256=fncjAFNsqZqEkDYSsBfq-goxN4GqNlQixirWRBguXhs,23847
22
21
  langchain/agents/react_agent.py,sha256=6ZNI2dp0hTL7hTm7ao-HkQ3hmVvBQuFu9pJz0PSK_eg,49712
23
22
  langchain/agents/structured_output.py,sha256=QWNafJx7au_jJawJgIfovnDoP8Z9mLxDZNvDX_1RRJ0,13327
24
23
  langchain/agents/tool_node.py,sha256=QabTfIi8nGrwfzaSOeWfyHos6sgXjFTdRXexQG7u2HE,46596
@@ -36,4 +35,4 @@ langchain/storage/exceptions.py,sha256=Fl_8tON3KmByBKwXtno5WSj0-c2RiZxnhw3gv5aS2
36
35
  langchain/storage/in_memory.py,sha256=ozrmu0EtaJJVSAzK_u7nzxWpr9OOscWkANHSg-qIVYQ,369
37
36
  langchain/text_splitter.py,sha256=yxWs4secpnkfK6VZDiNJNdlYOrRZ18RQZj1S3xNQ73A,1554
38
37
  langchain/tools/__init__.py,sha256=NYQzLxW2iI5Twu3voefVC-dJEI4Wgh7jC311CQEpvZs,252
39
- langchain-1.0.0a7.dist-info/RECORD,,
38
+ langchain-1.0.0a8.dist-info/RECORD,,
@@ -1,105 +0,0 @@
1
- """Dynamic System Prompt Middleware.
2
-
3
- Allows setting the system prompt dynamically right before each model invocation.
4
- Useful when the prompt depends on the current agent state or per-invocation context.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- from inspect import signature
10
- from typing import TYPE_CHECKING, Protocol, TypeAlias, cast
11
-
12
- from langgraph.typing import ContextT
13
-
14
- from langchain.agents.middleware.types import (
15
- AgentMiddleware,
16
- AgentState,
17
- ModelRequest,
18
- )
19
-
20
- if TYPE_CHECKING:
21
- from langgraph.runtime import Runtime
22
-
23
-
24
- class DynamicSystemPromptWithoutRuntime(Protocol):
25
- """Dynamic system prompt without runtime in call signature."""
26
-
27
- def __call__(self, state: AgentState) -> str:
28
- """Return the system prompt for the next model call."""
29
- ...
30
-
31
-
32
- class DynamicSystemPromptWithRuntime(Protocol[ContextT]):
33
- """Dynamic system prompt with runtime in call signature."""
34
-
35
- def __call__(self, state: AgentState, runtime: Runtime[ContextT]) -> str:
36
- """Return the system prompt for the next model call."""
37
- ...
38
-
39
-
40
- DynamicSystemPrompt: TypeAlias = (
41
- DynamicSystemPromptWithoutRuntime | DynamicSystemPromptWithRuntime[ContextT]
42
- )
43
-
44
-
45
- class DynamicSystemPromptMiddleware(AgentMiddleware):
46
- """Dynamic System Prompt Middleware.
47
-
48
- Allows setting the system prompt dynamically right before each model invocation.
49
- Useful when the prompt depends on the current agent state or per-invocation context.
50
-
51
- Example:
52
- ```python
53
- from langchain.agents.middleware import DynamicSystemPromptMiddleware
54
-
55
-
56
- class Context(TypedDict):
57
- user_name: str
58
-
59
-
60
- def system_prompt(state: AgentState, runtime: Runtime[Context]) -> str:
61
- user_name = runtime.context.get("user_name", "n/a")
62
- return (
63
- f"You are a helpful assistant. Always address the user by their name: {user_name}"
64
- )
65
-
66
-
67
- middleware = DynamicSystemPromptMiddleware(system_prompt)
68
- ```
69
- """
70
-
71
- _accepts_runtime: bool
72
-
73
- def __init__(
74
- self,
75
- dynamic_system_prompt: DynamicSystemPrompt[ContextT],
76
- ) -> None:
77
- """Initialize the dynamic system prompt middleware.
78
-
79
- Args:
80
- dynamic_system_prompt: Function that receives the current agent state
81
- and optionally runtime with context, and returns the system prompt for
82
- the next model call. Returns a string.
83
- """
84
- super().__init__()
85
- self.dynamic_system_prompt = dynamic_system_prompt
86
- self._accepts_runtime = "runtime" in signature(dynamic_system_prompt).parameters
87
-
88
- def modify_model_request(
89
- self,
90
- request: ModelRequest,
91
- state: AgentState,
92
- runtime: Runtime[ContextT],
93
- ) -> ModelRequest:
94
- """Modify the model request to include the dynamic system prompt."""
95
- if self._accepts_runtime:
96
- system_prompt = cast(
97
- "DynamicSystemPromptWithRuntime[ContextT]", self.dynamic_system_prompt
98
- )(state, runtime)
99
- else:
100
- system_prompt = cast("DynamicSystemPromptWithoutRuntime", self.dynamic_system_prompt)(
101
- state
102
- )
103
-
104
- request.system_prompt = system_prompt
105
- return request