langchain 1.0.4__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. langchain/__init__.py +1 -1
  2. langchain/agents/__init__.py +1 -7
  3. langchain/agents/factory.py +100 -41
  4. langchain/agents/middleware/__init__.py +5 -7
  5. langchain/agents/middleware/_execution.py +21 -20
  6. langchain/agents/middleware/_redaction.py +27 -12
  7. langchain/agents/middleware/_retry.py +123 -0
  8. langchain/agents/middleware/context_editing.py +26 -22
  9. langchain/agents/middleware/file_search.py +18 -13
  10. langchain/agents/middleware/human_in_the_loop.py +60 -54
  11. langchain/agents/middleware/model_call_limit.py +63 -17
  12. langchain/agents/middleware/model_fallback.py +7 -9
  13. langchain/agents/middleware/model_retry.py +300 -0
  14. langchain/agents/middleware/pii.py +80 -27
  15. langchain/agents/middleware/shell_tool.py +230 -103
  16. langchain/agents/middleware/summarization.py +439 -90
  17. langchain/agents/middleware/todo.py +111 -27
  18. langchain/agents/middleware/tool_call_limit.py +105 -71
  19. langchain/agents/middleware/tool_emulator.py +42 -33
  20. langchain/agents/middleware/tool_retry.py +171 -159
  21. langchain/agents/middleware/tool_selection.py +37 -27
  22. langchain/agents/middleware/types.py +754 -392
  23. langchain/agents/structured_output.py +22 -12
  24. langchain/chat_models/__init__.py +1 -7
  25. langchain/chat_models/base.py +234 -185
  26. langchain/embeddings/__init__.py +0 -5
  27. langchain/embeddings/base.py +80 -66
  28. langchain/messages/__init__.py +0 -5
  29. langchain/tools/__init__.py +1 -7
  30. {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/METADATA +3 -5
  31. langchain-1.2.3.dist-info/RECORD +36 -0
  32. {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/WHEEL +1 -1
  33. langchain-1.0.4.dist-info/RECORD +0 -34
  34. {langchain-1.0.4.dist-info → langchain-1.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,10 @@
1
1
  """Summarization middleware."""
2
2
 
3
3
  import uuid
4
- from collections.abc import Callable, Iterable
5
- from typing import Any, cast
4
+ import warnings
5
+ from collections.abc import Callable, Iterable, Mapping
6
+ from functools import partial
7
+ from typing import Any, Literal, cast
6
8
 
7
9
  from langchain_core.messages import (
8
10
  AIMessage,
@@ -12,11 +14,16 @@ from langchain_core.messages import (
12
14
  ToolMessage,
13
15
  )
14
16
  from langchain_core.messages.human import HumanMessage
15
- from langchain_core.messages.utils import count_tokens_approximately, trim_messages
17
+ from langchain_core.messages.utils import (
18
+ count_tokens_approximately,
19
+ get_buffer_string,
20
+ trim_messages,
21
+ )
16
22
  from langgraph.graph.message import (
17
23
  REMOVE_ALL_MESSAGES,
18
24
  )
19
25
  from langgraph.runtime import Runtime
26
+ from typing_extensions import override
20
27
 
21
28
  from langchain.agents.middleware.types import AgentMiddleware, AgentState
22
29
  from langchain.chat_models import BaseChatModel, init_chat_model
@@ -51,12 +58,79 @@ Messages to summarize:
51
58
  {messages}
52
59
  </messages>""" # noqa: E501
53
60
 
54
- SUMMARY_PREFIX = "## Previous conversation summary:"
55
-
56
61
  _DEFAULT_MESSAGES_TO_KEEP = 20
57
62
  _DEFAULT_TRIM_TOKEN_LIMIT = 4000
58
63
  _DEFAULT_FALLBACK_MESSAGE_COUNT = 15
59
- _SEARCH_RANGE_FOR_TOOL_PAIRS = 5
64
+
65
+ ContextFraction = tuple[Literal["fraction"], float]
66
+ """Fraction of model's maximum input tokens.
67
+
68
+ Example:
69
+ To specify 50% of the model's max input tokens:
70
+
71
+ ```python
72
+ ("fraction", 0.5)
73
+ ```
74
+ """
75
+
76
+ ContextTokens = tuple[Literal["tokens"], int]
77
+ """Absolute number of tokens.
78
+
79
+ Example:
80
+ To specify 3000 tokens:
81
+
82
+ ```python
83
+ ("tokens", 3000)
84
+ ```
85
+ """
86
+
87
+ ContextMessages = tuple[Literal["messages"], int]
88
+ """Absolute number of messages.
89
+
90
+ Example:
91
+ To specify 50 messages:
92
+
93
+ ```python
94
+ ("messages", 50)
95
+ ```
96
+ """
97
+
98
+ ContextSize = ContextFraction | ContextTokens | ContextMessages
99
+ """Union type for context size specifications.
100
+
101
+ Can be either:
102
+
103
+ - [`ContextFraction`][langchain.agents.middleware.summarization.ContextFraction]: A
104
+ fraction of the model's maximum input tokens.
105
+ - [`ContextTokens`][langchain.agents.middleware.summarization.ContextTokens]: An absolute
106
+ number of tokens.
107
+ - [`ContextMessages`][langchain.agents.middleware.summarization.ContextMessages]: An
108
+ absolute number of messages.
109
+
110
+ Depending on use with `trigger` or `keep` parameters, this type indicates either
111
+ when to trigger summarization or how much context to retain.
112
+
113
+ Example:
114
+ ```python
115
+ # ContextFraction
116
+ context_size: ContextSize = ("fraction", 0.5)
117
+
118
+ # ContextTokens
119
+ context_size: ContextSize = ("tokens", 3000)
120
+
121
+ # ContextMessages
122
+ context_size: ContextSize = ("messages", 50)
123
+ ```
124
+ """
125
+
126
+
127
+ def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
128
+ """Tune parameters of approximate token counter based on model type."""
129
+ if model._llm_type == "anthropic-chat": # noqa: SLF001
130
+ # 3.3 was estimated in an offline experiment, comparing with Claude's token-counting
131
+ # API: https://platform.claude.com/docs/en/build-with-claude/token-counting
132
+ return partial(count_tokens_approximately, chars_per_token=3.3)
133
+ return count_tokens_approximately
60
134
 
61
135
 
62
136
  class SummarizationMiddleware(AgentMiddleware):
@@ -70,48 +144,141 @@ class SummarizationMiddleware(AgentMiddleware):
70
144
  def __init__(
71
145
  self,
72
146
  model: str | BaseChatModel,
73
- max_tokens_before_summary: int | None = None,
74
- messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
147
+ *,
148
+ trigger: ContextSize | list[ContextSize] | None = None,
149
+ keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
75
150
  token_counter: TokenCounter = count_tokens_approximately,
76
151
  summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
77
- summary_prefix: str = SUMMARY_PREFIX,
152
+ trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
153
+ **deprecated_kwargs: Any,
78
154
  ) -> None:
79
- """Initialize the summarization middleware.
155
+ """Initialize summarization middleware.
80
156
 
81
157
  Args:
82
158
  model: The language model to use for generating summaries.
83
- max_tokens_before_summary: Token threshold to trigger summarization.
84
- If `None`, summarization is disabled.
85
- messages_to_keep: Number of recent messages to preserve after summarization.
159
+ trigger: One or more thresholds that trigger summarization.
160
+
161
+ Provide a single
162
+ [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
163
+ tuple or a list of tuples, in which case summarization runs when any
164
+ threshold is met.
165
+
166
+ !!! example
167
+
168
+ ```python
169
+ # Trigger summarization when 50 messages is reached
170
+ ("messages", 50)
171
+
172
+ # Trigger summarization when 3000 tokens is reached
173
+ ("tokens", 3000)
174
+
175
+ # Trigger summarization either when 80% of model's max input tokens
176
+ # is reached or when 100 messages is reached (whichever comes first)
177
+ [("fraction", 0.8), ("messages", 100)]
178
+ ```
179
+
180
+ See [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
181
+ for more details.
182
+ keep: Context retention policy applied after summarization.
183
+
184
+ Provide a [`ContextSize`][langchain.agents.middleware.summarization.ContextSize]
185
+ tuple to specify how much history to preserve.
186
+
187
+ Defaults to keeping the most recent `20` messages.
188
+
189
+ Does not support multiple values like `trigger`.
190
+
191
+ !!! example
192
+
193
+ ```python
194
+ # Keep the most recent 20 messages
195
+ ("messages", 20)
196
+
197
+ # Keep the most recent 3000 tokens
198
+ ("tokens", 3000)
199
+
200
+ # Keep the most recent 30% of the model's max input tokens
201
+ ("fraction", 0.3)
202
+ ```
86
203
  token_counter: Function to count tokens in messages.
87
204
  summary_prompt: Prompt template for generating summaries.
88
- summary_prefix: Prefix added to system message when including summary.
205
+ trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for
206
+ the summarization call.
207
+
208
+ Pass `None` to skip trimming entirely.
89
209
  """
210
+ # Handle deprecated parameters
211
+ if "max_tokens_before_summary" in deprecated_kwargs:
212
+ value = deprecated_kwargs["max_tokens_before_summary"]
213
+ warnings.warn(
214
+ "max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
215
+ DeprecationWarning,
216
+ stacklevel=2,
217
+ )
218
+ if trigger is None and value is not None:
219
+ trigger = ("tokens", value)
220
+
221
+ if "messages_to_keep" in deprecated_kwargs:
222
+ value = deprecated_kwargs["messages_to_keep"]
223
+ warnings.warn(
224
+ "messages_to_keep is deprecated. Use keep=('messages', value) instead.",
225
+ DeprecationWarning,
226
+ stacklevel=2,
227
+ )
228
+ if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
229
+ keep = ("messages", value)
230
+
90
231
  super().__init__()
91
232
 
92
233
  if isinstance(model, str):
93
234
  model = init_chat_model(model)
94
235
 
95
236
  self.model = model
96
- self.max_tokens_before_summary = max_tokens_before_summary
97
- self.messages_to_keep = messages_to_keep
98
- self.token_counter = token_counter
237
+ if trigger is None:
238
+ self.trigger: ContextSize | list[ContextSize] | None = None
239
+ trigger_conditions: list[ContextSize] = []
240
+ elif isinstance(trigger, list):
241
+ validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
242
+ self.trigger = validated_list
243
+ trigger_conditions = validated_list
244
+ else:
245
+ validated = self._validate_context_size(trigger, "trigger")
246
+ self.trigger = validated
247
+ trigger_conditions = [validated]
248
+ self._trigger_conditions = trigger_conditions
249
+
250
+ self.keep = self._validate_context_size(keep, "keep")
251
+ if token_counter is count_tokens_approximately:
252
+ self.token_counter = _get_approximate_token_counter(self.model)
253
+ else:
254
+ self.token_counter = token_counter
99
255
  self.summary_prompt = summary_prompt
100
- self.summary_prefix = summary_prefix
256
+ self.trim_tokens_to_summarize = trim_tokens_to_summarize
257
+
258
+ requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
259
+ if self.keep[0] == "fraction":
260
+ requires_profile = True
261
+ if requires_profile and self._get_profile_limits() is None:
262
+ msg = (
263
+ "Model profile information is required to use fractional token limits, "
264
+ "and is unavailable for the specified model. Please use absolute token "
265
+ "counts instead, or pass "
266
+ '`\n\nChatModel(..., profile={"max_input_tokens": ...})`.\n\n'
267
+ "with a desired integer value of the model's maximum input tokens."
268
+ )
269
+ raise ValueError(msg)
101
270
 
102
- def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
271
+ @override
272
+ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
103
273
  """Process messages before model invocation, potentially triggering summarization."""
104
274
  messages = state["messages"]
105
275
  self._ensure_message_ids(messages)
106
276
 
107
277
  total_tokens = self.token_counter(messages)
108
- if (
109
- self.max_tokens_before_summary is not None
110
- and total_tokens < self.max_tokens_before_summary
111
- ):
278
+ if not self._should_summarize(messages, total_tokens):
112
279
  return None
113
280
 
114
- cutoff_index = self._find_safe_cutoff(messages)
281
+ cutoff_index = self._determine_cutoff_index(messages)
115
282
 
116
283
  if cutoff_index <= 0:
117
284
  return None
@@ -129,6 +296,175 @@ class SummarizationMiddleware(AgentMiddleware):
129
296
  ]
130
297
  }
131
298
 
299
+ @override
300
+ async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
301
+ """Process messages before model invocation, potentially triggering summarization."""
302
+ messages = state["messages"]
303
+ self._ensure_message_ids(messages)
304
+
305
+ total_tokens = self.token_counter(messages)
306
+ if not self._should_summarize(messages, total_tokens):
307
+ return None
308
+
309
+ cutoff_index = self._determine_cutoff_index(messages)
310
+
311
+ if cutoff_index <= 0:
312
+ return None
313
+
314
+ messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
315
+
316
+ summary = await self._acreate_summary(messages_to_summarize)
317
+ new_messages = self._build_new_messages(summary)
318
+
319
+ return {
320
+ "messages": [
321
+ RemoveMessage(id=REMOVE_ALL_MESSAGES),
322
+ *new_messages,
323
+ *preserved_messages,
324
+ ]
325
+ }
326
+
327
+ def _should_summarize_based_on_reported_tokens(
328
+ self, messages: list[AnyMessage], threshold: float
329
+ ) -> bool:
330
+ """Check if reported token usage from last AIMessage exceeds threshold."""
331
+ last_ai_message = next(
332
+ (msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
333
+ None,
334
+ )
335
+ if ( # noqa: SIM103
336
+ isinstance(last_ai_message, AIMessage)
337
+ and last_ai_message.usage_metadata is not None
338
+ and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
339
+ and reported_tokens >= threshold
340
+ and (message_provider := last_ai_message.response_metadata.get("model_provider"))
341
+ and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
342
+ ):
343
+ return True
344
+ return False
345
+
346
+ def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
347
+ """Determine whether summarization should run for the current token usage."""
348
+ if not self._trigger_conditions:
349
+ return False
350
+
351
+ for kind, value in self._trigger_conditions:
352
+ if kind == "messages" and len(messages) >= value:
353
+ return True
354
+ if kind == "tokens" and total_tokens >= value:
355
+ return True
356
+ if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
357
+ messages, value
358
+ ):
359
+ return True
360
+ if kind == "fraction":
361
+ max_input_tokens = self._get_profile_limits()
362
+ if max_input_tokens is None:
363
+ continue
364
+ threshold = int(max_input_tokens * value)
365
+ if threshold <= 0:
366
+ threshold = 1
367
+ if total_tokens >= threshold:
368
+ return True
369
+
370
+ if self._should_summarize_based_on_reported_tokens(messages, threshold):
371
+ return True
372
+ return False
373
+
374
+ def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
375
+ """Choose cutoff index respecting retention configuration."""
376
+ kind, value = self.keep
377
+ if kind in {"tokens", "fraction"}:
378
+ token_based_cutoff = self._find_token_based_cutoff(messages)
379
+ if token_based_cutoff is not None:
380
+ return token_based_cutoff
381
+ # None cutoff -> model profile data not available (caught in __init__ but
382
+ # here for safety), fallback to message count
383
+ return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
384
+ return self._find_safe_cutoff(messages, cast("int", value))
385
+
386
+ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
387
+ """Find cutoff index based on target token retention."""
388
+ if not messages:
389
+ return 0
390
+
391
+ kind, value = self.keep
392
+ if kind == "fraction":
393
+ max_input_tokens = self._get_profile_limits()
394
+ if max_input_tokens is None:
395
+ return None
396
+ target_token_count = int(max_input_tokens * value)
397
+ elif kind == "tokens":
398
+ target_token_count = int(value)
399
+ else:
400
+ return None
401
+
402
+ if target_token_count <= 0:
403
+ target_token_count = 1
404
+
405
+ if self.token_counter(messages) <= target_token_count:
406
+ return 0
407
+
408
+ # Use binary search to identify the earliest message index that keeps the
409
+ # suffix within the token budget.
410
+ left, right = 0, len(messages)
411
+ cutoff_candidate = len(messages)
412
+ max_iterations = len(messages).bit_length() + 1
413
+ for _ in range(max_iterations):
414
+ if left >= right:
415
+ break
416
+
417
+ mid = (left + right) // 2
418
+ if self.token_counter(messages[mid:]) <= target_token_count:
419
+ cutoff_candidate = mid
420
+ right = mid
421
+ else:
422
+ left = mid + 1
423
+
424
+ if cutoff_candidate == len(messages):
425
+ cutoff_candidate = left
426
+
427
+ if cutoff_candidate >= len(messages):
428
+ if len(messages) == 1:
429
+ return 0
430
+ cutoff_candidate = len(messages) - 1
431
+
432
+ # Advance past any ToolMessages to avoid splitting AI/Tool pairs
433
+ return self._find_safe_cutoff_point(messages, cutoff_candidate)
434
+
435
+ def _get_profile_limits(self) -> int | None:
436
+ """Retrieve max input token limit from the model profile."""
437
+ try:
438
+ profile = self.model.profile
439
+ except AttributeError:
440
+ return None
441
+
442
+ if not isinstance(profile, Mapping):
443
+ return None
444
+
445
+ max_input_tokens = profile.get("max_input_tokens")
446
+
447
+ if not isinstance(max_input_tokens, int):
448
+ return None
449
+
450
+ return max_input_tokens
451
+
452
+ def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
453
+ """Validate context configuration tuples."""
454
+ kind, value = context
455
+ if kind == "fraction":
456
+ if not 0 < value <= 1:
457
+ msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
458
+ raise ValueError(msg)
459
+ elif kind in {"tokens", "messages"}:
460
+ if value <= 0:
461
+ msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
462
+ raise ValueError(msg)
463
+ else:
464
+ msg = f"Unsupported context size type {kind} for {parameter_name}."
465
+ raise ValueError(msg)
466
+ return context
467
+
132
468
  def _build_new_messages(self, summary: str) -> list[HumanMessage]:
133
469
  return [
134
470
  HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
@@ -151,74 +487,76 @@ class SummarizationMiddleware(AgentMiddleware):
151
487
 
152
488
  return messages_to_summarize, preserved_messages
153
489
 
154
- def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
490
+ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
155
491
  """Find safe cutoff point that preserves AI/Tool message pairs.
156
492
 
157
493
  Returns the index where messages can be safely cut without separating
158
- related AI and Tool messages. Returns 0 if no safe cutoff is found.
494
+ related AI and Tool messages. Returns `0` if no safe cutoff is found.
495
+
496
+ This is aggressive with summarization - if the target cutoff lands in the
497
+ middle of tool messages, we advance past all of them (summarizing more).
159
498
  """
160
- if len(messages) <= self.messages_to_keep:
499
+ if len(messages) <= messages_to_keep:
161
500
  return 0
162
501
 
163
- target_cutoff = len(messages) - self.messages_to_keep
164
-
165
- for i in range(target_cutoff, -1, -1):
166
- if self._is_safe_cutoff_point(messages, i):
167
- return i
168
-
169
- return 0
502
+ target_cutoff = len(messages) - messages_to_keep
503
+ return self._find_safe_cutoff_point(messages, target_cutoff)
170
504
 
171
- def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
172
- """Check if cutting at index would separate AI/Tool message pairs."""
173
- if cutoff_index >= len(messages):
174
- return True
175
-
176
- search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
177
- search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
505
+ def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
506
+ """Find a safe cutoff point that doesn't split AI/Tool message pairs.
178
507
 
179
- for i in range(search_start, search_end):
180
- if not self._has_tool_calls(messages[i]):
181
- continue
508
+ If the message at `cutoff_index` is a `ToolMessage`, search backward for the
509
+ `AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
510
+ include it. This ensures tool call requests and responses stay together.
182
511
 
183
- tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
184
- if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
185
- return False
512
+ Falls back to advancing forward past `ToolMessage` objects only if no matching
513
+ `AIMessage` is found (edge case).
514
+ """
515
+ if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
516
+ return cutoff_index
517
+
518
+ # Collect tool_call_ids from consecutive ToolMessages at/after cutoff
519
+ tool_call_ids: set[str] = set()
520
+ idx = cutoff_index
521
+ while idx < len(messages) and isinstance(messages[idx], ToolMessage):
522
+ tool_msg = cast("ToolMessage", messages[idx])
523
+ if tool_msg.tool_call_id:
524
+ tool_call_ids.add(tool_msg.tool_call_id)
525
+ idx += 1
526
+
527
+ # Search backward for AIMessage with matching tool_calls
528
+ for i in range(cutoff_index - 1, -1, -1):
529
+ msg = messages[i]
530
+ if isinstance(msg, AIMessage) and msg.tool_calls:
531
+ ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
532
+ if tool_call_ids & ai_tool_call_ids:
533
+ # Found the AIMessage - move cutoff to include it
534
+ return i
535
+
536
+ # Fallback: no matching AIMessage found, advance past ToolMessages to avoid
537
+ # orphaned tool responses
538
+ return idx
186
539
 
187
- return True
540
+ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
541
+ """Generate summary for the given messages."""
542
+ if not messages_to_summarize:
543
+ return "No previous conversation history."
188
544
 
189
- def _has_tool_calls(self, message: AnyMessage) -> bool:
190
- """Check if message is an AI message with tool calls."""
191
- return (
192
- isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
193
- )
545
+ trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
546
+ if not trimmed_messages:
547
+ return "Previous conversation was too long to summarize."
194
548
 
195
- def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
196
- """Extract tool call IDs from an AI message."""
197
- tool_call_ids = set()
198
- for tc in ai_message.tool_calls:
199
- call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
200
- if call_id is not None:
201
- tool_call_ids.add(call_id)
202
- return tool_call_ids
549
+ # Format messages to avoid token inflation from metadata when str() is called on
550
+ # message objects
551
+ formatted_messages = get_buffer_string(trimmed_messages)
203
552
 
204
- def _cutoff_separates_tool_pair(
205
- self,
206
- messages: list[AnyMessage],
207
- ai_message_index: int,
208
- cutoff_index: int,
209
- tool_call_ids: set[str],
210
- ) -> bool:
211
- """Check if cutoff separates an AI message from its corresponding tool messages."""
212
- for j in range(ai_message_index + 1, len(messages)):
213
- message = messages[j]
214
- if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
215
- ai_before_cutoff = ai_message_index < cutoff_index
216
- tool_before_cutoff = j < cutoff_index
217
- if ai_before_cutoff != tool_before_cutoff:
218
- return True
219
- return False
553
+ try:
554
+ response = self.model.invoke(self.summary_prompt.format(messages=formatted_messages))
555
+ return response.text.strip()
556
+ except Exception as e:
557
+ return f"Error generating summary: {e!s}"
220
558
 
221
- def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
559
+ async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
222
560
  """Generate summary for the given messages."""
223
561
  if not messages_to_summarize:
224
562
  return "No previous conversation history."
@@ -227,23 +565,34 @@ class SummarizationMiddleware(AgentMiddleware):
227
565
  if not trimmed_messages:
228
566
  return "Previous conversation was too long to summarize."
229
567
 
568
+ # Format messages to avoid token inflation from metadata when str() is called on
569
+ # message objects
570
+ formatted_messages = get_buffer_string(trimmed_messages)
571
+
230
572
  try:
231
- response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
232
- return cast("str", response.content).strip()
233
- except Exception as e: # noqa: BLE001
573
+ response = await self.model.ainvoke(
574
+ self.summary_prompt.format(messages=formatted_messages)
575
+ )
576
+ return response.text.strip()
577
+ except Exception as e:
234
578
  return f"Error generating summary: {e!s}"
235
579
 
236
580
  def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
237
581
  """Trim messages to fit within summary generation limits."""
238
582
  try:
239
- return trim_messages(
240
- messages,
241
- max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
242
- token_counter=self.token_counter,
243
- start_on="human",
244
- strategy="last",
245
- allow_partial=True,
246
- include_system=True,
583
+ if self.trim_tokens_to_summarize is None:
584
+ return messages
585
+ return cast(
586
+ "list[AnyMessage]",
587
+ trim_messages(
588
+ messages,
589
+ max_tokens=self.trim_tokens_to_summarize,
590
+ token_counter=self.token_counter,
591
+ start_on="human",
592
+ strategy="last",
593
+ allow_partial=True,
594
+ include_system=True,
595
+ ),
247
596
  )
248
- except Exception: # noqa: BLE001
597
+ except Exception:
249
598
  return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]