langchain 1.0.0a4__py3-none-any.whl → 1.0.0a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

@@ -48,7 +48,7 @@ Respond ONLY with the extracted context. Do not include any additional informati
48
48
  <messages>
49
49
  Messages to summarize:
50
50
  {messages}
51
- </messages>"""
51
+ </messages>""" # noqa: E501
52
52
 
53
53
  SUMMARY_PREFIX = "## Previous conversation summary:"
54
54
 
@@ -98,7 +98,7 @@ class SummarizationMiddleware(AgentMiddleware):
98
98
  self.summary_prompt = summary_prompt
99
99
  self.summary_prefix = summary_prefix
100
100
 
101
- def before_model(self, state: AgentState) -> dict[str, Any] | None:
101
+ def before_model(self, state: AgentState) -> dict[str, Any] | None: # type: ignore[override]
102
102
  """Process messages before model invocation, potentially triggering summarization."""
103
103
  messages = state["messages"]
104
104
  self._ensure_message_ids(messages)
@@ -8,15 +8,27 @@ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
8
8
  # needed as top level import for pydantic schema generation on AgentState
9
9
  from langchain_core.messages import AnyMessage # noqa: TC002
10
10
  from langgraph.channels.ephemeral_value import EphemeralValue
11
- from langgraph.graph.message import Messages, add_messages
11
+ from langgraph.graph.message import add_messages
12
+ from langgraph.runtime import Runtime
13
+ from langgraph.typing import ContextT
12
14
  from typing_extensions import NotRequired, Required, TypedDict, TypeVar
13
15
 
14
16
  if TYPE_CHECKING:
15
17
  from langchain_core.language_models.chat_models import BaseChatModel
16
18
  from langchain_core.tools import BaseTool
19
+ from langgraph.runtime import Runtime
17
20
 
18
21
  from langchain.agents.structured_output import ResponseFormat
19
22
 
23
+ __all__ = [
24
+ "AgentMiddleware",
25
+ "AgentState",
26
+ "ContextT",
27
+ "ModelRequest",
28
+ "OmitFromSchema",
29
+ "PublicAgentState",
30
+ ]
31
+
20
32
  JumpTo = Literal["tools", "model", "__end__"]
21
33
  """Destination to jump to when a middleware node returns."""
22
34
 
@@ -36,29 +48,53 @@ class ModelRequest:
36
48
  model_settings: dict[str, Any] = field(default_factory=dict)
37
49
 
38
50
 
51
+ @dataclass
52
+ class OmitFromSchema:
53
+ """Annotation used to mark state attributes as omitted from input or output schemas."""
54
+
55
+ input: bool = True
56
+ """Whether to omit the attribute from the input schema."""
57
+
58
+ output: bool = True
59
+ """Whether to omit the attribute from the output schema."""
60
+
61
+
62
+ OmitFromInput = OmitFromSchema(input=True, output=False)
63
+ """Annotation used to mark state attributes as omitted from input schema."""
64
+
65
+ OmitFromOutput = OmitFromSchema(input=False, output=True)
66
+ """Annotation used to mark state attributes as omitted from output schema."""
67
+
68
+ PrivateStateAttr = OmitFromSchema(input=True, output=True)
69
+ """Annotation used to mark state attributes as purely internal for a given middleware."""
70
+
71
+
39
72
  class AgentState(TypedDict, Generic[ResponseT]):
40
73
  """State schema for the agent."""
41
74
 
42
75
  messages: Required[Annotated[list[AnyMessage], add_messages]]
43
- model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]]
44
- jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
76
+ jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue, PrivateStateAttr]]
45
77
  response: NotRequired[ResponseT]
46
78
 
47
79
 
48
80
  class PublicAgentState(TypedDict, Generic[ResponseT]):
49
- """Input / output schema for the agent."""
81
+ """Public state schema for the agent.
50
82
 
51
- messages: Required[Messages]
83
+ Just used for typing purposes.
84
+ """
85
+
86
+ messages: Required[Annotated[list[AnyMessage], add_messages]]
52
87
  response: NotRequired[ResponseT]
53
88
 
54
89
 
55
- StateT = TypeVar("StateT", bound=AgentState)
90
+ StateT = TypeVar("StateT", bound=AgentState, default=AgentState)
56
91
 
57
92
 
