langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +1 -7
- langchain/agents/factory.py +153 -79
- langchain/agents/middleware/__init__.py +18 -23
- langchain/agents/middleware/_execution.py +29 -32
- langchain/agents/middleware/_redaction.py +108 -22
- langchain/agents/middleware/_retry.py +123 -0
- langchain/agents/middleware/context_editing.py +47 -25
- langchain/agents/middleware/file_search.py +19 -14
- langchain/agents/middleware/human_in_the_loop.py +87 -57
- langchain/agents/middleware/model_call_limit.py +64 -18
- langchain/agents/middleware/model_fallback.py +7 -9
- langchain/agents/middleware/model_retry.py +307 -0
- langchain/agents/middleware/pii.py +82 -29
- langchain/agents/middleware/shell_tool.py +254 -107
- langchain/agents/middleware/summarization.py +469 -95
- langchain/agents/middleware/todo.py +129 -31
- langchain/agents/middleware/tool_call_limit.py +105 -71
- langchain/agents/middleware/tool_emulator.py +47 -38
- langchain/agents/middleware/tool_retry.py +183 -164
- langchain/agents/middleware/tool_selection.py +81 -37
- langchain/agents/middleware/types.py +856 -427
- langchain/agents/structured_output.py +65 -42
- langchain/chat_models/__init__.py +1 -7
- langchain/chat_models/base.py +253 -196
- langchain/embeddings/__init__.py +0 -5
- langchain/embeddings/base.py +79 -65
- langchain/messages/__init__.py +0 -5
- langchain/tools/__init__.py +1 -7
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
- langchain-1.2.4.dist-info/RECORD +36 -0
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
- langchain-1.0.5.dist-info/RECORD +0 -34
- {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
"""Summarization middleware."""
|
|
2
2
|
|
|
3
3
|
import uuid
|
|
4
|
-
|
|
5
|
-
from
|
|
4
|
+
import warnings
|
|
5
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import Any, Literal, cast
|
|
6
8
|
|
|
7
9
|
from langchain_core.messages import (
|
|
8
10
|
AIMessage,
|
|
@@ -12,11 +14,16 @@ from langchain_core.messages import (
|
|
|
12
14
|
ToolMessage,
|
|
13
15
|
)
|
|
14
16
|
from langchain_core.messages.human import HumanMessage
|
|
15
|
-
from langchain_core.messages.utils import
|
|
17
|
+
from langchain_core.messages.utils import (
|
|
18
|
+
count_tokens_approximately,
|
|
19
|
+
get_buffer_string,
|
|
20
|
+
trim_messages,
|
|
21
|
+
)
|
|
16
22
|
from langgraph.graph.message import (
|
|
17
23
|
REMOVE_ALL_MESSAGES,
|
|
18
24
|
)
|
|
19
25
|
from langgraph.runtime import Runtime
|
|
26
|
+
from typing_extensions import override
|
|
20
27
|
|
|
21
28
|
from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
|
22
29
|
from langchain.chat_models import BaseChatModel, init_chat_model
|
|
@@ -51,12 +58,79 @@ Messages to summarize:
|
|
|
51
58
|
{messages}
|
|
52
59
|
</messages>""" # noqa: E501
|
|
53
60
|
|
|
54
|
-
SUMMARY_PREFIX = "## Previous conversation summary:"
|
|
55
|
-
|
|
56
61
|
_DEFAULT_MESSAGES_TO_KEEP = 20
|
|
57
62
|
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
|
|
58
63
|
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
|
|
59
|
-
|
|
64
|
+
|
|
65
|
+
ContextFraction = tuple[Literal["fraction"], float]
|
|
66
|
+
"""Fraction of model's maximum input tokens.
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
To specify 50% of the model's max input tokens:
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
("fraction", 0.5)
|
|
73
|
+
```
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
ContextTokens = tuple[Literal["tokens"], int]
|
|
77
|
+
"""Absolute number of tokens.
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
To specify 3000 tokens:
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
("tokens", 3000)
|
|
84
|
+
```
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
ContextMessages = tuple[Literal["messages"], int]
|
|
88
|
+
"""Absolute number of messages.
|
|
89
|
+
|
|
90
|
+
Example:
|
|
91
|
+
To specify 50 messages:
|
|
92
|
+
|
|
93
|
+
```python
|
|
94
|
+
("messages", 50)
|
|
95
|
+
```
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
ContextSize = ContextFraction | ContextTokens | ContextMessages
|
|
99
|
+
"""Union type for context size specifications.
|
|
100
|
+
|
|
101
|
+
Can be either:
|
|
102
|
+
|
|
103
|
+
- [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
|
|
104
|
+
fraction of the model's maximum input tokens.
|
|
105
|
+
- [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
|
|
106
|
+
number of tokens.
|
|
107
|
+
- [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
|
|
108
|
+
absolute number of messages.
|
|
109
|
+
|
|
110
|
+
Depending on use with `trigger` or `keep` parameters, this type indicates either
|
|
111
|
+
when to trigger summarization or how much context to retain.
|
|
112
|
+
|
|
113
|
+
Example:
|
|
114
|
+
```python
|
|
115
|
+
# ContextFraction
|
|
116
|
+
context_size: ContextSize = ("fraction", 0.5)
|
|
117
|
+
|
|
118
|
+
# ContextTokens
|
|
119
|
+
context_size: ContextSize = ("tokens", 3000)
|
|
120
|
+
|
|
121
|
+
# ContextMessages
|
|
122
|
+
context_size: ContextSize = ("messages", 50)
|
|
123
|
+
```
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
|
|
128
|
+
"""Tune parameters of approximate token counter based on model type."""
|
|
129
|
+
if model._llm_type == "anthropic-chat": # noqa: SLF001
|
|
130
|
+
# 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
|
|
131
|
+
# API: https://platform.claude.com/docs/en/build-with-claude/token-counting
|
|
132
|
+
return partial(count_tokens_approximately, chars_per_token=3.3)
|
|
133
|
+
return count_tokens_approximately
|
|
60
134
|
|
|
61
135
|
|
|
62
136
|
class SummarizationMiddleware(AgentMiddleware):
|
|
@@ -70,48 +144,149 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
70
144
|
def __init__(
|
|
71
145
|
self,
|
|
72
146
|
model: str | BaseChatModel,
|
|
73
|
-
|
|
74
|
-
|
|
147
|
+
*,
|
|
148
|
+
trigger: ContextSize | list[ContextSize] | None = None,
|
|
149
|
+
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
|
|
75
150
|
token_counter: TokenCounter = count_tokens_approximately,
|
|
76
151
|
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
|
|
77
|
-
|
|
152
|
+
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
|
|
153
|
+
**deprecated_kwargs: Any,
|
|
78
154
|
) -> None:
|
|
79
|
-
"""Initialize
|
|
155
|
+
"""Initialize summarization middleware.
|
|
80
156
|
|
|
81
157
|
Args:
|
|
82
158
|
model: The language model to use for generating summaries.
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
159
|
+
trigger: One or more thresholds that trigger summarization.
|
|
160
|
+
|
|
161
|
+
Provide a single
|
|
162
|
+
[`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
|
163
|
+
tuple or a list of tuples, in which case summarization runs when any
|
|
164
|
+
threshold is met.
|
|
165
|
+
|
|
166
|
+
!!! example
|
|
167
|
+
|
|
168
|
+
```python
|
|
169
|
+
# Trigger summarization when 50 messages is reached
|
|
170
|
+
("messages", 50)
|
|
171
|
+
|
|
172
|
+
# Trigger summarization when 3000 tokens is reached
|
|
173
|
+
("tokens", 3000)
|
|
174
|
+
|
|
175
|
+
# Trigger summarization either when 80% of model's max input tokens
|
|
176
|
+
# is reached or when 100 messages is reached (whichever comes first)
|
|
177
|
+
[("fraction", 0.8), ("messages", 100)]
|
|
178
|
+
```
|
|
179
|
+
|
|
180
|
+
See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
|
181
|
+
for more details.
|
|
182
|
+
keep: Context retention policy applied after summarization.
|
|
183
|
+
|
|
184
|
+
Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
|
|
185
|
+
tuple to specify how much history to preserve.
|
|
186
|
+
|
|
187
|
+
Defaults to keeping the most recent `20` messages.
|
|
188
|
+
|
|
189
|
+
Does not support multiple values like `trigger`.
|
|
190
|
+
|
|
191
|
+
!!! example
|
|
192
|
+
|
|
193
|
+
```python
|
|
194
|
+
# Keep the most recent 20 messages
|
|
195
|
+
("messages", 20)
|
|
196
|
+
|
|
197
|
+
# Keep the most recent 3000 tokens
|
|
198
|
+
("tokens", 3000)
|
|
199
|
+
|
|
200
|
+
# Keep the most recent 30% of the model's max input tokens
|
|
201
|
+
("fraction", 0.3)
|
|
202
|
+
```
|
|
86
203
|
token_counter: Function to count tokens in messages.
|
|
87
204
|
summary_prompt: Prompt template for generating summaries.
|
|
88
|
-
|
|
205
|
+
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
|
|
206
|
+
the summarization call.
|
|
207
|
+
|
|
208
|
+
Pass `None` to skip trimming entirely.
|
|
89
209
|
"""
|
|
210
|
+
# Handle deprecated parameters
|
|
211
|
+
if "max_tokens_before_summary" in deprecated_kwargs:
|
|
212
|
+
value = deprecated_kwargs["max_tokens_before_summary"]
|
|
213
|
+
warnings.warn(
|
|
214
|
+
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
|
|
215
|
+
DeprecationWarning,
|
|
216
|
+
stacklevel=2,
|
|
217
|
+
)
|
|
218
|
+
if trigger is None and value is not None:
|
|
219
|
+
trigger = ("tokens", value)
|
|
220
|
+
|
|
221
|
+
if "messages_to_keep" in deprecated_kwargs:
|
|
222
|
+
value = deprecated_kwargs["messages_to_keep"]
|
|
223
|
+
warnings.warn(
|
|
224
|
+
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
|
|
225
|
+
DeprecationWarning,
|
|
226
|
+
stacklevel=2,
|
|
227
|
+
)
|
|
228
|
+
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
|
|
229
|
+
keep = ("messages", value)
|
|
230
|
+
|
|
90
231
|
super().__init__()
|
|
91
232
|
|
|
92
233
|
if isinstance(model, str):
|
|
93
234
|
model = init_chat_model(model)
|
|
94
235
|
|
|
95
236
|
self.model = model
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
237
|
+
if trigger is None:
|
|
238
|
+
self.trigger: ContextSize | list[ContextSize] | None = None
|
|
239
|
+
trigger_conditions: list[ContextSize] = []
|
|
240
|
+
elif isinstance(trigger, list):
|
|
241
|
+
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
|
|
242
|
+
self.trigger = validated_list
|
|
243
|
+
trigger_conditions = validated_list
|
|
244
|
+
else:
|
|
245
|
+
validated = self._validate_context_size(trigger, "trigger")
|
|
246
|
+
self.trigger = validated
|
|
247
|
+
trigger_conditions = [validated]
|
|
248
|
+
self._trigger_conditions = trigger_conditions
|
|
249
|
+
|
|
250
|
+
self.keep = self._validate_context_size(keep, "keep")
|
|
251
|
+
if token_counter is count_tokens_approximately:
|
|
252
|
+
self.token_counter = _get_approximate_token_counter(self.model)
|
|
253
|
+
else:
|
|
254
|
+
self.token_counter = token_counter
|
|
99
255
|
self.summary_prompt = summary_prompt
|
|
100
|
-
self.
|
|
256
|
+
self.trim_tokens_to_summarize = trim_tokens_to_summarize
|
|
257
|
+
|
|
258
|
+
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
|
|
259
|
+
if self.keep[0] == "fraction":
|
|
260
|
+
requires_profile = True
|
|
261
|
+
if requires_profile and self._get_profile_limits() is None:
|
|
262
|
+
msg = (
|
|
263
|
+
"Model profile information is required to use fractional token limits, "
|
|
264
|
+
"and is unavailable for the specified model. Please use absolute token "
|
|
265
|
+
"counts instead, or pass "
|
|
266
|
+
'`\n\nChatModel(..., profile={"max_input_tokens": ...})`.\n\n'
|
|
267
|
+
"with a desired integer value of the model's maximum input tokens."
|
|
268
|
+
)
|
|
269
|
+
raise ValueError(msg)
|
|
270
|
+
|
|
271
|
+
@override
|
|
272
|
+
def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
|
|
273
|
+
"""Process messages before model invocation, potentially triggering summarization.
|
|
101
274
|
|
|
102
|
-
|
|
103
|
-
|
|
275
|
+
Args:
|
|
276
|
+
state: The agent state.
|
|
277
|
+
runtime: The runtime environment.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
An updated state with summarized messages if summarization was performed.
|
|
281
|
+
"""
|
|
104
282
|
messages = state["messages"]
|
|
105
283
|
self._ensure_message_ids(messages)
|
|
106
284
|
|
|
107
285
|
total_tokens = self.token_counter(messages)
|
|
108
|
-
if (
|
|
109
|
-
self.max_tokens_before_summary is not None
|
|
110
|
-
and total_tokens < self.max_tokens_before_summary
|
|
111
|
-
):
|
|
286
|
+
if not self._should_summarize(messages, total_tokens):
|
|
112
287
|
return None
|
|
113
288
|
|
|
114
|
-
cutoff_index = self.
|
|
289
|
+
cutoff_index = self._determine_cutoff_index(messages)
|
|
115
290
|
|
|
116
291
|
if cutoff_index <= 0:
|
|
117
292
|
return None
|
|
@@ -129,19 +304,204 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
129
304
|
]
|
|
130
305
|
}
|
|
131
306
|
|
|
132
|
-
|
|
307
|
+
@override
|
|
308
|
+
async def abefore_model(
|
|
309
|
+
self, state: AgentState[Any], runtime: Runtime
|
|
310
|
+
) -> dict[str, Any] | None:
|
|
311
|
+
"""Process messages before model invocation, potentially triggering summarization.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
state: The agent state.
|
|
315
|
+
runtime: The runtime environment.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
An updated state with summarized messages if summarization was performed.
|
|
319
|
+
"""
|
|
320
|
+
messages = state["messages"]
|
|
321
|
+
self._ensure_message_ids(messages)
|
|
322
|
+
|
|
323
|
+
total_tokens = self.token_counter(messages)
|
|
324
|
+
if not self._should_summarize(messages, total_tokens):
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
cutoff_index = self._determine_cutoff_index(messages)
|
|
328
|
+
|
|
329
|
+
if cutoff_index <= 0:
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
|
|
333
|
+
|
|
334
|
+
summary = await self._acreate_summary(messages_to_summarize)
|
|
335
|
+
new_messages = self._build_new_messages(summary)
|
|
336
|
+
|
|
337
|
+
return {
|
|
338
|
+
"messages": [
|
|
339
|
+
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
|
340
|
+
*new_messages,
|
|
341
|
+
*preserved_messages,
|
|
342
|
+
]
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
def _should_summarize_based_on_reported_tokens(
|
|
346
|
+
self, messages: list[AnyMessage], threshold: float
|
|
347
|
+
) -> bool:
|
|
348
|
+
"""Check if reported token usage from last AIMessage exceeds threshold."""
|
|
349
|
+
last_ai_message = next(
|
|
350
|
+
(msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
|
|
351
|
+
None,
|
|
352
|
+
)
|
|
353
|
+
if ( # noqa: SIM103
|
|
354
|
+
isinstance(last_ai_message, AIMessage)
|
|
355
|
+
and last_ai_message.usage_metadata is not None
|
|
356
|
+
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
|
|
357
|
+
and reported_tokens >= threshold
|
|
358
|
+
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
|
|
359
|
+
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
|
|
360
|
+
):
|
|
361
|
+
return True
|
|
362
|
+
return False
|
|
363
|
+
|
|
364
|
+
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
|
|
365
|
+
"""Determine whether summarization should run for the current token usage."""
|
|
366
|
+
if not self._trigger_conditions:
|
|
367
|
+
return False
|
|
368
|
+
|
|
369
|
+
for kind, value in self._trigger_conditions:
|
|
370
|
+
if kind == "messages" and len(messages) >= value:
|
|
371
|
+
return True
|
|
372
|
+
if kind == "tokens" and total_tokens >= value:
|
|
373
|
+
return True
|
|
374
|
+
if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
|
|
375
|
+
messages, value
|
|
376
|
+
):
|
|
377
|
+
return True
|
|
378
|
+
if kind == "fraction":
|
|
379
|
+
max_input_tokens = self._get_profile_limits()
|
|
380
|
+
if max_input_tokens is None:
|
|
381
|
+
continue
|
|
382
|
+
threshold = int(max_input_tokens * value)
|
|
383
|
+
if threshold <= 0:
|
|
384
|
+
threshold = 1
|
|
385
|
+
if total_tokens >= threshold:
|
|
386
|
+
return True
|
|
387
|
+
|
|
388
|
+
if self._should_summarize_based_on_reported_tokens(messages, threshold):
|
|
389
|
+
return True
|
|
390
|
+
return False
|
|
391
|
+
|
|
392
|
+
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
|
|
393
|
+
"""Choose cutoff index respecting retention configuration."""
|
|
394
|
+
kind, value = self.keep
|
|
395
|
+
if kind in {"tokens", "fraction"}:
|
|
396
|
+
token_based_cutoff = self._find_token_based_cutoff(messages)
|
|
397
|
+
if token_based_cutoff is not None:
|
|
398
|
+
return token_based_cutoff
|
|
399
|
+
# None cutoff -> model profile data not available (caught in __init__ but
|
|
400
|
+
# here for safety), fallback to message count
|
|
401
|
+
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
|
|
402
|
+
return self._find_safe_cutoff(messages, cast("int", value))
|
|
403
|
+
|
|
404
|
+
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
|
|
405
|
+
"""Find cutoff index based on target token retention."""
|
|
406
|
+
if not messages:
|
|
407
|
+
return 0
|
|
408
|
+
|
|
409
|
+
kind, value = self.keep
|
|
410
|
+
if kind == "fraction":
|
|
411
|
+
max_input_tokens = self._get_profile_limits()
|
|
412
|
+
if max_input_tokens is None:
|
|
413
|
+
return None
|
|
414
|
+
target_token_count = int(max_input_tokens * value)
|
|
415
|
+
elif kind == "tokens":
|
|
416
|
+
target_token_count = int(value)
|
|
417
|
+
else:
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
if target_token_count <= 0:
|
|
421
|
+
target_token_count = 1
|
|
422
|
+
|
|
423
|
+
if self.token_counter(messages) <= target_token_count:
|
|
424
|
+
return 0
|
|
425
|
+
|
|
426
|
+
# Use binary search to identify the earliest message index that keeps the
|
|
427
|
+
# suffix within the token budget.
|
|
428
|
+
left, right = 0, len(messages)
|
|
429
|
+
cutoff_candidate = len(messages)
|
|
430
|
+
max_iterations = len(messages).bit_length() + 1
|
|
431
|
+
for _ in range(max_iterations):
|
|
432
|
+
if left >= right:
|
|
433
|
+
break
|
|
434
|
+
|
|
435
|
+
mid = (left + right) // 2
|
|
436
|
+
if self.token_counter(messages[mid:]) <= target_token_count:
|
|
437
|
+
cutoff_candidate = mid
|
|
438
|
+
right = mid
|
|
439
|
+
else:
|
|
440
|
+
left = mid + 1
|
|
441
|
+
|
|
442
|
+
if cutoff_candidate == len(messages):
|
|
443
|
+
cutoff_candidate = left
|
|
444
|
+
|
|
445
|
+
if cutoff_candidate >= len(messages):
|
|
446
|
+
if len(messages) == 1:
|
|
447
|
+
return 0
|
|
448
|
+
cutoff_candidate = len(messages) - 1
|
|
449
|
+
|
|
450
|
+
# Advance past any ToolMessages to avoid splitting AI/Tool pairs
|
|
451
|
+
return self._find_safe_cutoff_point(messages, cutoff_candidate)
|
|
452
|
+
|
|
453
|
+
def _get_profile_limits(self) -> int | None:
|
|
454
|
+
"""Retrieve max input token limit from the model profile."""
|
|
455
|
+
try:
|
|
456
|
+
profile = self.model.profile
|
|
457
|
+
except AttributeError:
|
|
458
|
+
return None
|
|
459
|
+
|
|
460
|
+
if not isinstance(profile, Mapping):
|
|
461
|
+
return None
|
|
462
|
+
|
|
463
|
+
max_input_tokens = profile.get("max_input_tokens")
|
|
464
|
+
|
|
465
|
+
if not isinstance(max_input_tokens, int):
|
|
466
|
+
return None
|
|
467
|
+
|
|
468
|
+
return max_input_tokens
|
|
469
|
+
|
|
470
|
+
@staticmethod
|
|
471
|
+
def _validate_context_size(context: ContextSize, parameter_name: str) -> ContextSize:
|
|
472
|
+
"""Validate context configuration tuples."""
|
|
473
|
+
kind, value = context
|
|
474
|
+
if kind == "fraction":
|
|
475
|
+
if not 0 < value <= 1:
|
|
476
|
+
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
|
|
477
|
+
raise ValueError(msg)
|
|
478
|
+
elif kind in {"tokens", "messages"}:
|
|
479
|
+
if value <= 0:
|
|
480
|
+
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
|
|
481
|
+
raise ValueError(msg)
|
|
482
|
+
else:
|
|
483
|
+
msg = f"Unsupported context size type {kind} for {parameter_name}."
|
|
484
|
+
raise ValueError(msg)
|
|
485
|
+
return context
|
|
486
|
+
|
|
487
|
+
@staticmethod
|
|
488
|
+
def _build_new_messages(summary: str) -> list[HumanMessage]:
|
|
133
489
|
return [
|
|
134
|
-
HumanMessage(
|
|
490
|
+
HumanMessage(
|
|
491
|
+
content=f"Here is a summary of the conversation to date:\n\n{summary}",
|
|
492
|
+
additional_kwargs={"lc_source": "summarization"},
|
|
493
|
+
)
|
|
135
494
|
]
|
|
136
495
|
|
|
137
|
-
|
|
496
|
+
@staticmethod
|
|
497
|
+
def _ensure_message_ids(messages: list[AnyMessage]) -> None:
|
|
138
498
|
"""Ensure all messages have unique IDs for the add_messages reducer."""
|
|
139
499
|
for msg in messages:
|
|
140
500
|
if msg.id is None:
|
|
141
501
|
msg.id = str(uuid.uuid4())
|
|
142
502
|
|
|
503
|
+
@staticmethod
|
|
143
504
|
def _partition_messages(
|
|
144
|
-
self,
|
|
145
505
|
conversation_messages: list[AnyMessage],
|
|
146
506
|
cutoff_index: int,
|
|
147
507
|
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
|
@@ -151,74 +511,77 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
151
511
|
|
|
152
512
|
return messages_to_summarize, preserved_messages
|
|
153
513
|
|
|
154
|
-
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
|
|
514
|
+
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
|
|
155
515
|
"""Find safe cutoff point that preserves AI/Tool message pairs.
|
|
156
516
|
|
|
157
517
|
Returns the index where messages can be safely cut without separating
|
|
158
|
-
related AI and Tool messages. Returns 0 if no safe cutoff is found.
|
|
518
|
+
related AI and Tool messages. Returns `0` if no safe cutoff is found.
|
|
519
|
+
|
|
520
|
+
This is aggressive with summarization - if the target cutoff lands in the
|
|
521
|
+
middle of tool messages, we advance past all of them (summarizing more).
|
|
159
522
|
"""
|
|
160
|
-
if len(messages) <=
|
|
523
|
+
if len(messages) <= messages_to_keep:
|
|
161
524
|
return 0
|
|
162
525
|
|
|
163
|
-
target_cutoff = len(messages) -
|
|
526
|
+
target_cutoff = len(messages) - messages_to_keep
|
|
527
|
+
return self._find_safe_cutoff_point(messages, target_cutoff)
|
|
164
528
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
529
|
+
@staticmethod
|
|
530
|
+
def _find_safe_cutoff_point(messages: list[AnyMessage], cutoff_index: int) -> int:
|
|
531
|
+
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
|
|
168
532
|
|
|
169
|
-
|
|
533
|
+
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
|
|
534
|
+
`AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
|
|
535
|
+
include it. This ensures tool call requests and responses stay together.
|
|
170
536
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
537
|
+
Falls back to advancing forward past `ToolMessage` objects only if no matching
|
|
538
|
+
`AIMessage` is found (edge case).
|
|
539
|
+
"""
|
|
540
|
+
if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
|
|
541
|
+
return cutoff_index
|
|
542
|
+
|
|
543
|
+
# Collect tool_call_ids from consecutive ToolMessages at/after cutoff
|
|
544
|
+
tool_call_ids: set[str] = set()
|
|
545
|
+
idx = cutoff_index
|
|
546
|
+
while idx < len(messages) and isinstance(messages[idx], ToolMessage):
|
|
547
|
+
tool_msg = cast("ToolMessage", messages[idx])
|
|
548
|
+
if tool_msg.tool_call_id:
|
|
549
|
+
tool_call_ids.add(tool_msg.tool_call_id)
|
|
550
|
+
idx += 1
|
|
551
|
+
|
|
552
|
+
# Search backward for AIMessage with matching tool_calls
|
|
553
|
+
for i in range(cutoff_index - 1, -1, -1):
|
|
554
|
+
msg = messages[i]
|
|
555
|
+
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
556
|
+
ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
|
|
557
|
+
if tool_call_ids & ai_tool_call_ids:
|
|
558
|
+
# Found the AIMessage - move cutoff to include it
|
|
559
|
+
return i
|
|
560
|
+
|
|
561
|
+
# Fallback: no matching AIMessage found, advance past ToolMessages to avoid
|
|
562
|
+
# orphaned tool responses
|
|
563
|
+
return idx
|
|
186
564
|
|
|
187
|
-
|
|
565
|
+
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
566
|
+
"""Generate summary for the given messages."""
|
|
567
|
+
if not messages_to_summarize:
|
|
568
|
+
return "No previous conversation history."
|
|
188
569
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
|
|
193
|
-
)
|
|
570
|
+
trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
|
|
571
|
+
if not trimmed_messages:
|
|
572
|
+
return "Previous conversation was too long to summarize."
|
|
194
573
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
for tc in ai_message.tool_calls:
|
|
199
|
-
call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
|
|
200
|
-
if call_id is not None:
|
|
201
|
-
tool_call_ids.add(call_id)
|
|
202
|
-
return tool_call_ids
|
|
574
|
+
# Format messages to avoid token inflation from metadata when str() is called on
|
|
575
|
+
# message objects
|
|
576
|
+
formatted_messages = get_buffer_string(trimmed_messages)
|
|
203
577
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
tool_call_ids: set[str],
|
|
210
|
-
) -> bool:
|
|
211
|
-
"""Check if cutoff separates an AI message from its corresponding tool messages."""
|
|
212
|
-
for j in range(ai_message_index + 1, len(messages)):
|
|
213
|
-
message = messages[j]
|
|
214
|
-
if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
|
|
215
|
-
ai_before_cutoff = ai_message_index < cutoff_index
|
|
216
|
-
tool_before_cutoff = j < cutoff_index
|
|
217
|
-
if ai_before_cutoff != tool_before_cutoff:
|
|
218
|
-
return True
|
|
219
|
-
return False
|
|
578
|
+
try:
|
|
579
|
+
response = self.model.invoke(self.summary_prompt.format(messages=formatted_messages))
|
|
580
|
+
return response.text.strip()
|
|
581
|
+
except Exception as e:
|
|
582
|
+
return f"Error generating summary: {e!s}"
|
|
220
583
|
|
|
221
|
-
def
|
|
584
|
+
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
|
222
585
|
"""Generate summary for the given messages."""
|
|
223
586
|
if not messages_to_summarize:
|
|
224
587
|
return "No previous conversation history."
|
|
@@ -227,23 +590,34 @@ class SummarizationMiddleware(AgentMiddleware):
|
|
|
227
590
|
if not trimmed_messages:
|
|
228
591
|
return "Previous conversation was too long to summarize."
|
|
229
592
|
|
|
593
|
+
# Format messages to avoid token inflation from metadata when str() is called on
|
|
594
|
+
# message objects
|
|
595
|
+
formatted_messages = get_buffer_string(trimmed_messages)
|
|
596
|
+
|
|
230
597
|
try:
|
|
231
|
-
response = self.model.
|
|
232
|
-
|
|
233
|
-
|
|
598
|
+
response = await self.model.ainvoke(
|
|
599
|
+
self.summary_prompt.format(messages=formatted_messages)
|
|
600
|
+
)
|
|
601
|
+
return response.text.strip()
|
|
602
|
+
except Exception as e:
|
|
234
603
|
return f"Error generating summary: {e!s}"
|
|
235
604
|
|
|
236
605
|
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
|
237
606
|
"""Trim messages to fit within summary generation limits."""
|
|
238
607
|
try:
|
|
239
|
-
|
|
240
|
-
messages
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
608
|
+
if self.trim_tokens_to_summarize is None:
|
|
609
|
+
return messages
|
|
610
|
+
return cast(
|
|
611
|
+
"list[AnyMessage]",
|
|
612
|
+
trim_messages(
|
|
613
|
+
messages,
|
|
614
|
+
max_tokens=self.trim_tokens_to_summarize,
|
|
615
|
+
token_counter=self.token_counter,
|
|
616
|
+
start_on="human",
|
|
617
|
+
strategy="last",
|
|
618
|
+
allow_partial=True,
|
|
619
|
+
include_system=True,
|
|
620
|
+
),
|
|
247
621
|
)
|
|
248
|
-
except Exception:
|
|
622
|
+
except Exception:
|
|
249
623
|
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
|