langchain 1.0.5__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +1 -7
  3. langchain/agents/factory.py +153 -79
  4. langchain/agents/middleware/__init__.py +18 -23
  5. langchain/agents/middleware/_execution.py +29 -32
  6. langchain/agents/middleware/_redaction.py +108 -22
  7. langchain/agents/middleware/_retry.py +123 -0
  8. langchain/agents/middleware/context_editing.py +47 -25
  9. langchain/agents/middleware/file_search.py +19 -14
  10. langchain/agents/middleware/human_in_the_loop.py +87 -57
  11. langchain/agents/middleware/model_call_limit.py +64 -18
  12. langchain/agents/middleware/model_fallback.py +7 -9
  13. langchain/agents/middleware/model_retry.py +307 -0
  14. langchain/agents/middleware/pii.py +82 -29
  15. langchain/agents/middleware/shell_tool.py +254 -107
  16. langchain/agents/middleware/summarization.py +469 -95
  17. langchain/agents/middleware/todo.py +129 -31
  18. langchain/agents/middleware/tool_call_limit.py +105 -71
  19. langchain/agents/middleware/tool_emulator.py +47 -38
  20. langchain/agents/middleware/tool_retry.py +183 -164
  21. langchain/agents/middleware/tool_selection.py +81 -37
  22. langchain/agents/middleware/types.py +856 -427
  23. langchain/agents/structured_output.py +65 -42
  24. langchain/chat_models/__init__.py +1 -7
  25. langchain/chat_models/base.py +253 -196
  26. langchain/embeddings/__init__.py +0 -5
  27. langchain/embeddings/base.py +79 -65
  28. langchain/messages/__init__.py +0 -5
  29. langchain/tools/__init__.py +1 -7
  30. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/METADATA +5 -7
  31. langchain-1.2.4.dist-info/RECORD +36 -0
  32. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/WHEEL +1 -1
  33. langchain-1.0.5.dist-info/RECORD +0 -34
  34. {langchain-1.0.5.dist-info → langchain-1.2.4.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,10 @@
1
1
  """Summarization middleware."""
2
2
 
3
3
  import uuid
4
- from collections.abc import Callable, Iterable
5
- from typing import Any, cast
4
+ import warnings
5
+ from collections.abc import Callable, Iterable, Mapping
6
+ from functools import partial
7
+ from typing import Any, Literal, cast
6
8
 
7
9
  from langchain_core.messages import (
8
10
  AIMessage,
@@ -12,11 +14,16 @@ from langchain_core.messages import (
12
14
  ToolMessage,
13
15
  )
14
16
  from langchain_core.messages.human import HumanMessage
15
- from langchain_core.messages.utils import count_tokens_approximately, trim_messages
17
+ from langchain_core.messages.utils import (
18
+ count_tokens_approximately,
19
+ get_buffer_string,
20
+ trim_messages,
21
+ )
16
22
  from langgraph.graph.message import (
17
23
  REMOVE_ALL_MESSAGES,
18
24
  )
19
25
  from langgraph.runtime import Runtime
26
+ from typing_extensions import override
20
27
 
21
28
  from langchain.agents.middleware.types import AgentMiddleware, AgentState
22
29
  from langchain.chat_models import BaseChatModel, init_chat_model
@@ -51,12 +58,79 @@ Messages to summarize:
51
58
  {messages}
52
59
  </messages>""" # noqa: E501
53
60
 
54
- SUMMARY_PREFIX = "## Previous conversation summary:"
55
-
56
61
  _DEFAULT_MESSAGES_TO_KEEP = 20
57
62
  _DEFAULT_TRIM_TOKEN_LIMIT = 4000
58
63
  _DEFAULT_FALLBACK_MESSAGE_COUNT = 15
59
- _SEARCH_RANGE_FOR_TOOL_PAIRS = 5
64
+
65
+ ContextFraction = tuple[Literal["fraction"], float]
66
+ """Fraction of model's maximum input tokens.
67
+
68
+ Example:
69
+ To specify 50% of the model's max input tokens:
70
+
71
+ ```python
72
+ ("fraction", 0.5)
73
+ ```
74
+ """
75
+
76
+ ContextTokens = tuple[Literal["tokens"], int]
77
+ """Absolute number of tokens.
78
+
79
+ Example:
80
+ To specify 3000 tokens:
81
+
82
+ ```python
83
+ ("tokens", 3000)
84
+ ```
85
+ """
86
+
87
+ ContextMessages = tuple[Literal["messages"], int]
88
+ """Absolute number of messages.
89
+
90
+ Example:
91
+ To specify 50 messages:
92
+
93
+ ```python
94
+ ("messages", 50)
95
+ ```
96
+ """
97
+
98
+ ContextSize = ContextFraction | ContextTokens | ContextMessages
99
+ """Union type for context size specifications.
100
+
101
+ Can be either:
102
+
103
+ - [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
104
+ fraction of the model's maximum input tokens.
105
+ - [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
106
+ number of tokens.
107
+ - [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
108
+ absolute number of messages.
109
+
110
+ Depending on use with `trigger` or `keep` parameters, this type indicates either
111
+ when to trigger summarization or how much context to retain.
112
+
113
+ Example:
114
+ ```python
115
+ # ContextFraction
116
+ context_size: ContextSize = ("fraction", 0.5)
117
+
118
+ # ContextTokens
119
+ context_size: ContextSize = ("tokens", 3000)
120
+
121
+ # ContextMessages
122
+ context_size: ContextSize = ("messages", 50)
123
+ ```
124
+ """
125
+
126
+
127
+ def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
128
+ """Tune parameters of approximate token counter based on model type."""
129
+ if model._llm_type == "anthropic-chat": # noqa: SLF001
130
+ # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
131
+ # API: https://platform.claude.com/docs/en/build-with-claude/token-counting
132
+ return partial(count_tokens_approximately, chars_per_token=3.3)
133
+ return count_tokens_approximately
60
134
 
61
135
 
62
136
  class SummarizationMiddleware(AgentMiddleware):
@@ -70,48 +144,149 @@ class SummarizationMiddleware(AgentMiddleware):
70
144
  def __init__(
71
145
  self,
72
146
  model: str | BaseChatModel,
73
- max_tokens_before_summary: int | None = None,
74
- messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
147
+ *,
148
+ trigger: ContextSize | list[ContextSize] | None = None,
149
+ keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
75
150
  token_counter: TokenCounter = count_tokens_approximately,
76
151
  summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
77
- summary_prefix: str = SUMMARY_PREFIX,
152
+ trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
153
+ **deprecated_kwargs: Any,
78
154
  ) -> None:
79
- """Initialize the summarization middleware.
155
+ """Initialize summarization middleware.
80
156
 
81
157
  Args:
82
158
  model: The language model to use for generating summaries.
83
- max_tokens_before_summary: Token threshold to trigger summarization.
84
- If `None`, summarization is disabled.
85
- messages_to_keep: Number of recent messages to preserve after summarization.
159
+ trigger: One or more thresholds that trigger summarization.
160
+
161
+ Provide a single
162
+ [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
163
+ tuple or a list of tuples, in which case summarization runs when any
164
+ threshold is met.
165
+
166
+ !!! example
167
+
168
+ ```python
169
+ # Trigger summarization when 50 messages is reached
170
+ ("messages", 50)
171
+
172
+ # Trigger summarization when 3000 tokens is reached
173
+ ("tokens", 3000)
174
+
175
+ # Trigger summarization either when 80% of model's max input tokens
176
+ # is reached or when 100 messages is reached (whichever comes first)
177
+ [("fraction", 0.8), ("messages", 100)]
178
+ ```
179
+
180
+ See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
181
+ for more details.
182
+ keep: Context retention policy applied after summarization.
183
+
184
+ Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
185
+ tuple to specify how much history to preserve.
186
+
187
+ Defaults to keeping the most recent `20` messages.
188
+
189
+ Does not support multiple values like `trigger`.
190
+
191
+ !!! example
192
+
193
+ ```python
194
+ # Keep the most recent 20 messages
195
+ ("messages", 20)
196
+
197
+ # Keep the most recent 3000 tokens
198
+ ("tokens", 3000)
199
+
200
+ # Keep the most recent 30% of the model's max input tokens
201
+ ("fraction", 0.3)
202
+ ```
86
203
  token_counter: Function to count tokens in messages.
87
204
  summary_prompt: Prompt template for generating summaries.
88
- summary_prefix: Prefix added to system message when including summary.
205
+ trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
206
+ the summarization call.
207
+
208
+ Pass `None` to skip trimming entirely.
89
209
  """
210
+ # Handle deprecated parameters
211
+ if "max_tokens_before_summary" in deprecated_kwargs:
212
+ value = deprecated_kwargs["max_tokens_before_summary"]
213
+ warnings.warn(
214
+ "max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
215
+ DeprecationWarning,
216
+ stacklevel=2,
217
+ )
218
+ if trigger is None and value is not None:
219
+ trigger = ("tokens", value)
220
+
221
+ if "messages_to_keep" in deprecated_kwargs:
222
+ value = deprecated_kwargs["messages_to_keep"]
223
+ warnings.warn(
224
+ "messages_to_keep is deprecated. Use keep=('messages', value) instead.",
225
+ DeprecationWarning,
226
+ stacklevel=2,
227
+ )
228
+ if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
229
+ keep = ("messages", value)
230
+
90
231
  super().__init__()
91
232
 
92
233
  if isinstance(model, str):
93
234
  model = init_chat_model(model)
94
235
 
95
236
  self.model = model
96
- self.max_tokens_before_summary = max_tokens_before_summary
97
- self.messages_to_keep = messages_to_keep
98
- self.token_counter = token_counter
237
+ if trigger is None:
238
+ self.trigger: ContextSize | list[ContextSize] | None = None
239
+ trigger_conditions: list[ContextSize] = []
240
+ elif isinstance(trigger, list):
241
+ validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
242
+ self.trigger = validated_list
243
+ trigger_conditions = validated_list
244
+ else:
245
+ validated = self._validate_context_size(trigger, "trigger")
246
+ self.trigger = validated
247
+ trigger_conditions = [validated]
248
+ self._trigger_conditions = trigger_conditions
249
+
250
+ self.keep = self._validate_context_size(keep, "keep")
251
+ if token_counter is count_tokens_approximately:
252
+ self.token_counter = _get_approximate_token_counter(self.model)
253
+ else:
254
+ self.token_counter = token_counter
99
255
  self.summary_prompt = summary_prompt
100
- self.summary_prefix = summary_prefix
256
+ self.trim_tokens_to_summarize = trim_tokens_to_summarize
257
+
258
+ requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
259
+ if self.keep[0] == "fraction":
260
+ requires_profile = True
261
+ if requires_profile and self._get_profile_limits() is None:
262
+ msg = (
263
+ "Model profile information is required to use fractional token limits, "
264
+ "and is unavailable for the specified model. Please use absolute token "
265
+ "counts instead, or pass "
266
+ '`\n\nChatModel(..., profile={"max_input_tokens": ...})`.\n\n'
267
+ "with a desired integer value of the model's maximum input tokens."
268
+ )
269
+ raise ValueError(msg)
270
+
271
+ @override
272
+ def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
273
+ """Process messages before model invocation, potentially triggering summarization.
101
274
 
102
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
103
- """Process messages before model invocation, potentially triggering summarization."""
275
+ Args:
276
+ state: The agent state.
277
+ runtime: The runtime environment.
278
+
279
+ Returns:
280
+ An updated state with summarized messages if summarization was performed.
281
+ """
104
282
  messages = state["messages"]
105
283
  self._ensure_message_ids(messages)
106
284
 
107
285
  total_tokens = self.token_counter(messages)
108
- if (
109
- self.max_tokens_before_summary is not None
110
- and total_tokens < self.max_tokens_before_summary
111
- ):
286
+ if not self._should_summarize(messages, total_tokens):
112
287
  return None
113
288
 
114
- cutoff_index = self._find_safe_cutoff(messages)
289
+ cutoff_index = self._determine_cutoff_index(messages)
115
290
 
116
291
  if cutoff_index <= 0:
117
292
  return None
@@ -129,19 +304,204 @@ class SummarizationMiddleware(AgentMiddleware):
129
304
  ]
130
305
  }
131
306
 
132
- def _build_new_messages(self, summary: str) -> list[HumanMessage]:
307
+ @override
308
+ async def abefore_model(
309
+ self, state: AgentState[Any], runtime: Runtime
310
+ ) -> dict[str, Any] | None:
311
+ """Process messages before model invocation, potentially triggering summarization.
312
+
313
+ Args:
314
+ state: The agent state.
315
+ runtime: The runtime environment.
316
+
317
+ Returns:
318
+ An updated state with summarized messages if summarization was performed.
319
+ """
320
+ messages = state["messages"]
321
+ self._ensure_message_ids(messages)
322
+
323
+ total_tokens = self.token_counter(messages)
324
+ if not self._should_summarize(messages, total_tokens):
325
+ return None
326
+
327
+ cutoff_index = self._determine_cutoff_index(messages)
328
+
329
+ if cutoff_index <= 0:
330
+ return None
331
+
332
+ messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
333
+
334
+ summary = await self._acreate_summary(messages_to_summarize)
335
+ new_messages = self._build_new_messages(summary)
336
+
337
+ return {
338
+ "messages": [
339
+ RemoveMessage(id=REMOVE_ALL_MESSAGES),
340
+ *new_messages,
341
+ *preserved_messages,
342
+ ]
343
+ }
344
+
345
+ def _should_summarize_based_on_reported_tokens(
346
+ self, messages: list[AnyMessage], threshold: float
347
+ ) -> bool:
348
+ """Check if reported token usage from last AIMessage exceeds threshold."""
349
+ last_ai_message = next(
350
+ (msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
351
+ None,
352
+ )
353
+ if ( # noqa: SIM103
354
+ isinstance(last_ai_message, AIMessage)
355
+ and last_ai_message.usage_metadata is not None
356
+ and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
357
+ and reported_tokens >= threshold
358
+ and (message_provider := last_ai_message.response_metadata.get("model_provider"))
359
+ and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
360
+ ):
361
+ return True
362
+ return False
363
+
364
+ def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
365
+ """Determine whether summarization should run for the current token usage."""
366
+ if not self._trigger_conditions:
367
+ return False
368
+
369
+ for kind, value in self._trigger_conditions:
370
+ if kind == "messages" and len(messages) >= value:
371
+ return True
372
+ if kind == "tokens" and total_tokens >= value:
373
+ return True
374
+ if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
375
+ messages, value
376
+ ):
377
+ return True
378
+ if kind == "fraction":
379
+ max_input_tokens = self._get_profile_limits()
380
+ if max_input_tokens is None:
381
+ continue
382
+ threshold = int(max_input_tokens * value)
383
+ if threshold <= 0:
384
+ threshold = 1
385
+ if total_tokens >= threshold:
386
+ return True
387
+
388
+ if self._should_summarize_based_on_reported_tokens(messages, threshold):
389
+ return True
390
+ return False
391
+
392
+ def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
393
+ """Choose cutoff index respecting retention configuration."""
394
+ kind, value = self.keep
395
+ if kind in {"tokens", "fraction"}:
396
+ token_based_cutoff = self._find_token_based_cutoff(messages)
397
+ if token_based_cutoff is not None:
398
+ return token_based_cutoff
399
+ # None cutoff -> model profile data not available (caught in __init__ but
400
+ # here for safety), fallback to message count
401
+ return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
402
+ return self._find_safe_cutoff(messages, cast("int", value))
403
+
404
+ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
405
+ """Find cutoff index based on target token retention."""
406
+ if not messages:
407
+ return 0
408
+
409
+ kind, value = self.keep
410
+ if kind == "fraction":
411
+ max_input_tokens = self._get_profile_limits()
412
+ if max_input_tokens is None:
413
+ return None
414
+ target_token_count = int(max_input_tokens * value)
415
+ elif kind == "tokens":
416
+ target_token_count = int(value)
417
+ else:
418
+ return None
419
+
420
+ if target_token_count <= 0:
421
+ target_token_count = 1
422
+
423
+ if self.token_counter(messages) <= target_token_count:
424
+ return 0
425
+
426
+ # Use binary search to identify the earliest message index that keeps the
427
+ # suffix within the token budget.
428
+ left, right = 0, len(messages)
429
+ cutoff_candidate = len(messages)
430
+ max_iterations = len(messages).bit_length() + 1
431
+ for _ in range(max_iterations):
432
+ if left >= right:
433
+ break
434
+
435
+ mid = (left + right) // 2
436
+ if self.token_counter(messages[mid:]) <= target_token_count:
437
+ cutoff_candidate = mid
438
+ right = mid
439
+ else:
440
+ left = mid + 1
441
+
442
+ if cutoff_candidate == len(messages):
443
+ cutoff_candidate = left
444
+
445
+ if cutoff_candidate >= len(messages):
446
+ if len(messages) == 1:
447
+ return 0
448
+ cutoff_candidate = len(messages) - 1
449
+
450
+ # Advance past any ToolMessages to avoid splitting AI/Tool pairs
451
+ return self._find_safe_cutoff_point(messages, cutoff_candidate)
452
+
453
+ def _get_profile_limits(self) -> int | None:
454
+ """Retrieve max input token limit from the model profile."""
455
+ try:
456
+ profile = self.model.profile
457
+ except AttributeError:
458
+ return None
459
+
460
+ if not isinstance(profile, Mapping):
461
+ return None
462
+
463
+ max_input_tokens = profile.get("max_input_tokens")
464
+
465
+ if not isinstance(max_input_tokens, int):
466
+ return None
467
+
468
+ return max_input_tokens
469
+
470
+ @staticmethod
471
+ def _validate_context_size(context: ContextSize, parameter_name: str) -> ContextSize:
472
+ """Validate context configuration tuples."""
473
+ kind, value = context
474
+ if kind == "fraction":
475
+ if not 0 < value <= 1:
476
+ msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
477
+ raise ValueError(msg)
478
+ elif kind in {"tokens", "messages"}:
479
+ if value <= 0:
480
+ msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
481
+ raise ValueError(msg)
482
+ else:
483
+ msg = f"Unsupported context size type {kind} for {parameter_name}."
484
+ raise ValueError(msg)
485
+ return context
486
+
487
+ @staticmethod
488
+ def _build_new_messages(summary: str) -> list[HumanMessage]:
133
489
  return [
134
- HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
490
+ HumanMessage(
491
+ content=f"Here is a summary of the conversation to date:\n\n{summary}",
492
+ additional_kwargs={"lc_source": "summarization"},
493
+ )
135
494
  ]
136
495
 
137
- def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
496
+ @staticmethod
497
+ def _ensure_message_ids(messages: list[AnyMessage]) -> None:
138
498
  """Ensure all messages have unique IDs for the add_messages reducer."""
139
499
  for msg in messages:
140
500
  if msg.id is None:
141
501
  msg.id = str(uuid.uuid4())
142
502
 
503
+ @staticmethod
143
504
  def _partition_messages(
144
- self,
145
505
  conversation_messages: list[AnyMessage],
146
506
  cutoff_index: int,
147
507
  ) -> tuple[list[AnyMessage], list[AnyMessage]]:
@@ -151,74 +511,77 @@ class SummarizationMiddleware(AgentMiddleware):
151
511
 
152
512
  return messages_to_summarize, preserved_messages
153
513
 
154
- def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
514
+ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
155
515
  """Find safe cutoff point that preserves AI/Tool message pairs.
156
516
 
157
517
  Returns the index where messages can be safely cut without separating
158
- related AI and Tool messages. Returns 0 if no safe cutoff is found.
518
+ related AI and Tool messages. Returns `0` if no safe cutoff is found.
519
+
520
+ This is aggressive with summarization - if the target cutoff lands in the
521
+ middle of tool messages, we advance past all of them (summarizing more).
159
522
  """
160
- if len(messages) <= self.messages_to_keep:
523
+ if len(messages) <= messages_to_keep:
161
524
  return 0
162
525
 
163
- target_cutoff = len(messages) - self.messages_to_keep
526
+ target_cutoff = len(messages) - messages_to_keep
527
+ return self._find_safe_cutoff_point(messages, target_cutoff)
164
528
 
165
- for i in range(target_cutoff, -1, -1):
166
- if self._is_safe_cutoff_point(messages, i):
167
- return i
529
+ @staticmethod
530
+ def _find_safe_cutoff_point(messages: list[AnyMessage], cutoff_index: int) -> int:
531
+ """Find a safe cutoff point that doesn't split AI/Tool message pairs.
168
532
 
169
- return 0
533
+ If the message at `cutoff_index` is a `ToolMessage`, search backward for the
534
+ `AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
535
+ include it. This ensures tool call requests and responses stay together.
170
536
 
171
- def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
172
- """Check if cutting at index would separate AI/Tool message pairs."""
173
- if cutoff_index >= len(messages):
174
- return True
175
-
176
- search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
177
- search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
178
-
179
- for i in range(search_start, search_end):
180
- if not self._has_tool_calls(messages[i]):
181
- continue
182
-
183
- tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
184
- if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
185
- return False
537
+ Falls back to advancing forward past `ToolMessage` objects only if no matching
538
+ `AIMessage` is found (edge case).
539
+ """
540
+ if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
541
+ return cutoff_index
542
+
543
+ # Collect tool_call_ids from consecutive ToolMessages at/after cutoff
544
+ tool_call_ids: set[str] = set()
545
+ idx = cutoff_index
546
+ while idx < len(messages) and isinstance(messages[idx], ToolMessage):
547
+ tool_msg = cast("ToolMessage", messages[idx])
548
+ if tool_msg.tool_call_id:
549
+ tool_call_ids.add(tool_msg.tool_call_id)
550
+ idx += 1
551
+
552
+ # Search backward for AIMessage with matching tool_calls
553
+ for i in range(cutoff_index - 1, -1, -1):
554
+ msg = messages[i]
555
+ if isinstance(msg, AIMessage) and msg.tool_calls:
556
+ ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
557
+ if tool_call_ids & ai_tool_call_ids:
558
+ # Found the AIMessage - move cutoff to include it
559
+ return i
560
+
561
+ # Fallback: no matching AIMessage found, advance past ToolMessages to avoid
562
+ # orphaned tool responses
563
+ return idx
186
564
 
187
- return True
565
+ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
566
+ """Generate summary for the given messages."""
567
+ if not messages_to_summarize:
568
+ return "No previous conversation history."
188
569
 
189
- def _has_tool_calls(self, message: AnyMessage) -> bool:
190
- """Check if message is an AI message with tool calls."""
191
- return (
192
- isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
193
- )
570
+ trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
571
+ if not trimmed_messages:
572
+ return "Previous conversation was too long to summarize."
194
573
 
195
- def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
196
- """Extract tool call IDs from an AI message."""
197
- tool_call_ids = set()
198
- for tc in ai_message.tool_calls:
199
- call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
200
- if call_id is not None:
201
- tool_call_ids.add(call_id)
202
- return tool_call_ids
574
+ # Format messages to avoid token inflation from metadata when str() is called on
575
+ # message objects
576
+ formatted_messages = get_buffer_string(trimmed_messages)
203
577
 
204
- def _cutoff_separates_tool_pair(
205
- self,
206
- messages: list[AnyMessage],
207
- ai_message_index: int,
208
- cutoff_index: int,
209
- tool_call_ids: set[str],
210
- ) -> bool:
211
- """Check if cutoff separates an AI message from its corresponding tool messages."""
212
- for j in range(ai_message_index + 1, len(messages)):
213
- message = messages[j]
214
- if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
215
- ai_before_cutoff = ai_message_index < cutoff_index
216
- tool_before_cutoff = j < cutoff_index
217
- if ai_before_cutoff != tool_before_cutoff:
218
- return True
219
- return False
578
+ try:
579
+ response = self.model.invoke(self.summary_prompt.format(messages=formatted_messages))
580
+ return response.text.strip()
581
+ except Exception as e:
582
+ return f"Error generating summary: {e!s}"
220
583
 
221
- def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
584
+ async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
222
585
  """Generate summary for the given messages."""
223
586
  if not messages_to_summarize:
224
587
  return "No previous conversation history."
@@ -227,23 +590,34 @@ class SummarizationMiddleware(AgentMiddleware):
227
590
  if not trimmed_messages:
228
591
  return "Previous conversation was too long to summarize."
229
592
 
593
+ # Format messages to avoid token inflation from metadata when str() is called on
594
+ # message objects
595
+ formatted_messages = get_buffer_string(trimmed_messages)
596
+
230
597
  try:
231
- response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
232
- return cast("str", response.content).strip()
233
- except Exception as e: # noqa: BLE001
598
+ response = await self.model.ainvoke(
599
+ self.summary_prompt.format(messages=formatted_messages)
600
+ )
601
+ return response.text.strip()
602
+ except Exception as e:
234
603
  return f"Error generating summary: {e!s}"
235
604
 
236
605
  def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
237
606
  """Trim messages to fit within summary generation limits."""
238
607
  try:
239
- return trim_messages(
240
- messages,
241
- max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
242
- token_counter=self.token_counter,
243
- start_on="human",
244
- strategy="last",
245
- allow_partial=True,
246
- include_system=True,
608
+ if self.trim_tokens_to_summarize is None:
609
+ return messages
610
+ return cast(
611
+ "list[AnyMessage]",
612
+ trim_messages(
613
+ messages,
614
+ max_tokens=self.trim_tokens_to_summarize,
615
+ token_counter=self.token_counter,
616
+ start_on="human",
617
+ strategy="last",
618
+ allow_partial=True,
619
+ include_system=True,
620
+ ),
247
621
  )
248
- except Exception: # noqa: BLE001
622
+ except Exception:
249
623
  return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]