langchain 1.1.0__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langchain-1.1.0 → langchain-1.1.2}/PKG-INFO +1 -1
- {langchain-1.1.0 → langchain-1.1.2}/langchain/__init__.py +1 -1
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/context_editing.py +1 -1
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/human_in_the_loop.py +9 -7
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/summarization.py +121 -70
- {langchain-1.1.0 → langchain-1.1.2}/pyproject.toml +1 -1
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_summarization.py +124 -132
- {langchain-1.1.0 → langchain-1.1.2}/uv.lock +32 -1
- {langchain-1.1.0 → langchain-1.1.2}/.gitignore +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/LICENSE +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/Makefile +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/README.md +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/extended_testing_deps.txt +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/factory.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/_execution.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/_redaction.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/_retry.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/file_search.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/model_call_limit.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/model_fallback.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/model_retry.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/pii.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/shell_tool.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/todo.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/tool_call_limit.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/tool_emulator.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/tool_retry.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/tool_selection.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/middleware/types.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/agents/structured_output.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/chat_models/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/chat_models/base.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/embeddings/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/embeddings/base.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/messages/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/py.typed +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/rate_limiters/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/tools/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/langchain/tools/tool_node.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/scripts/check_imports.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/cassettes/test_inference_to_native_output[False].yaml.gz +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/cassettes/test_inference_to_native_output[True].yaml.gz +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/cassettes/test_inference_to_tool_output[False].yaml.gz +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/cassettes/test_inference_to_tool_output[True].yaml.gz +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/agents/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/agents/middleware/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/agents/middleware/test_shell_tool_integration.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/cache/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/cache/fake_embeddings.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/chat_models/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/chat_models/test_base.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/conftest.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/embeddings/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/embeddings/test_base.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/integration_tests/test_compile.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/__snapshots__/test_middleware_decorators.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/__snapshots__/test_middleware_framework.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/__snapshots__/test_return_direct_graph.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/any_str.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/compose-postgres.yml +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/compose-redis.yml +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/conftest.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/conftest_checkpointer.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/conftest_store.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/memory_assert.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/messages.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_decorators.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_diagram.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_framework.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/__snapshots__/test_decorators.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/__snapshots__/test_diagram.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/__snapshots__/test_framework.ambr +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_composition.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_decorators.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_diagram.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_framework.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_overrides.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_sync_async_wrappers.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_tools.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/core/test_wrap_tool_call.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_context_editing.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_file_search.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_model_fallback.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_model_retry.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_pii.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_shell_execution_policies.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_structured_output_retry.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_todo.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_tool_emulator.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_tool_retry.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/middleware/implementations/test_tool_selection.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/model.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/specifications/responses.json +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/specifications/return_direct.json +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_create_agent_tool_validation.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_injected_runtime_create_agent.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_react_agent.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_response_format.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_response_format_integration.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_responses.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_responses_spec.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_return_direct_graph.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_return_direct_spec.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_state_schema.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/test_system_message.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/agents/utils.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/chat_models/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/chat_models/test_chat_models.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/conftest.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/embeddings/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/embeddings/test_base.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/embeddings/test_imports.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/stubs.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/test_dependencies.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/test_imports.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/test_pytest_config.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/tools/__init__.py +0 -0
- {langchain-1.1.0 → langchain-1.1.2}/tests/unit_tests/tools/test_imports.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: langchain
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.2
|
|
4
4
|
Summary: Building applications with LLMs through composability
|
|
5
5
|
Project-URL: Homepage, https://docs.langchain.com/
|
|
6
6
|
Project-URL: Documentation, https://reference.langchain.com/python/langchain/langchain/
|
|
@@ -189,7 +189,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
189
189
|
configured thresholds.
|
|
190
190
|
|
|
191
191
|
Currently the `ClearToolUsesEdit` strategy is supported, aligning with Anthropic's
|
|
192
|
-
`clear_tool_uses_20250919` behavior [(read more)](https://
|
|
192
|
+
`clear_tool_uses_20250919` behavior [(read more)](https://platform.claude.com/docs/en/agents-and-tools/tool-use/memory-tool).
|
|
193
193
|
"""
|
|
194
194
|
|
|
195
195
|
edits: list[ContextEdit]
|
|
@@ -7,7 +7,7 @@ from langgraph.runtime import Runtime
|
|
|
7
7
|
from langgraph.types import interrupt
|
|
8
8
|
from typing_extensions import NotRequired, TypedDict
|
|
9
9
|
|
|
10
|
-
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
|
10
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class Action(TypedDict):
|
|
@@ -102,7 +102,7 @@ class HITLResponse(TypedDict):
|
|
|
102
102
|
class _DescriptionFactory(Protocol):
|
|
103
103
|
"""Callable that generates a description for a tool call."""
|
|
104
104
|
|
|
105
|
-
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime) -> str:
|
|
105
|
+
def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
|
|
106
106
|
"""Generate a description for a tool call."""
|
|
107
107
|
...
|
|
108
108
|
|
|
@@ -138,7 +138,7 @@ class InterruptOnConfig(TypedDict):
|
|
|
138
138
|
def format_tool_description(
|
|
139
139
|
tool_call: ToolCall,
|
|
140
140
|
state: AgentState,
|
|
141
|
-
runtime: Runtime
|
|
141
|
+
runtime: Runtime[ContextT]
|
|
142
142
|
) -> str:
|
|
143
143
|
import json
|
|
144
144
|
return (
|
|
@@ -156,7 +156,7 @@ class InterruptOnConfig(TypedDict):
|
|
|
156
156
|
"""JSON schema for the args associated with the action, if edits are allowed."""
|
|
157
157
|
|
|
158
158
|
|
|
159
|
-
class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
159
|
+
class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
|
|
160
160
|
"""Human in the loop middleware."""
|
|
161
161
|
|
|
162
162
|
def __init__(
|
|
@@ -204,7 +204,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
204
204
|
tool_call: ToolCall,
|
|
205
205
|
config: InterruptOnConfig,
|
|
206
206
|
state: AgentState,
|
|
207
|
-
runtime: Runtime,
|
|
207
|
+
runtime: Runtime[ContextT],
|
|
208
208
|
) -> tuple[ActionRequest, ReviewConfig]:
|
|
209
209
|
"""Create an ActionRequest and ReviewConfig for a tool call."""
|
|
210
210
|
tool_name = tool_call["name"]
|
|
@@ -277,7 +277,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
277
277
|
)
|
|
278
278
|
raise ValueError(msg)
|
|
279
279
|
|
|
280
|
-
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
|
280
|
+
def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
|
|
281
281
|
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
|
282
282
|
messages = state["messages"]
|
|
283
283
|
if not messages:
|
|
@@ -350,6 +350,8 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
350
350
|
|
|
351
351
|
return {"messages": [last_ai_msg, *artificial_tool_messages]}
|
|
352
352
|
|
|
353
|
-
async def aafter_model(
|
|
353
|
+
async def aafter_model(
|
|
354
|
+
self, state: AgentState, runtime: Runtime[ContextT]
|
|
355
|
+
) -> dict[str, Any] | None:
|
|
354
356
|
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`."""
|
|
355
357
|
return self.after_model(state, runtime)
|
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import uuid
|
|
4
4
|
import warnings
|
|
5
5
|
from collections.abc import Callable, Iterable, Mapping
|
|
6
|
+
from functools import partial
|
|
6
7
|
from typing import Any, Literal, cast
|
|
7
8
|
|
|
8
9
|
from langchain_core.messages import (
|
|
9
|
-
AIMessage,
|
|
10
10
|
AnyMessage,
|
|
11
11
|
MessageLikeRepresentation,
|
|
12
12
|
RemoveMessage,
|
|
@@ -55,13 +55,76 @@ Messages to summarize:
|
|
|
55
55
|
_DEFAULT_MESSAGES_TO_KEEP = 20
|
|
56
56
|
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
|
57
57
|
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
|
58
|
-
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
|
|
59
58
|
|
|
60
59
|
ContextFraction = tuple[Literal["fraction"], float]
|
|
60
|
+
"""Fraction of model's maximum input tokens.
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
To specify 50% of the model's max input tokens:
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
("fraction", 0.5)
|
|
67
|
+
```
|
|
68
|
+
"""
|
|
69
|
+
|
|
61
70
|
ContextTokens = tuple[Literal["tokens"], int]
|
|
71
|
+
"""Absolute number of tokens.
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
To specify 3000 tokens:
|
|
75
|
+
|
|
76
|
+
```python
|
|
77
|
+
("tokens", 3000)
|
|
78
|
+
```
|
|
79
|
+
"""
|
|
80
|
+
|
|
62
81
|
ContextMessages = tuple[Literal["messages"], int]
|
|
82
|
+
"""Absolute number of messages.
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
To specify 50 messages:
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
("messages", 50)
|
|
89
|
+
```
|
|
90
|
+
"""
|
|
63
91
|
|
|
64
92
|
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
|
93
|
+
"""Union type for context size specifications.
|
|
94
|
+
|
|
95
|
+
Can be either:
|
|
96
|
+
|
|
97
|
+
- [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
|
|
98
|
+
fraction of the model's maximum input tokens.
|
|
99
|
+
- [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
|
|
100
|
+
number of tokens.
|
|
101
|
+
- [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
|
|
102
|
+
absolute number of messages.
|
|
103
|
+
|
|
104
|
+
Depending on use with `trigger` or `keep` parameters, this type indicates either
|
|
105
|
+
when to trigger summarization or how much context to retain.
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
```python
|
|
109
|
+
# ContextFraction
|
|
110
|
+
context_size: ContextSize = ("fraction", 0.5)
|
|
111
|
+
|
|
112
|
+
# ContextTokens
|
|
113
|
+
context_size: ContextSize = ("tokens", 3000)
|
|
114
|
+
|
|
115
|
+
# ContextMessages
|
|
116
|
+
context_size: ContextSize = ("messages", 50)
|
|
117
|
+
```
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
|
|
122
|
+
"""Tune parameters of approximate token counter based on model type."""
|
|
123
|
+
if model._llm_type == "anthropic-chat":
|
|
124
|
+
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
|
|
125
|
+
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
|
|
126
|
+
return partial(count_tokens_approximately, chars_per_token=3.3)
|
|
127
|
+
return count_tokens_approximately
|
|
65
128
|
|
|
66
129
|
|
|
67
130
|
class SummarizationMiddleware(AgentMiddleware):
|
|
@@ -89,19 +152,48 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
89
152
|
model: The language model to use for generating summaries.
|
|
90
153
|
trigger: One or more thresholds that trigger summarization.
|
|
91
154
|
|
|
92
|
-
Provide a single
|
|
93
|
-
summarization
|
|
155
|
+
Provide a single
|
|
156
|
+
[`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
|
157
|
+
tuple or a list of tuples, in which case summarization runs when any
|
|
158
|
+
threshold is met.
|
|
159
|
+
|
|
160
|
+
!!! example
|
|
161
|
+
|
|
162
|
+
```python
|
|
163
|
+
# Trigger summarization when 50 messages is reached
|
|
164
|
+
("messages", 50)
|
|
165
|
+
|
|
166
|
+
# Trigger summarization when 3000 tokens is reached
|
|
167
|
+
("tokens", 3000)
|
|
168
|
+
|
|
169
|
+
# Trigger summarization either when 80% of model's max input tokens
|
|
170
|
+
# is reached or when 100 messages is reached (whichever comes first)
|
|
171
|
+
[("fraction", 0.8), ("messages", 100)]
|
|
172
|
+
```
|
|
94
173
|
|
|
95
|
-
|
|
96
|
-
|
|
174
|
+
See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
|
175
|
+
for more details.
|
|
97
176
|
keep: Context retention policy applied after summarization.
|
|
98
177
|
|
|
99
|
-
Provide a `ContextSize`
|
|
178
|
+
Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
|
179
|
+
tuple to specify how much history to preserve.
|
|
100
180
|
|
|
101
|
-
Defaults to keeping the most recent 20 messages.
|
|
181
|
+
Defaults to keeping the most recent `20` messages.
|
|
102
182
|
|
|
103
|
-
|
|
104
|
-
|
|
183
|
+
Does not support multiple values like `trigger`.
|
|
184
|
+
|
|
185
|
+
!!! example
|
|
186
|
+
|
|
187
|
+
```python
|
|
188
|
+
# Keep the most recent 20 messages
|
|
189
|
+
("messages", 20)
|
|
190
|
+
|
|
191
|
+
# Keep the most recent 3000 tokens
|
|
192
|
+
("tokens", 3000)
|
|
193
|
+
|
|
194
|
+
# Keep the most recent 30% of the model's max input tokens
|
|
195
|
+
("fraction", 0.3)
|
|
196
|
+
```
|
|
105
197
|
token_counter: Function to count tokens in messages.
|
|
106
198
|
summary_prompt: Prompt template for generating summaries.
|
|
107
199
|
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
|
|
@@ -150,7 +242,10 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
150
242
|
self._trigger_conditions = trigger_conditions
|
|
151
243
|
|
|
152
244
|
self.keep = self._validate_context_size(keep, "keep")
|
|
153
|
-
|
|
245
|
+
if token_counter is count_tokens_approximately:
|
|
246
|
+
self.token_counter = _get_approximate_token_counter(self.model)
|
|
247
|
+
else:
|
|
248
|
+
self.token_counter = token_counter
|
|
154
249
|
self.summary_prompt = summary_prompt
|
|
155
250
|
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
|
156
251
|
|
|
@@ -300,11 +395,8 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
300
395
|
return 0
|
|
301
396
|
cutoff_candidate = len(messages) - 1
|
|
302
397
|
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
return i
|
|
306
|
-
|
|
307
|
-
return 0
|
|
398
|
+
# Advance past any ToolMessages to avoid splitting AI/Tool pairs
|
|
399
|
+
return self._find_safe_cutoff_point(messages, cutoff_candidate)
|
|
308
400
|
|
|
309
401
|
def _get_profile_limits(self) -> int | None:
|
|
310
402
|
"""Retrieve max input token limit from the model profile."""
|
|
@@ -366,67 +458,26 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
366
458
|
|
|
367
459
|
Returns the index where messages can be safely cut without separating
|
|
368
460
|
related AI and Tool messages. Returns `0` if no safe cutoff is found.
|
|
461
|
+
|
|
462
|
+
This is aggressive with summarization - if the target cutoff lands in the
|
|
463
|
+
middle of tool messages, we advance past all of them (summarizing more).
|
|
369
464
|
"""
|
|
370
465
|
if len(messages) <= messages_to_keep:
|
|
371
466
|
return 0
|
|
372
467
|
|
|
373
468
|
target_cutoff = len(messages) - messages_to_keep
|
|
469
|
+
return self._find_safe_cutoff_point(messages, target_cutoff)
|
|
374
470
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
return i
|
|
378
|
-
|
|
379
|
-
return 0
|
|
380
|
-
|
|
381
|
-
def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
|
|
382
|
-
"""Check if cutting at index would separate AI/Tool message pairs."""
|
|
383
|
-
if cutoff_index >= len(messages):
|
|
384
|
-
return True
|
|
385
|
-
|
|
386
|
-
search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
|
|
387
|
-
search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
|
|
471
|
+
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
|
|
472
|
+
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
|
|
388
473
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
return True
|
|
398
|
-
|
|
399
|
-
def _has_tool_calls(self, message: AnyMessage) -> bool:
|
|
400
|
-
"""Check if message is an AI message with tool calls."""
|
|
401
|
-
return (
|
|
402
|
-
isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
|
|
406
|
-
"""Extract tool call IDs from an AI message."""
|
|
407
|
-
tool_call_ids = set()
|
|
408
|
-
for tc in ai_message.tool_calls:
|
|
409
|
-
call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
|
|
410
|
-
if call_id is not None:
|
|
411
|
-
tool_call_ids.add(call_id)
|
|
412
|
-
return tool_call_ids
|
|
413
|
-
|
|
414
|
-
def _cutoff_separates_tool_pair(
|
|
415
|
-
self,
|
|
416
|
-
messages: list[AnyMessage],
|
|
417
|
-
ai_message_index: int,
|
|
418
|
-
cutoff_index: int,
|
|
419
|
-
tool_call_ids: set[str],
|
|
420
|
-
) -> bool:
|
|
421
|
-
"""Check if cutoff separates an AI message from its corresponding tool messages."""
|
|
422
|
-
for j in range(ai_message_index + 1, len(messages)):
|
|
423
|
-
message = messages[j]
|
|
424
|
-
if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
|
|
425
|
-
ai_before_cutoff = ai_message_index < cutoff_index
|
|
426
|
-
tool_before_cutoff = j < cutoff_index
|
|
427
|
-
if ai_before_cutoff != tool_before_cutoff:
|
|
428
|
-
return True
|
|
429
|
-
return False
|
|
474
|
+
If the message at cutoff_index is a ToolMessage, advance until we find
|
|
475
|
+
a non-ToolMessage. This ensures we never cut in the middle of parallel
|
|
476
|
+
tool call responses.
|
|
477
|
+
"""
|
|
478
|
+
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
|
|
479
|
+
cutoff_index += 1
|
|
480
|
+
return cutoff_index
|
|
430
481
|
|
|
431
482
|
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
432
483
|
"""Generate summary for the given messages."""
|
|
@@ -121,46 +121,6 @@ def test_summarization_middleware_helper_methods() -> None:
|
|
|
121
121
|
assert "Here is a summary of the conversation to date:" in new_messages[0].content
|
|
122
122
|
assert summary in new_messages[0].content
|
|
123
123
|
|
|
124
|
-
# Test tool call detection
|
|
125
|
-
ai_message_no_tools = AIMessage(content="Hello")
|
|
126
|
-
assert not middleware._has_tool_calls(ai_message_no_tools)
|
|
127
|
-
|
|
128
|
-
ai_message_with_tools = AIMessage(
|
|
129
|
-
content="Hello", tool_calls=[{"name": "test", "args": {}, "id": "1"}]
|
|
130
|
-
)
|
|
131
|
-
assert middleware._has_tool_calls(ai_message_with_tools)
|
|
132
|
-
|
|
133
|
-
human_message = HumanMessage(content="Hello")
|
|
134
|
-
assert not middleware._has_tool_calls(human_message)
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def test_summarization_middleware_tool_call_safety() -> None:
|
|
138
|
-
"""Test SummarizationMiddleware tool call safety logic."""
|
|
139
|
-
model = FakeToolCallingModel()
|
|
140
|
-
middleware = SummarizationMiddleware(
|
|
141
|
-
model=model, trigger=("tokens", 1000), keep=("messages", 3)
|
|
142
|
-
)
|
|
143
|
-
|
|
144
|
-
# Test safe cutoff point detection with tool calls
|
|
145
|
-
messages = [
|
|
146
|
-
HumanMessage(content="1"),
|
|
147
|
-
AIMessage(content="2", tool_calls=[{"name": "test", "args": {}, "id": "1"}]),
|
|
148
|
-
ToolMessage(content="3", tool_call_id="1"),
|
|
149
|
-
HumanMessage(content="4"),
|
|
150
|
-
]
|
|
151
|
-
|
|
152
|
-
# Safe cutoff (doesn't separate AI/Tool pair)
|
|
153
|
-
is_safe = middleware._is_safe_cutoff_point(messages, 0)
|
|
154
|
-
assert is_safe is True
|
|
155
|
-
|
|
156
|
-
# Unsafe cutoff (separates AI/Tool pair)
|
|
157
|
-
is_safe = middleware._is_safe_cutoff_point(messages, 2)
|
|
158
|
-
assert is_safe is False
|
|
159
|
-
|
|
160
|
-
# Test tool call ID extraction
|
|
161
|
-
ids = middleware._extract_tool_call_ids(messages[1])
|
|
162
|
-
assert ids == {"1"}
|
|
163
|
-
|
|
164
124
|
|
|
165
125
|
def test_summarization_middleware_summary_creation() -> None:
|
|
166
126
|
"""Test SummarizationMiddleware summary creation."""
|
|
@@ -315,8 +275,8 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
|
|
|
315
275
|
]
|
|
316
276
|
|
|
317
277
|
|
|
318
|
-
def
|
|
319
|
-
"""Ensure token retention
|
|
278
|
+
def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
|
|
279
|
+
"""Ensure token retention advances past tool messages for aggressive summarization."""
|
|
320
280
|
|
|
321
281
|
def token_counter(messages: list[AnyMessage]) -> int:
|
|
322
282
|
return sum(len(getattr(message, "content", "")) for message in messages)
|
|
@@ -328,6 +288,10 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
|
|
|
328
288
|
)
|
|
329
289
|
middleware.token_counter = token_counter
|
|
330
290
|
|
|
291
|
+
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
|
|
292
|
+
# Target keep: 500 tokens (50% of 1000)
|
|
293
|
+
# Binary search finds cutoff around index 2 (ToolMessage)
|
|
294
|
+
# We advance past it to index 3 (HumanMessage)
|
|
331
295
|
messages: list[AnyMessage] = [
|
|
332
296
|
HumanMessage(content="H" * 300),
|
|
333
297
|
AIMessage(
|
|
@@ -344,13 +308,14 @@ def test_summarization_middleware_token_retention_pct_respects_tool_pairs() -> N
|
|
|
344
308
|
assert result is not None
|
|
345
309
|
|
|
346
310
|
preserved_messages = result["messages"][2:]
|
|
347
|
-
|
|
311
|
+
# With aggressive summarization, we advance past the ToolMessage
|
|
312
|
+
# So we preserve messages from index 3 onward (the two HumanMessages)
|
|
313
|
+
assert preserved_messages == messages[3:]
|
|
348
314
|
|
|
315
|
+
# Verify preserved tokens are within budget
|
|
349
316
|
target_token_count = int(1000 * 0.5)
|
|
350
317
|
preserved_tokens = middleware.token_counter(preserved_messages)
|
|
351
|
-
|
|
352
|
-
# Tool pair retention can exceed the target token count but should keep the pair intact.
|
|
353
|
-
assert preserved_tokens > target_token_count
|
|
318
|
+
assert preserved_tokens <= target_token_count
|
|
354
319
|
|
|
355
320
|
|
|
356
321
|
def test_summarization_middleware_missing_profile() -> None:
|
|
@@ -692,95 +657,38 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
|
|
|
692
657
|
assert cutoff == 0
|
|
693
658
|
|
|
694
659
|
|
|
695
|
-
def
|
|
696
|
-
"""Test
|
|
660
|
+
def test_summarization_middleware_find_safe_cutoff_point() -> None:
|
|
661
|
+
"""Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
|
|
697
662
|
model = FakeToolCallingModel()
|
|
698
|
-
middleware = SummarizationMiddleware(
|
|
699
|
-
|
|
700
|
-
# Test with dict-style tool calls
|
|
701
|
-
ai_message_dict = AIMessage(
|
|
702
|
-
content="test", tool_calls=[{"name": "tool1", "args": {}, "id": "id1"}]
|
|
703
|
-
)
|
|
704
|
-
ids = middleware._extract_tool_call_ids(ai_message_dict)
|
|
705
|
-
assert ids == {"id1"}
|
|
706
|
-
|
|
707
|
-
# Test with multiple tool calls
|
|
708
|
-
ai_message_multiple = AIMessage(
|
|
709
|
-
content="test",
|
|
710
|
-
tool_calls=[
|
|
711
|
-
{"name": "tool1", "args": {}, "id": "id1"},
|
|
712
|
-
{"name": "tool2", "args": {}, "id": "id2"},
|
|
713
|
-
],
|
|
663
|
+
middleware = SummarizationMiddleware(
|
|
664
|
+
model=model, trigger=("messages", 10), keep=("messages", 2)
|
|
714
665
|
)
|
|
715
|
-
ids = middleware._extract_tool_call_ids(ai_message_multiple)
|
|
716
|
-
assert ids == {"id1", "id2"}
|
|
717
|
-
|
|
718
|
-
# Test with empty tool calls list
|
|
719
|
-
ai_message_empty = AIMessage(content="test", tool_calls=[])
|
|
720
|
-
ids = middleware._extract_tool_call_ids(ai_message_empty)
|
|
721
|
-
assert len(ids) == 0
|
|
722
666
|
|
|
723
|
-
|
|
724
|
-
def test_summarization_middleware_complex_tool_pair_scenarios() -> None:
|
|
725
|
-
"""Test complex tool call pairing scenarios."""
|
|
726
|
-
model = FakeToolCallingModel()
|
|
727
|
-
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5), keep=("messages", 3))
|
|
728
|
-
|
|
729
|
-
# Test with multiple AI messages with tool calls
|
|
730
|
-
messages = [
|
|
667
|
+
messages: list[AnyMessage] = [
|
|
731
668
|
HumanMessage(content="msg1"),
|
|
732
|
-
AIMessage(content="
|
|
669
|
+
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
|
|
733
670
|
ToolMessage(content="result1", tool_call_id="call1"),
|
|
734
|
-
HumanMessage(content="msg2"),
|
|
735
|
-
AIMessage(content="ai2", tool_calls=[{"name": "tool2", "args": {}, "id": "call2"}]),
|
|
736
671
|
ToolMessage(content="result2", tool_call_id="call2"),
|
|
737
|
-
HumanMessage(content="
|
|
672
|
+
HumanMessage(content="msg2"),
|
|
738
673
|
]
|
|
739
674
|
|
|
740
|
-
#
|
|
741
|
-
assert
|
|
742
|
-
|
|
743
|
-
# Test cutoff at index 3 - safe (keeps first pair together)
|
|
744
|
-
assert middleware._is_safe_cutoff_point(messages, 3)
|
|
745
|
-
|
|
746
|
-
# Test cutoff at index 5 - unsafe (separates second AI/Tool pair)
|
|
747
|
-
assert not middleware._is_safe_cutoff_point(messages, 5)
|
|
748
|
-
|
|
749
|
-
# Test _cutoff_separates_tool_pair directly
|
|
750
|
-
assert middleware._cutoff_separates_tool_pair(messages, 1, 2, {"call1"})
|
|
751
|
-
assert not middleware._cutoff_separates_tool_pair(messages, 1, 0, {"call1"})
|
|
752
|
-
assert not middleware._cutoff_separates_tool_pair(messages, 1, 3, {"call1"})
|
|
675
|
+
# Starting at a non-ToolMessage returns the same index
|
|
676
|
+
assert middleware._find_safe_cutoff_point(messages, 0) == 0
|
|
677
|
+
assert middleware._find_safe_cutoff_point(messages, 1) == 1
|
|
753
678
|
|
|
679
|
+
# Starting at a ToolMessage advances to the next non-ToolMessage
|
|
680
|
+
assert middleware._find_safe_cutoff_point(messages, 2) == 4
|
|
681
|
+
assert middleware._find_safe_cutoff_point(messages, 3) == 4
|
|
754
682
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
model = FakeToolCallingModel()
|
|
758
|
-
middleware = SummarizationMiddleware(
|
|
759
|
-
model=model, trigger=("messages", 10), keep=("messages", 2)
|
|
760
|
-
)
|
|
761
|
-
|
|
762
|
-
# Create messages with tool pair separated by some distance
|
|
763
|
-
# Search range is 5, so messages within 5 positions of cutoff are checked
|
|
764
|
-
messages = [
|
|
765
|
-
HumanMessage(content="msg1"),
|
|
766
|
-
HumanMessage(content="msg2"),
|
|
767
|
-
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
|
|
768
|
-
HumanMessage(content="msg3"),
|
|
769
|
-
HumanMessage(content="msg4"),
|
|
770
|
-
ToolMessage(content="result", tool_call_id="call1"),
|
|
771
|
-
HumanMessage(content="msg6"),
|
|
772
|
-
]
|
|
773
|
-
|
|
774
|
-
# Cutoff at index 3 would separate: [0,1,2] from [3,4,5,6]
|
|
775
|
-
# AI at index 2 is before cutoff, Tool at index 5 is after cutoff - unsafe
|
|
776
|
-
assert not middleware._is_safe_cutoff_point(messages, 3)
|
|
683
|
+
# Starting at the HumanMessage after tools returns that index
|
|
684
|
+
assert middleware._find_safe_cutoff_point(messages, 4) == 4
|
|
777
685
|
|
|
778
|
-
#
|
|
779
|
-
assert middleware.
|
|
686
|
+
# Starting past the end returns the index unchanged
|
|
687
|
+
assert middleware._find_safe_cutoff_point(messages, 5) == 5
|
|
780
688
|
|
|
781
|
-
# Cutoff at
|
|
782
|
-
assert middleware.
|
|
783
|
-
assert middleware.
|
|
689
|
+
# Cutoff at or past length stays the same
|
|
690
|
+
assert middleware._find_safe_cutoff_point(messages, len(messages)) == len(messages)
|
|
691
|
+
assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5
|
|
784
692
|
|
|
785
693
|
|
|
786
694
|
def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
|
|
@@ -880,15 +788,99 @@ def test_summarization_middleware_fraction_trigger_with_no_profile() -> None:
|
|
|
880
788
|
middleware._get_profile_limits = original_method
|
|
881
789
|
|
|
882
790
|
|
|
883
|
-
def
|
|
884
|
-
""
|
|
885
|
-
model = FakeToolCallingModel()
|
|
886
|
-
middleware = SummarizationMiddleware(model=model, trigger=("messages", 5))
|
|
791
|
+
def test_summarization_adjust_token_counts() -> None:
|
|
792
|
+
test_message = HumanMessage(content="a" * 12)
|
|
887
793
|
|
|
888
|
-
|
|
794
|
+
middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5))
|
|
795
|
+
count_1 = middleware.token_counter([test_message])
|
|
796
|
+
|
|
797
|
+
class MockAnthropicModel(MockChatModel):
|
|
798
|
+
@property
|
|
799
|
+
def _llm_type(self) -> str:
|
|
800
|
+
return "anthropic-chat"
|
|
801
|
+
|
|
802
|
+
middleware = SummarizationMiddleware(model=MockAnthropicModel(), trigger=("messages", 5))
|
|
803
|
+
count_2 = middleware.token_counter([test_message])
|
|
804
|
+
|
|
805
|
+
assert count_1 != count_2
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
|
|
809
|
+
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
|
|
810
|
+
middleware = SummarizationMiddleware(
|
|
811
|
+
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
|
|
812
|
+
)
|
|
813
|
+
tool_calls = [{"name": f"tool_{i}", "args": {}, "id": f"call_{i}"} for i in range(10)]
|
|
814
|
+
human_message = HumanMessage(content="calling 10 tools")
|
|
815
|
+
ai_message = AIMessage(content="calling 10 tools", tool_calls=tool_calls)
|
|
816
|
+
tool_messages = [
|
|
817
|
+
ToolMessage(content=f"result_{i}", tool_call_id=f"call_{i}") for i in range(10)
|
|
818
|
+
]
|
|
819
|
+
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]
|
|
820
|
+
|
|
821
|
+
# Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
|
|
822
|
+
assert middleware._find_safe_cutoff_point(messages, 7) == 12
|
|
823
|
+
|
|
824
|
+
# Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
|
|
825
|
+
for i in range(2, 12):
|
|
826
|
+
assert middleware._find_safe_cutoff_point(messages, i) == 12
|
|
827
|
+
|
|
828
|
+
# Cutoff at index 0, 1 (before tool messages) stays the same
|
|
829
|
+
assert middleware._find_safe_cutoff_point(messages, 0) == 0
|
|
830
|
+
assert middleware._find_safe_cutoff_point(messages, 1) == 1
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
|
|
834
|
+
"""Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
|
|
835
|
+
middleware = SummarizationMiddleware(
|
|
836
|
+
model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
|
|
837
|
+
)
|
|
889
838
|
|
|
890
|
-
#
|
|
891
|
-
|
|
839
|
+
# Messages: [Human, AI, Tool, Tool, Tool, Human]
|
|
840
|
+
messages: list[AnyMessage] = [
|
|
841
|
+
HumanMessage(content="msg1"),
|
|
842
|
+
AIMessage(
|
|
843
|
+
content="ai",
|
|
844
|
+
tool_calls=[
|
|
845
|
+
{"name": "tool1", "args": {}, "id": "call1"},
|
|
846
|
+
{"name": "tool2", "args": {}, "id": "call2"},
|
|
847
|
+
{"name": "tool3", "args": {}, "id": "call3"},
|
|
848
|
+
],
|
|
849
|
+
),
|
|
850
|
+
ToolMessage(content="result1", tool_call_id="call1"),
|
|
851
|
+
ToolMessage(content="result2", tool_call_id="call2"),
|
|
852
|
+
ToolMessage(content="result3", tool_call_id="call3"),
|
|
853
|
+
HumanMessage(content="msg2"),
|
|
854
|
+
]
|
|
855
|
+
|
|
856
|
+
# Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
|
|
857
|
+
# Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
|
|
858
|
+
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
|
|
859
|
+
assert cutoff == 5
|
|
860
|
+
|
|
861
|
+
# With messages_to_keep=2, target cutoff index is 6 - 2 = 4
|
|
862
|
+
# Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
|
|
863
|
+
# This is aggressive - we keep only 1 message instead of 2
|
|
864
|
+
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
|
|
865
|
+
assert cutoff == 5
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
|
|
869
|
+
"""Test cutoff when target lands exactly at the first ToolMessage."""
|
|
870
|
+
middleware = SummarizationMiddleware(
|
|
871
|
+
model=MockChatModel(), trigger=("messages", 8), keep=("messages", 4)
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
messages: list[AnyMessage] = [
|
|
875
|
+
HumanMessage(content="msg1"),
|
|
876
|
+
HumanMessage(content="msg2"),
|
|
877
|
+
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
|
|
878
|
+
ToolMessage(content="result", tool_call_id="call1"),
|
|
879
|
+
HumanMessage(content="msg3"),
|
|
880
|
+
HumanMessage(content="msg4"),
|
|
881
|
+
]
|
|
892
882
|
|
|
893
|
-
#
|
|
894
|
-
|
|
883
|
+
# Target cutoff index is len(messages) - messages_to_keep = 6 - 4 = 2
|
|
884
|
+
# Index 2 is an AIMessage (safe cutoff point), so no adjustment needed
|
|
885
|
+
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4)
|
|
886
|
+
assert cutoff == 2
|