langchain 1.2.2__tar.gz → 1.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {langchain-1.2.2 → langchain-1.2.3}/PKG-INFO +1 -1
  2. {langchain-1.2.2 → langchain-1.2.3}/langchain/__init__.py +1 -1
  3. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/summarization.py +57 -6
  4. {langchain-1.2.2 → langchain-1.2.3}/langchain/embeddings/base.py +1 -1
  5. {langchain-1.2.2 → langchain-1.2.3}/pyproject.toml +1 -1
  6. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_summarization.py +171 -26
  7. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/chat_models/test_chat_models.py +29 -1
  8. {langchain-1.2.2 → langchain-1.2.3}/.gitignore +0 -0
  9. {langchain-1.2.2 → langchain-1.2.3}/LICENSE +0 -0
  10. {langchain-1.2.2 → langchain-1.2.3}/Makefile +0 -0
  11. {langchain-1.2.2 → langchain-1.2.3}/README.md +0 -0
  12. {langchain-1.2.2 → langchain-1.2.3}/extended_testing_deps.txt +0 -0
  13. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/__init__.py +0 -0
  14. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/factory.py +0 -0
  15. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/__init__.py +0 -0
  16. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/_execution.py +0 -0
  17. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/_redaction.py +0 -0
  18. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/_retry.py +0 -0
  19. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/context_editing.py +0 -0
  20. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/file_search.py +0 -0
  21. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/human_in_the_loop.py +0 -0
  22. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/model_call_limit.py +0 -0
  23. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/model_fallback.py +0 -0
  24. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/model_retry.py +0 -0
  25. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/pii.py +0 -0
  26. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/shell_tool.py +0 -0
  27. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/todo.py +0 -0
  28. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/tool_call_limit.py +0 -0
  29. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/tool_emulator.py +0 -0
  30. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/tool_retry.py +0 -0
  31. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/tool_selection.py +0 -0
  32. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/middleware/types.py +0 -0
  33. {langchain-1.2.2 → langchain-1.2.3}/langchain/agents/structured_output.py +0 -0
  34. {langchain-1.2.2 → langchain-1.2.3}/langchain/chat_models/__init__.py +0 -0
  35. {langchain-1.2.2 → langchain-1.2.3}/langchain/chat_models/base.py +0 -0
  36. {langchain-1.2.2 → langchain-1.2.3}/langchain/embeddings/__init__.py +0 -0
  37. {langchain-1.2.2 → langchain-1.2.3}/langchain/messages/__init__.py +0 -0
  38. {langchain-1.2.2 → langchain-1.2.3}/langchain/py.typed +0 -0
  39. {langchain-1.2.2 → langchain-1.2.3}/langchain/rate_limiters/__init__.py +0 -0
  40. {langchain-1.2.2 → langchain-1.2.3}/langchain/tools/__init__.py +0 -0
  41. {langchain-1.2.2 → langchain-1.2.3}/langchain/tools/tool_node.py +0 -0
  42. {langchain-1.2.2 → langchain-1.2.3}/scripts/check_imports.py +0 -0
  43. {langchain-1.2.2 → langchain-1.2.3}/tests/__init__.py +0 -0
  44. {langchain-1.2.2 → langchain-1.2.3}/tests/cassettes/test_inference_to_native_output[False].yaml.gz +0 -0
  45. {langchain-1.2.2 → langchain-1.2.3}/tests/cassettes/test_inference_to_native_output[True].yaml.gz +0 -0
  46. {langchain-1.2.2 → langchain-1.2.3}/tests/cassettes/test_inference_to_tool_output[False].yaml.gz +0 -0
  47. {langchain-1.2.2 → langchain-1.2.3}/tests/cassettes/test_inference_to_tool_output[True].yaml.gz +0 -0
  48. {langchain-1.2.2 → langchain-1.2.3}/tests/cassettes/test_strict_mode[False].yaml.gz +0 -0
  49. {langchain-1.2.2 → langchain-1.2.3}/tests/cassettes/test_strict_mode[True].yaml.gz +0 -0
  50. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/__init__.py +0 -0
  51. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/agents/__init__.py +0 -0
  52. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/agents/middleware/__init__.py +0 -0
  53. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/agents/middleware/test_shell_tool_integration.py +0 -0
  54. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/cache/__init__.py +0 -0
  55. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/cache/fake_embeddings.py +0 -0
  56. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/chat_models/__init__.py +0 -0
  57. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/chat_models/test_base.py +0 -0
  58. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/conftest.py +0 -0
  59. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/embeddings/__init__.py +0 -0
  60. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/embeddings/test_base.py +0 -0
  61. {langchain-1.2.2 → langchain-1.2.3}/tests/integration_tests/test_compile.py +0 -0
  62. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/__init__.py +0 -0
  63. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/__init__.py +0 -0
  64. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr +0 -0
  65. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/__snapshots__/test_middleware_decorators.ambr +0 -0
  66. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/__snapshots__/test_middleware_framework.ambr +0 -0
  67. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/__snapshots__/test_return_direct_graph.ambr +0 -0
  68. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/any_str.py +0 -0
  69. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/compose-postgres.yml +0 -0
  70. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/compose-redis.yml +0 -0
  71. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/conftest.py +0 -0
  72. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/conftest_checkpointer.py +0 -0
  73. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/conftest_store.py +0 -0
  74. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/memory_assert.py +0 -0
  75. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/messages.py +0 -0
  76. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/__init__.py +0 -0
  77. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_decorators.ambr +0 -0
  78. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_diagram.ambr +0 -0
  79. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/__snapshots__/test_middleware_framework.ambr +0 -0
  80. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/__init__.py +0 -0
  81. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/__snapshots__/test_decorators.ambr +0 -0
  82. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/__snapshots__/test_diagram.ambr +0 -0
  83. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/__snapshots__/test_framework.ambr +0 -0
  84. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_composition.py +0 -0
  85. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_decorators.py +0 -0
  86. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_diagram.py +0 -0
  87. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_framework.py +0 -0
  88. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_overrides.py +0 -0
  89. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_sync_async_wrappers.py +0 -0
  90. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_tools.py +0 -0
  91. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_wrap_model_call.py +0 -0
  92. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/core/test_wrap_tool_call.py +0 -0
  93. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/__init__.py +0 -0
  94. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_context_editing.py +0 -0
  95. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_file_search.py +0 -0
  96. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_human_in_the_loop.py +0 -0
  97. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_model_call_limit.py +0 -0
  98. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_model_fallback.py +0 -0
  99. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_model_retry.py +0 -0
  100. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_pii.py +0 -0
  101. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_shell_execution_policies.py +0 -0
  102. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_shell_tool.py +0 -0
  103. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_structured_output_retry.py +0 -0
  104. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_todo.py +0 -0
  105. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_tool_call_limit.py +0 -0
  106. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_tool_emulator.py +0 -0
  107. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_tool_retry.py +0 -0
  108. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/middleware/implementations/test_tool_selection.py +0 -0
  109. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/model.py +0 -0
  110. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/specifications/responses.json +0 -0
  111. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/specifications/return_direct.json +0 -0
  112. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_agent_name.py +0 -0
  113. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_create_agent_tool_validation.py +0 -0
  114. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_injected_runtime_create_agent.py +0 -0
  115. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_react_agent.py +0 -0
  116. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_response_format.py +0 -0
  117. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_response_format_integration.py +0 -0
  118. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_responses.py +0 -0
  119. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_responses_spec.py +0 -0
  120. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_return_direct_graph.py +0 -0
  121. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_return_direct_spec.py +0 -0
  122. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_state_schema.py +0 -0
  123. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/test_system_message.py +0 -0
  124. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/agents/utils.py +0 -0
  125. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/chat_models/__init__.py +0 -0
  126. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/conftest.py +0 -0
  127. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/embeddings/__init__.py +0 -0
  128. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/embeddings/test_base.py +0 -0
  129. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/embeddings/test_imports.py +0 -0
  130. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/test_dependencies.py +0 -0
  131. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/test_imports.py +0 -0
  132. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/test_pytest_config.py +0 -0
  133. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/test_version.py +0 -0
  134. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/tools/__init__.py +0 -0
  135. {langchain-1.2.2 → langchain-1.2.3}/tests/unit_tests/tools/test_imports.py +0 -0
  136. {langchain-1.2.2 → langchain-1.2.3}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langchain