58
- class AgentMiddleware(Generic[StateT]):
93
+ class AgentMiddleware(Generic[StateT, ContextT]):
59
94
  """Base middleware class for an agent.
60
95
 
61
- Subclass this and implement any of the defined methods to customize agent behavior between steps in the main agent loop.
96
+ Subclass this and implement any of the defined methods to customize agent behavior
97
+ between steps in the main agent loop.
62
98
  """
63
99
 
64
100
  state_schema: type[StateT] = cast("type[StateT]", AgentState)
@@ -67,12 +103,17 @@ class AgentMiddleware(Generic[StateT]):
67
103
  tools: list[BaseTool]
68
104
  """Additional tools registered by the middleware."""
69
105
 
70
- def before_model(self, state: StateT) -> dict[str, Any] | None:
106
+ def before_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
71
107
  """Logic to run before the model is called."""
72
108
 
73
- def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002
109
+ def modify_model_request(
110
+ self,
111
+ request: ModelRequest,
112
+ state: StateT, # noqa: ARG002
113
+ runtime: Runtime[ContextT], # noqa: ARG002
114
+ ) -> ModelRequest:
74
115
  """Logic to modify request kwargs before the model is called."""
75
116
  return request
76
117
 
77
- def after_model(self, state: StateT) -> dict[str, Any] | None:
118
+ def after_model(self, state: StateT, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
78
119
  """Logic to run after the model is called."""
@@ -2,7 +2,8 @@
2
2
 
3
3
  import itertools
4
4
  from collections.abc import Callable, Sequence
5
- from typing import Any, Union
5
+ from inspect import signature
6
+ from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
6
7
 
7
8
  from langchain_core.language_models.chat_models import BaseChatModel
8
9
  from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
@@ -10,18 +11,19 @@ from langchain_core.runnables import Runnable
10
11
  from langchain_core.tools import BaseTool
11
12
  from langgraph.constants import END, START
12
13
  from langgraph.graph.state import StateGraph
14
+ from langgraph.runtime import Runtime
15
+ from langgraph.types import Send
13
16
  from langgraph.typing import ContextT
14
- from typing_extensions import TypedDict, TypeVar
17
+ from typing_extensions import NotRequired, Required, TypedDict, TypeVar
15
18
 
16
19
  from langchain.agents.middleware.types import (
17
20
  AgentMiddleware,
18
21
  AgentState,
19
22
  JumpTo,
20
23
  ModelRequest,
24
+ OmitFromSchema,
21
25
  PublicAgentState,
22
26
  )
23
-
24
- # Import structured output classes from the old implementation
25
27
  from langchain.agents.structured_output import (
26
28
  MultipleStructuredOutputsError,
27
29
  OutputToolBinding,
@@ -37,29 +39,52 @@ from langchain.chat_models import init_chat_model
37
39
  STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
38
40
 
39
41
 
40
- def _merge_state_schemas(schemas: list[type]) -> type:
41
- """Merge multiple TypedDict schemas into a single schema with all fields."""
42
- if not schemas:
43
- return AgentState
42
+ def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
43
+ """Resolve schema by merging schemas and optionally respecting OmitFromSchema annotations.
44
44
 
45
+ Args:
46
+ schemas: List of schema types to merge
47
+ schema_name: Name for the generated TypedDict
48
+ omit_flag: If specified, omit fields with this flag set ('input' or 'output')
49
+ """
45
50
  all_annotations = {}
46
51
 
47
52
  for schema in schemas:
48
- all_annotations.update(schema.__annotations__)
53
+ hints = get_type_hints(schema, include_extras=True)
54
+
55
+ for field_name, field_type in hints.items():
56
+ should_omit = False
49
57
 
50
- return TypedDict("MergedState", all_annotations) # type: ignore[operator]
58
+ if omit_flag:
59
+ # Check for omission in the annotation metadata
60
+ metadata = _extract_metadata(field_type)
61
+ for meta in metadata:
62
+ if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
63
+ should_omit = True
64
+ break
51
65
 
66
+ if not should_omit:
67
+ all_annotations[field_name] = field_type
52
68
 
53
- def _filter_state_for_schema(state: dict[str, Any], schema: type) -> dict[str, Any]:
54
- """Filter state to only include fields defined in the given schema."""
55
- if not hasattr(schema, "__annotations__"):
56
- return state
69
+ return TypedDict(schema_name, all_annotations) # type: ignore[operator]
57
70
 
58
- schema_fields = set(schema.__annotations__.keys())
59
- return {k: v for k, v in state.items() if k in schema_fields}
60
71
 
72
+ def _extract_metadata(type_: type) -> list:
73
+ """Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
74
+ # Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
75
+ if get_origin(type_) in (Required, NotRequired):
76
+ inner_type = get_args(type_)[0]
77
+ if get_origin(inner_type) is Annotated:
78
+ return list(get_args(inner_type)[1:])
61
79
 
62
- def _supports_native_structured_output(model: Union[str, BaseChatModel]) -> bool:
80
+ # Handle direct Annotated[...]
81
+ elif get_origin(type_) is Annotated:
82
+ return list(get_args(type_)[1:])
83
+
84
+ return []
85
+
86
+
87
+ def _supports_native_structured_output(model: str | BaseChatModel) -> bool:
63
88
  """Check if a model supports native structured output."""
64
89
  model_name: str | None = None
65
90
  if isinstance(model, str):
@@ -113,7 +138,7 @@ def create_agent( # noqa: PLR0915
113
138
  model: str | BaseChatModel,
114
139
  tools: Sequence[BaseTool | Callable | dict[str, Any]] | ToolNode | None = None,
115
140
  system_prompt: str | None = None,
116
- middleware: Sequence[AgentMiddleware] = (),
141
+ middleware: Sequence[AgentMiddleware[AgentState[ResponseT], ContextT]] = (),
117
142
  response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
118
143
  context_schema: type[ContextT] | None = None,
119
144
  ) -> StateGraph[
@@ -198,46 +223,30 @@ def create_agent( # noqa: PLR0915
198
223
  m for m in middleware if m.__class__.after_model is not AgentMiddleware.after_model
199
224
  ]
200
225
 
201
- # Collect all middleware state schemas and create merged schema
202
- merged_state_schema: type[AgentState] = _merge_state_schemas(
203
- [m.state_schema for m in middleware]
204
- )
226
+ state_schemas = {m.state_schema for m in middleware}
227
+ state_schemas.add(AgentState)
228
+
229
+ state_schema = _resolve_schema(state_schemas, "StateSchema", None)
230
+ input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
231
+ output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
205
232
 
206
233
  # create graph, add nodes
207
- graph = StateGraph(
208
- merged_state_schema,
209
- input_schema=PublicAgentState,
210
- output_schema=PublicAgentState,
234
+ graph: StateGraph[
235
+ AgentState[ResponseT], ContextT, PublicAgentState[ResponseT], PublicAgentState[ResponseT]
236
+ ] = StateGraph(
237
+ state_schema=state_schema,
238
+ input_schema=input_schema,
239
+ output_schema=output_schema,
211
240
  context_schema=context_schema,
212
241
  )
213
242
 
214
- def _prepare_model_request(state: dict[str, Any]) -> tuple[ModelRequest, list[AnyMessage]]:
215
- """Prepare model request and messages."""
216
- request = state.get("model_request") or ModelRequest(
217
- model=model,
218
- tools=default_tools,
219
- system_prompt=system_prompt,
220
- response_format=response_format,
221
- messages=state["messages"],
222
- tool_choice=None,
223
- )
224
-
225
- # prepare messages
226
- messages = request.messages
227
- if request.system_prompt:
228
- messages = [SystemMessage(request.system_prompt), *messages]
229
-
230
- return request, messages
231
-
232
- def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str, Any]:
243
+ def _handle_model_output(output: AIMessage) -> dict[str, Any]:
233
244
  """Handle model output including structured responses."""
234
245
  # Handle structured output with native strategy
235
246
  if isinstance(response_format, ProviderStrategy):
236
247
  if not output.tool_calls and native_output_binding:
237
248
  structured_response = native_output_binding.parse(output)
238
249
  return {"messages": [output], "response": structured_response}
239
- if state.get("response") is not None:
240
- return {"messages": [output], "response": None}
241
250
  return {"messages": [output]}
242
251
 
243
252
  # Handle structured output with tools strategy
@@ -315,9 +324,6 @@ def create_agent( # noqa: PLR0915
315
324
  ],
316
325
  }
317
326
 
318
- # Standard response handling
319
- if state.get("response") is not None:
320
- return {"messages": [output], "response": None}
321
327
  return {"messages": [output]}
322
328
 
323
329
  def _get_bound_model(request: ModelRequest) -> Runnable:
@@ -340,37 +346,67 @@ def create_agent( # noqa: PLR0915
340
346
  )
341
347
  return request.model.bind(**request.model_settings)
342
348
 
343
- def model_request(state: dict[str, Any]) -> dict[str, Any]:
349
+ model_request_signatures: list[
350
+ tuple[bool, AgentMiddleware[AgentState[ResponseT], ContextT]]
351
+ ] = [
352
+ ("runtime" in signature(m.modify_model_request).parameters, m)
353
+ for m in middleware_w_modify_model_request
354
+ ]
355
+
356
+ def model_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
344
357
  """Sync model request handler with sequential middleware processing."""
345
- # Start with the base model request
346
- request, messages = _prepare_model_request(state)
358
+ request = ModelRequest(
359
+ model=model,
360
+ tools=default_tools,
361
+ system_prompt=system_prompt,
362
+ response_format=response_format,
363
+ messages=state["messages"],
364
+ tool_choice=None,
365
+ )
347
366
 
348
367
  # Apply modify_model_request middleware in sequence
349
- for m in middleware_w_modify_model_request:
350
- # Filter state to only include fields defined in this middleware's schema
351
- filtered_state = _filter_state_for_schema(state, m.state_schema)
352
- request = m.modify_model_request(request, filtered_state)
368
+ for use_runtime, m in model_request_signatures:
369
+ if use_runtime:
370
+ m.modify_model_request(request, state, runtime)
371
+ else:
372
+ m.modify_model_request(request, state) # type: ignore[call-arg]
353
373
 
354
- # Get the bound model with the final request
374
+ # Get the final model and messages
355
375
  model_ = _get_bound_model(request)
376
+ messages = request.messages
377
+ if request.system_prompt:
378
+ messages = [SystemMessage(request.system_prompt), *messages]
379
+
356
380
  output = model_.invoke(messages)
357
- return _handle_model_output(state, output)
381
+ return _handle_model_output(output)
358
382
 
359
- async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
383
+ async def amodel_request(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
360
384
  """Async model request handler with sequential middleware processing."""
361
385
  # Start with the base model request
362
- request, messages = _prepare_model_request(state)
386
+ request = ModelRequest(
387
+ model=model,
388
+ tools=default_tools,
389
+ system_prompt=system_prompt,
390
+ response_format=response_format,
391
+ messages=state["messages"],
392
+ tool_choice=None,
393
+ )
363
394
 
364
395
  # Apply modify_model_request middleware in sequence
365
- for m in middleware_w_modify_model_request:
366
- # Filter state to only include fields defined in this middleware's schema
367
- filtered_state = _filter_state_for_schema(state, m.state_schema)
368
- request = m.modify_model_request(request, filtered_state)
396
+ for use_runtime, m in model_request_signatures:
397
+ if use_runtime:
398
+ m.modify_model_request(request, state, runtime)
399
+ else:
400
+ m.modify_model_request(request, state) # type: ignore[call-arg]
369
401
 
370
- # Get the bound model with the final request
402
+ # Get the final model and messages
371
403
  model_ = _get_bound_model(request)
404
+ messages = request.messages
405
+ if request.system_prompt:
406
+ messages = [SystemMessage(request.system_prompt), *messages]
407
+
372
408
  output = await model_.ainvoke(messages)
373
- return _handle_model_output(state, output)
409
+ return _handle_model_output(output)
374
410
 
375
411
  # Use sync or async based on model capabilities
376
412
  from langgraph._internal._runnable import RunnableCallable
@@ -385,16 +421,12 @@ def create_agent( # noqa: PLR0915
385
421
  for m in middleware:
386
422
  if m.__class__.before_model is not AgentMiddleware.before_model:
387
423
  graph.add_node(
388
- f"{m.__class__.__name__}.before_model",
389
- m.before_model,
390
- input_schema=m.state_schema,
424
+ f"{m.__class__.__name__}.before_model", m.before_model, input_schema=state_schema
391
425
  )
392
426
 
393
427
  if m.__class__.after_model is not AgentMiddleware.after_model:
394
428
  graph.add_node(
395
- f"{m.__class__.__name__}.after_model",
396
- m.after_model,
397
- input_schema=m.state_schema,
429
+ f"{m.__class__.__name__}.after_model", m.after_model, input_schema=state_schema
398
430
  )
399
431
 
400
432
  # add start edge
@@ -414,12 +446,12 @@ def create_agent( # noqa: PLR0915
414
446
  if tool_node is not None:
415
447
  graph.add_conditional_edges(
416
448
  "tools",
417
- _make_tools_to_model_edge(tool_node, first_node),
449
+ _make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
418
450
  [first_node, END],
419
451
  )
420
452
  graph.add_conditional_edges(
421
453
  last_node,
422
- _make_model_to_tools_edge(first_node, structured_output_tools),
454
+ _make_model_to_tools_edge(first_node, structured_output_tools, tool_node),
423
455
  [first_node, "tools", END],
424
456
  )
425
457
  elif last_node == "model_request":
@@ -478,27 +510,48 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
478
510
  return None
479
511
 
480
512
 
513
+ def _fetch_last_ai_and_tool_messages(
514
+ messages: list[AnyMessage],
515
+ ) -> tuple[AIMessage, list[ToolMessage]]:
516
+ last_ai_index: int
517
+ last_ai_message: AIMessage
518
+
519
+ for i in range(len(messages) - 1, -1, -1):
520
+ if isinstance(messages[i], AIMessage):
521
+ last_ai_index = i
522
+ last_ai_message = cast("AIMessage", messages[i])
523
+ break
524
+
525
+ tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
526
+ return last_ai_message, tool_messages
527
+
528
+
481
529
  def _make_model_to_tools_edge(
482
- first_node: str, structured_output_tools: dict[str, OutputToolBinding]
483
- ) -> Callable[[AgentState], str | None]:
484
- def model_to_tools(state: AgentState) -> str | None:
530
+ first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
531
+ ) -> Callable[[AgentState], str | list[Send] | None]:
532
+ def model_to_tools(state: AgentState) -> str | list[Send] | None:
485
533
  if jump_to := state.get("jump_to"):
486
534
  return _resolve_jump(jump_to, first_node)
487
535
 
488
- message = state["messages"][-1]
489
-
490
- # Check if this is a ToolMessage from structured output - if so, end
491
- if isinstance(message, ToolMessage) and message.name in structured_output_tools:
492
- return END
493
-
494
- # Check for tool calls
495
- if isinstance(message, AIMessage) and message.tool_calls:
496
- # If all tool calls are for structured output, don't go to tools
497
- non_structured_calls = [
498
- tc for tc in message.tool_calls if tc["name"] not in structured_output_tools
536
+ last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
537
+ tool_message_ids = [m.tool_call_id for m in tool_messages]
538
+
539
+ pending_tool_calls = [
540
+ c
541
+ for c in last_ai_message.tool_calls
542
+ if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
543
+ ]
544
+
545
+ if pending_tool_calls:
546
+ # imo we should not be injecting state, store here,
547
+ # this should be done by the tool node itself ideally but this is a consequence
548
+ # of using Send w/ tool calls directly which allows more intuitive interrupt behavior
549
+ # largely internal so can be fixed later
550
+ pending_tool_calls = [
551
+ tool_node.inject_tool_args(call, state, None) # type: ignore[arg-type]
552
+ for call in pending_tool_calls
499
553
  ]
500
- if non_structured_calls:
501
- return "tools"
554
+ return [Send("tools", [tool_call]) for tool_call in pending_tool_calls]
502
555
 
503
556
  return END
504
557
 
@@ -506,17 +559,21 @@ def _make_model_to_tools_edge(
506
559
 
507
560
 
508
561
  def _make_tools_to_model_edge(
509
- tool_node: ToolNode, next_node: str
562
+ tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
510
563
  ) -> Callable[[AgentState], str | None]:
511
564
  def tools_to_model(state: AgentState) -> str | None:
512
- ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
565
+ last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
566
+
513
567
  if all(
514
568
  tool_node.tools_by_name[c["name"]].return_direct
515
- for c in ai_message.tool_calls
569
+ for c in last_ai_message.tool_calls
516
570
  if c["name"] in tool_node.tools_by_name
517
571
  ):
518
572
  return END
519
573
 
574
+ if any(t.name in structured_output_tools for t in tool_messages):
575
+ return END
576
+
520
577
  return next_node
521
578
 
522
579
  return tools_to_model