3
- Version: 1.2.2
3
+ Version: 1.2.3
4
4
  Summary: Building applications with LLMs through composability
5
5
  Project-URL: Homepage, https://docs.langchain.com/
6
6
  Project-URL: Documentation, https://reference.langchain.com/python/langchain/langchain/
@@ -1,3 +1,3 @@
1
1
  """Main entrypoint into LangChain."""
2
2
 
3
- __version__ = "1.2.2"
3
+ __version__ = "1.2.3"
@@ -7,6 +7,7 @@ from functools import partial
7
7
  from typing import Any, Literal, cast
8
8
 
9
9
  from langchain_core.messages import (
10
+ AIMessage,
10
11
  AnyMessage,
11
12
  MessageLikeRepresentation,
12
13
  RemoveMessage,
@@ -323,6 +324,25 @@ class SummarizationMiddleware(AgentMiddleware):
323
324
  ]
324
325
  }
325
326
 
327
+ def _should_summarize_based_on_reported_tokens(
328
+ self, messages: list[AnyMessage], threshold: float
329
+ ) -> bool:
330
+ """Check if reported token usage from last AIMessage exceeds threshold."""
331
+ last_ai_message = next(
332
+ (msg for msg in reversed(messages) if isinstance(msg, AIMessage)),
333
+ None,
334
+ )
335
+ if ( # noqa: SIM103
336
+ isinstance(last_ai_message, AIMessage)
337
+ and last_ai_message.usage_metadata is not None
338
+ and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
339
+ and reported_tokens >= threshold
340
+ and (message_provider := last_ai_message.response_metadata.get("model_provider"))
341
+ and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
342
+ ):
343
+ return True
344
+ return False
345
+
326
346
  def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
327
347
  """Determine whether summarization should run for the current token usage."""
328
348
  if not self._trigger_conditions:
@@ -333,6 +353,10 @@ class SummarizationMiddleware(AgentMiddleware):
333
353
  return True
334
354
  if kind == "tokens" and total_tokens >= value:
335
355
  return True
356
+ if kind == "tokens" and self._should_summarize_based_on_reported_tokens(
357
+ messages, value
358
+ ):
359
+ return True
336
360
  if kind == "fraction":
337
361
  max_input_tokens = self._get_profile_limits()
338
362
  if max_input_tokens is None:
@@ -342,6 +366,9 @@ class SummarizationMiddleware(AgentMiddleware):
342
366
  threshold = 1
343
367
  if total_tokens >= threshold:
344
368
  return True
369
+
370
+ if self._should_summarize_based_on_reported_tokens(messages, threshold):
371
+ return True
345
372
  return False
346
373
 
347
374
  def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
@@ -478,13 +505,37 @@ class SummarizationMiddleware(AgentMiddleware):
478
505
  def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
479
506
  """Find a safe cutoff point that doesn't split AI/Tool message pairs.
480
507
 
481
- If the message at cutoff_index is a ToolMessage, advance until we find
482
- a non-ToolMessage. This ensures we never cut in the middle of parallel
483
- tool call responses.
508
+ If the message at `cutoff_index` is a `ToolMessage`, search backward for the
509
+ `AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
510
+ include it. This ensures tool call requests and responses stay together.
511
+
512
+ Falls back to advancing forward past `ToolMessage` objects only if no matching
513
+ `AIMessage` is found (edge case).
484
514
  """
485
- while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
486
- cutoff_index += 1
487
- return cutoff_index
515
+ if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
516
+ return cutoff_index
517
+
518
+ # Collect tool_call_ids from consecutive ToolMessages at/after cutoff
519
+ tool_call_ids: set[str] = set()
520
+ idx = cutoff_index
521
+ while idx < len(messages) and isinstance(messages[idx], ToolMessage):
522
+ tool_msg = cast("ToolMessage", messages[idx])
523
+ if tool_msg.tool_call_id:
524
+ tool_call_ids.add(tool_msg.tool_call_id)
525
+ idx += 1
526
+
527
+ # Search backward for AIMessage with matching tool_calls
528
+ for i in range(cutoff_index - 1, -1, -1):
529
+ msg = messages[i]
530
+ if isinstance(msg, AIMessage) and msg.tool_calls:
531
+ ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
532
+ if tool_call_ids & ai_tool_call_ids:
533
+ # Found the AIMessage - move cutoff to include it
534
+ return i
535
+
536
+ # Fallback: no matching AIMessage found, advance past ToolMessages to avoid
537
+ # orphaned tool responses
538
+ return idx
488
539
 
489
540
  def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
490
541
  """Generate summary for the given messages."""
@@ -13,7 +13,7 @@ def _call(cls: type[Embeddings], **kwargs: Any) -> Embeddings:
13
13
 
14
14
 
15
15
  _SUPPORTED_PROVIDERS: dict[str, tuple[str, str, Callable[..., Embeddings]]] = {
16
- "azure_openai": ("langchain_openai", "OpenAIEmbeddings", _call),
16
+ "azure_openai": ("langchain_openai", "AzureOpenAIEmbeddings", _call),
17
17
  "bedrock": (
18
18
  "langchain_aws",
19
19
  "BedrockEmbeddings",
@@ -9,7 +9,7 @@ license = { text = "MIT" }
9
9
  readme = "README.md"
10
10
  authors = []
11
11
 
12
- version = "1.2.2"
12
+ version = "1.2.3"
13
13
  requires-python = ">=3.10.0,<4.0.0"
14
14
  dependencies = [
15
15
  "langchain-core>=1.2.1,<2.0.0",
@@ -9,6 +9,7 @@ from langchain_core.outputs import ChatGeneration, ChatResult
9
9
  from langgraph.graph.message import REMOVE_ALL_MESSAGES
10
10
 
11
11
  from langchain.agents.middleware.summarization import SummarizationMiddleware
12
+ from langchain.chat_models import init_chat_model
12
13
  from tests.unit_tests.agents.model import FakeToolCallingModel
13
14
 
14
15
 
@@ -281,8 +282,8 @@ def test_summarization_middleware_profile_inference_triggers_summary() -> None:
281
282
  ]
282
283
 
283
284
 
284
- def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
285
- """Ensure token retention advances past tool messages for aggressive summarization."""
285
+ def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None:
286
+ """Ensure token retention preserves AI/Tool message pairs together."""
286
287
 
287
288
  def token_counter(messages: list[AnyMessage]) -> int:
288
289
  return sum(len(getattr(message, "content", "")) for message in messages)
@@ -297,7 +298,7 @@ def test_summarization_middleware_token_retention_advances_past_tool_messages()
297
298
  # Total tokens: 300 + 200 + 50 + 180 + 160 = 890
298
299
  # Target keep: 500 tokens (50% of 1000)
299
300
  # Binary search finds cutoff around index 2 (ToolMessage)
300
- # We advance past it to index 3 (HumanMessage)
301
+ # We move back to index 1 to preserve the AIMessage with its ToolMessage
301
302
  messages: list[AnyMessage] = [
302
303
  HumanMessage(content="H" * 300),
303
304
  AIMessage(
@@ -314,14 +315,15 @@ def test_summarization_middleware_token_retention_advances_past_tool_messages()
314
315
  assert result is not None
315
316
 
316
317
  preserved_messages = result["messages"][2:]
317
- # With aggressive summarization, we advance past the ToolMessage
318
- # So we preserve messages from index 3 onward (the two HumanMessages)
319
- assert preserved_messages == messages[3:]
318
+ # We move the cutoff back to include the AIMessage with its ToolMessage
319
+ # So we preserve messages from index 1 onward (AI + Tool + Human + Human)
320
+ assert preserved_messages == messages[1:]
320
321
 
321
- # Verify preserved tokens are within budget
322
- target_token_count = int(1000 * 0.5)
323
- preserved_tokens = middleware.token_counter(preserved_messages)
324
- assert preserved_tokens <= target_token_count
322
+ # Verify the AI/Tool pair is preserved together
323
+ assert isinstance(preserved_messages[0], AIMessage)
324
+ assert preserved_messages[0].tool_calls
325
+ assert isinstance(preserved_messages[1], ToolMessage)
326
+ assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"]
325
327
 
326
328
 
327
329
  def test_summarization_middleware_missing_profile() -> None:
@@ -666,7 +668,7 @@ def test_summarization_middleware_binary_search_edge_cases() -> None:
666
668
 
667
669
 
668
670
  def test_summarization_middleware_find_safe_cutoff_point() -> None:
669
- """Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
671
+ """Test `_find_safe_cutoff_point` preserves AI/Tool message pairs."""
670
672
  model = FakeToolCallingModel()
671
673
  middleware = SummarizationMiddleware(
672
674
  model=model, trigger=("messages", 10), keep=("messages", 2)
@@ -676,7 +678,7 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
676
678
  HumanMessage(content="msg1"),
677
679
  AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
678
680
  ToolMessage(content="result1", tool_call_id="call1"),
679
- ToolMessage(content="result2", tool_call_id="call2"),
681
+ ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI
680
682
  HumanMessage(content="msg2"),
681
683
  ]
682
684
 
@@ -684,8 +686,14 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
684
686
  assert middleware._find_safe_cutoff_point(messages, 0) == 0
685
687
  assert middleware._find_safe_cutoff_point(messages, 1) == 1
686
688
 
687
- # Starting at a ToolMessage advances to the next non-ToolMessage
688
- assert middleware._find_safe_cutoff_point(messages, 2) == 4
689
+ # Starting at ToolMessage with matching AIMessage moves back to include it
690
+ # ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1
691
+ assert middleware._find_safe_cutoff_point(messages, 2) == 1
692
+
693
+ # Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing
694
+ # ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage
695
+ # Since we only collect from cutoff_index onwards, only {call2} is collected
696
+ # No match found, so we fall back to advancing past ToolMessages
689
697
  assert middleware._find_safe_cutoff_point(messages, 3) == 4
690
698
 
691
699
  # Starting at the HumanMessage after tools returns that index
@@ -699,6 +707,65 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
699
707
  assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5
700
708
 
701
709
 
710
+ def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None:
711
+ """Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`)."""
712
+ model = FakeToolCallingModel()
713
+ middleware = SummarizationMiddleware(
714
+ model=model, trigger=("messages", 10), keep=("messages", 2)
715
+ )
716
+
717
+ # Messages where ToolMessage has no matching AIMessage at all
718
+ messages: list[AnyMessage] = [
719
+ HumanMessage(content="msg1"),
720
+ AIMessage(content="ai_no_tools"), # No tool_calls
721
+ ToolMessage(content="orphan_result", tool_call_id="orphan_call"),
722
+ HumanMessage(content="msg2"),
723
+ ]
724
+
725
+ # Starting at orphan ToolMessage falls back to advancing forward
726
+ assert middleware._find_safe_cutoff_point(messages, 2) == 3
727
+
728
+
729
+ def test_summarization_cutoff_moves_backward_to_include_ai_message() -> None:
730
+ """Test that cutoff moves backward to include `AIMessage` with its `ToolMessage`s.
731
+
732
+ Previously, when the cutoff landed on a `ToolMessage`, the code would advance
733
+ FORWARD past all `ToolMessage`s. This could result in orphaned `ToolMessage`s (kept
734
+ without their `AIMessage`) or aggressive summarization that removed AI/Tool pairs.
735
+
736
+ The fix searches backward from a `ToolMessage` to find the `AIMessage` with matching
737
+ `tool_calls`, ensuring the pair stays together in the preserved messages.
738
+ """
739
+ model = FakeToolCallingModel()
740
+ middleware = SummarizationMiddleware(
741
+ model=model, trigger=("messages", 10), keep=("messages", 2)
742
+ )
743
+
744
+ # Scenario: cutoff lands on ToolMessage that has a matching AIMessage before it
745
+ messages: list[AnyMessage] = [
746
+ HumanMessage(content="initial question"), # index 0
747
+ AIMessage(
748
+ content="I'll use a tool",
749
+ tool_calls=[{"name": "search", "args": {"q": "test"}, "id": "call_abc"}],
750
+ ), # index 1
751
+ ToolMessage(content="search result", tool_call_id="call_abc"), # index 2
752
+ HumanMessage(content="followup"), # index 3
753
+ ]
754
+
755
+ # When cutoff is at index 2 (ToolMessage), it should move BACKWARD to index 1
756
+ # to include the AIMessage that generated the tool call
757
+ result = middleware._find_safe_cutoff_point(messages, 2)
758
+
759
+ assert result == 1, (
760
+ f"Expected cutoff to move backward to index 1 (AIMessage), got {result}. "
761
+ "The cutoff should preserve AI/Tool pairs together."
762
+ )
763
+
764
+ assert isinstance(messages[result], AIMessage)
765
+ assert messages[result].tool_calls # type: ignore[union-attr]
766
+ assert messages[result].tool_calls[0]["id"] == "call_abc" # type: ignore[union-attr]
767
+
768
+
702
769
  def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
703
770
  """Test handling of edge cases with target token calculations."""
704
771
  # Test with very small fraction that rounds to zero
@@ -814,7 +881,7 @@ def test_summarization_adjust_token_counts() -> None:
814
881
 
815
882
 
816
883
  def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
817
- """Test cutoff safety with many parallel tool calls extending beyond old search range."""
884
+ """Test cutoff safety preserves AI message with many parallel tool calls."""
818
885
  middleware = SummarizationMiddleware(
819
886
  model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
820
887
  )
@@ -826,20 +893,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
826
893
  ]
827
894
  messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]
828
895
 
829
- # Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
830
- assert middleware._find_safe_cutoff_point(messages, 7) == 12
896
+ # Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage)
897
+ # to preserve the AI/Tool pair together
898
+ assert middleware._find_safe_cutoff_point(messages, 7) == 1
831
899
 
832
- # Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
900
+ # Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1
833
901
  for i in range(2, 12):
834
- assert middleware._find_safe_cutoff_point(messages, i) == 12
902
+ assert middleware._find_safe_cutoff_point(messages, i) == 1
835
903
 
836
904
  # Cutoff at index 0, 1 (before tool messages) stays the same
837
905
  assert middleware._find_safe_cutoff_point(messages, 0) == 0
838
906
  assert middleware._find_safe_cutoff_point(messages, 1) == 1
839
907
 
840
908
 
841
- def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
842
- """Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
909
+ def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None:
910
+ """Test `_find_safe_cutoff` preserves AI/Tool message pairs together."""
843
911
  middleware = SummarizationMiddleware(
844
912
  model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
845
913
  )
@@ -862,15 +930,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None
862
930
  ]
863
931
 
864
932
  # Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
865
- # Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
933
+ # Index 3 is a ToolMessage, we move back to index 1 to include AIMessage
866
934
  cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
867
- assert cutoff == 5
935
+ assert cutoff == 1
868
936
 
869
937
  # With messages_to_keep=2, target cutoff index is 6 - 2 = 4
870
- # Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
871
- # This is aggressive - we keep only 1 message instead of 2
938
+ # Index 4 is a ToolMessage, we move back to index 1 to include AIMessage
939
+ # This preserves the AI + Tools + Human, more than requested but valid
872
940
  cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
873
- assert cutoff == 5
941
+ assert cutoff == 1
874
942
 
875
943
 
876
944
  def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
@@ -947,3 +1015,80 @@ def test_create_summary_uses_get_buffer_string_format() -> None:
947
1015
  f"str(messages) should produce significantly more tokens. "
948
1016
  f"Got ratio {str_ratio:.2f}x (expected > 1.5)"
949
1017
  )
1018
+
1019
+
1020
+ @pytest.mark.requires("langchain_anthropic")
1021
+ def test_usage_metadata_trigger() -> None:
1022
+ model = init_chat_model("anthropic:claude-sonnet-4-5")
1023
+ middleware = SummarizationMiddleware(
1024
+ model=model, trigger=("tokens", 10_000), keep=("messages", 4)
1025
+ )
1026
+ messages: list[AnyMessage] = [
1027
+ HumanMessage(content="msg1"),
1028
+ AIMessage(
1029
+ content="msg2",
1030
+ tool_calls=[{"name": "tool", "args": {}, "id": "call1"}],
1031
+ response_metadata={"model_provider": "anthropic"},
1032
+ usage_metadata={
1033
+ "input_tokens": 5000,
1034
+ "output_tokens": 1000,
1035
+ "total_tokens": 6000,
1036
+ },
1037
+ ),
1038
+ ToolMessage(content="result", tool_call_id="call1"),
1039
+ AIMessage(
1040
+ content="msg3",
1041
+ response_metadata={"model_provider": "anthropic"},
1042
+ usage_metadata={
1043
+ "input_tokens": 6100,
1044
+ "output_tokens": 900,
1045
+ "total_tokens": 7000,
1046
+ },
1047
+ ),
1048
+ HumanMessage(content="msg4"),
1049
+ AIMessage(
1050
+ content="msg5",
1051
+ response_metadata={"model_provider": "anthropic"},
1052
+ usage_metadata={
1053
+ "input_tokens": 7500,
1054
+ "output_tokens": 2501,
1055
+ "total_tokens": 10_001,
1056
+ },
1057
+ ),
1058
+ ]
1059
+ # reported token count should override count of zero
1060
+ assert middleware._should_summarize(messages, 0)
1061
+
1062
+ # don't engage unless model provider matches
1063
+ messages.extend(
1064
+ [
1065
+ HumanMessage(content="msg6"),
1066
+ AIMessage(
1067
+ content="msg7",
1068
+ response_metadata={"model_provider": "not-anthropic"},
1069
+ usage_metadata={
1070
+ "input_tokens": 7500,
1071
+ "output_tokens": 2501,
1072
+ "total_tokens": 10_001,
1073
+ },
1074
+ ),
1075
+ ]
1076
+ )
1077
+ assert not middleware._should_summarize(messages, 0)
1078
+
1079
+ # don't engage if subsequent message stays under threshold (e.g., after summarization)
1080
+ messages.extend(
1081
+ [
1082
+ HumanMessage(content="msg8"),
1083
+ AIMessage(
1084
+ content="msg9",
1085
+ response_metadata={"model_provider": "anthropic"},
1086
+ usage_metadata={
1087
+ "input_tokens": 7500,
1088
+ "output_tokens": 2499,
1089
+ "total_tokens": 9999,
1090
+ },
1091
+ ),
1092
+ ]
1093
+ )
1094
+ assert not middleware._should_summarize(messages, 0)
@@ -8,7 +8,7 @@ from langchain_core.runnables import RunnableConfig, RunnableSequence
8
8
  from pydantic import SecretStr
9
9
 
10
10
  from langchain.chat_models import __all__, init_chat_model
11
- from langchain.chat_models.base import _SUPPORTED_PROVIDERS
11
+ from langchain.chat_models.base import _SUPPORTED_PROVIDERS, _attempt_infer_model_provider
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from langchain_core.language_models import BaseChatModel
@@ -67,6 +67,34 @@ def test_supported_providers_is_sorted() -> None:
67
67
  assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
68
68
 
69
69
 
70
+ @pytest.mark.parametrize(
71
+ ("model_name", "expected_provider"),
72
+ [
73
+ ("gpt-4o", "openai"),
74
+ ("o1-mini", "openai"),
75
+ ("o3-mini", "openai"),
76
+ ("chatgpt-4o-latest", "openai"),
77
+ ("text-davinci-003", "openai"),
78
+ ("claude-3-haiku-20240307", "anthropic"),
79
+ ("command-r-plus", "cohere"),
80
+ ("accounts/fireworks/models/mixtral-8x7b-instruct", "fireworks"),
81
+ ("gemini-1.5-pro", "google_vertexai"),
82
+ ("gemini-2.5-pro", "google_vertexai"),
83
+ ("gemini-3-pro-preview", "google_vertexai"),
84
+ ("amazon.titan-text-express-v1", "bedrock"),
85
+ ("anthropic.claude-v2", "bedrock"),
86
+ ("mistral-small", "mistralai"),
87
+ ("mixtral-8x7b", "mistralai"),
88
+ ("deepseek-v3", "deepseek"),
89
+ ("grok-beta", "xai"),
90
+ ("sonar-small", "perplexity"),
91
+ ("solar-pro", "upstage"),
92
+ ],
93
+ )
94
+ def test_attempt_infer_model_provider(model_name: str, expected_provider: str) -> None:
95
+ assert _attempt_infer_model_provider(model_name) == expected_provider
96
+
97
+
70
98
  @pytest.mark.requires("langchain_openai")
71
99
  @mock.patch.dict(
72
100
  os.environ,
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes