langchain 1.0.0a3__tar.gz → 1.0.0a4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langchain might be problematic. Click here for more details.

Files changed (90) hide show
  1. {langchain-1.0.0a3 → langchain-1.0.0a4}/PKG-INFO +2 -2
  2. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/__init__.py +1 -1
  3. langchain-1.0.0a4/langchain/agents/middleware/__init__.py +15 -0
  4. langchain-1.0.0a4/langchain/agents/middleware/_utils.py +11 -0
  5. langchain-1.0.0a4/langchain/agents/middleware/human_in_the_loop.py +128 -0
  6. langchain-1.0.0a4/langchain/agents/middleware/prompt_caching.py +57 -0
  7. langchain-1.0.0a4/langchain/agents/middleware/summarization.py +248 -0
  8. langchain-1.0.0a4/langchain/agents/middleware/types.py +78 -0
  9. langchain-1.0.0a4/langchain/agents/middleware_agent.py +554 -0
  10. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/agents/react_agent.py +28 -0
  11. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/chat_models/__init__.py +2 -0
  12. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/chat_models/base.py +2 -0
  13. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/documents/__init__.py +2 -0
  14. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/embeddings/__init__.py +2 -0
  15. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/embeddings/base.py +2 -0
  16. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/storage/encoder_backed.py +2 -0
  17. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/storage/exceptions.py +2 -0
  18. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/tools/__init__.py +2 -0
  19. {langchain-1.0.0a3 → langchain-1.0.0a4}/pyproject.toml +8 -16
  20. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/cache/fake_embeddings.py +20 -14
  21. langchain-1.0.0a4/tests/unit_tests/agents/__snapshots__/test_middleware_agent.ambr +533 -0
  22. langchain-1.0.0a4/tests/unit_tests/agents/test_middleware_agent.py +712 -0
  23. {langchain-1.0.0a3 → langchain-1.0.0a4}/LICENSE +0 -0
  24. {langchain-1.0.0a3 → langchain-1.0.0a4}/README.md +0 -0
  25. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/_internal/__init__.py +0 -0
  26. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/_internal/_documents.py +0 -0
  27. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/_internal/_lazy_import.py +0 -0
  28. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/_internal/_prompts.py +0 -0
  29. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/_internal/_typing.py +0 -0
  30. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/_internal/_utils.py +0 -0
  31. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/agents/__init__.py +0 -0
  32. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/agents/_internal/__init__.py +0 -0
  33. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/agents/_internal/_typing.py +0 -0
  34. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/agents/interrupt.py +0 -0
  35. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/agents/structured_output.py +0 -0
  36. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/agents/tool_node.py +0 -0
  37. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/embeddings/cache.py +0 -0
  38. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/globals.py +0 -0
  39. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/py.typed +0 -0
  40. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/storage/__init__.py +0 -0
  41. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/storage/in_memory.py +0 -0
  42. {langchain-1.0.0a3 → langchain-1.0.0a4}/langchain/text_splitter.py +0 -0
  43. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/__init__.py +0 -0
  44. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/__init__.py +0 -0
  45. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/agents/__init__.py +0 -0
  46. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/agents/test_response_format.py +0 -0
  47. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/cache/__init__.py +0 -0
  48. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/chat_models/__init__.py +0 -0
  49. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/chat_models/test_base.py +0 -0
  50. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/conftest.py +0 -0
  51. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/embeddings/__init__.py +0 -0
  52. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/embeddings/test_base.py +0 -0
  53. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/integration_tests/test_compile.py +0 -0
  54. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/__init__.py +0 -0
  55. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/__init__.py +0 -0
  56. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/__snapshots__/test_react_agent_graph.ambr +0 -0
  57. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/any_str.py +0 -0
  58. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/compose-postgres.yml +0 -0
  59. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/compose-redis.yml +0 -0
  60. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/conftest.py +0 -0
  61. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/conftest_checkpointer.py +0 -0
  62. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/conftest_store.py +0 -0
  63. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/memory_assert.py +0 -0
  64. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/messages.py +0 -0
  65. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/model.py +0 -0
  66. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/specifications/responses.json +0 -0
  67. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/specifications/return_direct.json +0 -0
  68. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/test_react_agent.py +0 -0
  69. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/test_react_agent_graph.py +0 -0
  70. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/test_response_format.py +0 -0
  71. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/test_responses.py +0 -0
  72. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/test_responses_spec.py +0 -0
  73. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/test_return_direct_spec.py +0 -0
  74. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/test_tool_node.py +0 -0
  75. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/agents/utils.py +0 -0
  76. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/chat_models/__init__.py +0 -0
  77. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/chat_models/test_chat_models.py +0 -0
  78. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/conftest.py +0 -0
  79. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/embeddings/__init__.py +0 -0
  80. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/embeddings/test_base.py +0 -0
  81. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/embeddings/test_caching.py +0 -0
  82. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/embeddings/test_imports.py +0 -0
  83. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/storage/__init__.py +0 -0
  84. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/storage/test_imports.py +0 -0
  85. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/stubs.py +0 -0
  86. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/test_dependencies.py +0 -0
  87. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/test_imports.py +0 -0
  88. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/test_pytest_config.py +0 -0
  89. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/tools/__init__.py +0 -0
  90. {langchain-1.0.0a3 → langchain-1.0.0a4}/tests/unit_tests/tools/test_imports.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain
3
- Version: 1.0.0a3
3
+ Version: 1.0.0a4
4
4
  Summary: Building applications with LLMs through composability
5
5
  License: MIT
6
6
  Project-URL: Source Code, https://github.com/langchain-ai/langchain/tree/master/libs/langchain
@@ -9,7 +9,7 @@ Project-URL: repository, https://github.com/langchain-ai/langchain
9
9
  Requires-Python: >=3.10
10
10
  Requires-Dist: langchain-core<2.0.0,>=0.3.75
11
11
  Requires-Dist: langchain-text-splitters<1.0.0,>=0.3.11
12
- Requires-Dist: langgraph>=0.6.0
12
+ Requires-Dist: langgraph>=0.6.7
13
13
  Requires-Dist: pydantic>=2.7.4
14
14
  Provides-Extra: anthropic
15
15
  Requires-Dist: langchain-anthropic; extra == "anthropic"
@@ -2,7 +2,7 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- __version__ = "1.0.0a1"
5
+ __version__ = "1.0.0a3"
6
6
 
7
7
 
8
8
  def __getattr__(name: str) -> Any: # noqa: ANN401
@@ -0,0 +1,15 @@
1
+ """Middleware plugins for agents."""
2
+
3
+ from .human_in_the_loop import HumanInTheLoopMiddleware
4
+ from .prompt_caching import AnthropicPromptCachingMiddleware
5
+ from .summarization import SummarizationMiddleware
6
+ from .types import AgentMiddleware, AgentState, ModelRequest
7
+
8
+ __all__ = [
9
+ "AgentMiddleware",
10
+ "AgentState",
11
+ "AnthropicPromptCachingMiddleware",
12
+ "HumanInTheLoopMiddleware",
13
+ "ModelRequest",
14
+ "SummarizationMiddleware",
15
+ ]
@@ -0,0 +1,11 @@
1
+ """Utility functions for middleware."""
2
+
3
+ from typing import Any
4
+
5
+
6
+ def _generate_correction_tool_messages(content: str, tool_calls: list) -> list[dict[str, Any]]:
7
+ """Generate tool messages for model behavior correction."""
8
+ return [
9
+ {"role": "tool", "content": content, "tool_call_id": tool_call["id"]}
10
+ for tool_call in tool_calls
11
+ ]
@@ -0,0 +1,128 @@
1
+ """Human in the loop middleware."""
2
+
3
+ from typing import Any
4
+
5
+ from langgraph.prebuilt.interrupt import (
6
+ ActionRequest,
7
+ HumanInterrupt,
8
+ HumanInterruptConfig,
9
+ HumanResponse,
10
+ )
11
+ from langgraph.types import interrupt
12
+
13
+ from langchain.agents.middleware._utils import _generate_correction_tool_messages
14
+ from langchain.agents.middleware.types import AgentMiddleware, AgentState
15
+
16
+ ToolInterruptConfig = dict[str, HumanInterruptConfig]
17
+
18
+
19
+ class HumanInTheLoopMiddleware(AgentMiddleware):
20
+ """Human in the loop middleware."""
21
+
22
+ def __init__(
23
+ self,
24
+ tool_configs: ToolInterruptConfig,
25
+ message_prefix: str = "Tool execution requires approval",
26
+ ) -> None:
27
+ """Initialize the human in the loop middleware.
28
+
29
+ Args:
30
+ tool_configs: The tool interrupt configs to use for the middleware.
31
+ message_prefix: The message prefix to use when constructing interrupt content.
32
+ """
33
+ super().__init__()
34
+ self.tool_configs = tool_configs
35
+ self.message_prefix = message_prefix
36
+
37
+ def after_model(self, state: AgentState) -> dict[str, Any] | None:
38
+ """Trigger HITL flows for relevant tool calls after an AIMessage."""
39
+ messages = state["messages"]
40
+ if not messages:
41
+ return None
42
+
43
+ last_message = messages[-1]
44
+
45
+ if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
46
+ return None
47
+
48
+ # Separate tool calls that need interrupts from those that don't
49
+ interrupt_tool_calls = []
50
+ auto_approved_tool_calls = []
51
+
52
+ for tool_call in last_message.tool_calls:
53
+ tool_name = tool_call["name"]
54
+ if tool_name in self.tool_configs:
55
+ interrupt_tool_calls.append(tool_call)
56
+ else:
57
+ auto_approved_tool_calls.append(tool_call)
58
+
59
+ # If no interrupts needed, return early
60
+ if not interrupt_tool_calls:
61
+ return None
62
+
63
+ approved_tool_calls = auto_approved_tool_calls.copy()
64
+
65
+ # Right now, we do not support multiple tool calls with interrupts
66
+ if len(interrupt_tool_calls) > 1:
67
+ tool_names = [t["name"] for t in interrupt_tool_calls]
68
+ msg = f"Called the following tools which require interrupts: {tool_names}\n\nYou may only call ONE tool that requires an interrupt at a time"
69
+ return {
70
+ "messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
71
+ "jump_to": "model",
72
+ }
73
+
74
+ # Right now, we do not support interrupting a tool call if other tool calls exist
75
+ if auto_approved_tool_calls:
76
+ tool_names = [t["name"] for t in interrupt_tool_calls]
77
+ msg = f"Called the following tools which require interrupts: {tool_names}. You also called other tools that do not require interrupts. If you call a tool that requires and interrupt, you may ONLY call that tool."
78
+ return {
79
+ "messages": _generate_correction_tool_messages(msg, last_message.tool_calls),
80
+ "jump_to": "model",
81
+ }
82
+
83
+ # Only one tool call will need interrupts
84
+ tool_call = interrupt_tool_calls[0]
85
+ tool_name = tool_call["name"]
86
+ tool_args = tool_call["args"]
87
+ description = f"{self.message_prefix}\n\nTool: {tool_name}\nArgs: {tool_args}"
88
+ tool_config = self.tool_configs[tool_name]
89
+
90
+ request: HumanInterrupt = {
91
+ "action_request": ActionRequest(
92
+ action=tool_name,
93
+ args=tool_args,
94
+ ),
95
+ "config": tool_config,
96
+ "description": description,
97
+ }
98
+
99
+ responses: list[HumanResponse] = interrupt([request])
100
+ response = responses[0]
101
+
102
+ if response["type"] == "accept":
103
+ approved_tool_calls.append(tool_call)
104
+ elif response["type"] == "edit":
105
+ edited: ActionRequest = response["args"] # type: ignore[assignment]
106
+ new_tool_call = {
107
+ "type": "tool_call",
108
+ "name": tool_call["name"],
109
+ "args": edited["args"],
110
+ "id": tool_call["id"],
111
+ }
112
+ approved_tool_calls.append(new_tool_call)
113
+ elif response["type"] == "ignore":
114
+ return {"jump_to": "__end__"}
115
+ elif response["type"] == "response":
116
+ tool_message = {
117
+ "role": "tool",
118
+ "tool_call_id": tool_call["id"],
119
+ "content": response["args"],
120
+ }
121
+ return {"messages": [tool_message], "jump_to": "model"}
122
+ else:
123
+ msg = f"Unknown response type: {response['type']}"
124
+ raise ValueError(msg)
125
+
126
+ last_message.tool_calls = approved_tool_calls
127
+
128
+ return {"messages": [last_message]}
@@ -0,0 +1,57 @@
1
+ """Anthropic prompt caching middleware."""
2
+
3
+ from typing import Literal
4
+
5
+ from langchain.agents.middleware.types import AgentMiddleware, AgentState, ModelRequest
6
+
7
+
8
+ class AnthropicPromptCachingMiddleware(AgentMiddleware):
9
+ """Prompt Caching Middleware - Optimizes API usage by caching conversation prefixes for Anthropic models.
10
+
11
+ Learn more about anthropic prompt caching [here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ type: Literal["ephemeral"] = "ephemeral",
17
+ ttl: Literal["5m", "1h"] = "5m",
18
+ min_messages_to_cache: int = 0,
19
+ ) -> None:
20
+ """Initialize the middleware with cache control settings.
21
+
22
+ Args:
23
+ type: The type of cache to use, only "ephemeral" is supported.
24
+ ttl: The time to live for the cache, only "5m" and "1h" are supported.
25
+ min_messages_to_cache: The minimum number of messages until the cache is used, default is 0.
26
+ """
27
+ self.type = type
28
+ self.ttl = ttl
29
+ self.min_messages_to_cache = min_messages_to_cache
30
+
31
+ def modify_model_request(self, request: ModelRequest, state: AgentState) -> ModelRequest: # noqa: ARG002
32
+ """Modify the model request to add cache control blocks."""
33
+ try:
34
+ from langchain_anthropic import ChatAnthropic
35
+ except ImportError:
36
+ msg = (
37
+ "AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models."
38
+ "Please install langchain-anthropic."
39
+ )
40
+ raise ValueError(msg)
41
+
42
+ if not isinstance(request.model, ChatAnthropic):
43
+ msg = (
44
+ "AnthropicPromptCachingMiddleware caching middleware only supports Anthropic models, "
45
+ f"not instances of {type(request.model)}"
46
+ )
47
+ raise ValueError(msg)
48
+
49
+ messages_count = (
50
+ len(request.messages) + 1 if request.system_prompt else len(request.messages)
51
+ )
52
+ if messages_count < self.min_messages_to_cache:
53
+ return request
54
+
55
+ request.model_settings["cache_control"] = {"type": self.type, "ttl": self.ttl}
56
+
57
+ return request
@@ -0,0 +1,248 @@
1
+ """Summarization middleware."""
2
+
3
+ import uuid
4
+ from collections.abc import Callable, Iterable
5
+ from typing import Any, cast
6
+
7
+ from langchain_core.messages import (
8
+ AIMessage,
9
+ AnyMessage,
10
+ MessageLikeRepresentation,
11
+ RemoveMessage,
12
+ ToolMessage,
13
+ )
14
+ from langchain_core.messages.human import HumanMessage
15
+ from langchain_core.messages.utils import count_tokens_approximately, trim_messages
16
+ from langgraph.graph.message import (
17
+ REMOVE_ALL_MESSAGES,
18
+ )
19
+
20
+ from langchain.agents.middleware.types import AgentMiddleware, AgentState
21
+ from langchain.chat_models import BaseChatModel, init_chat_model
22
+
23
+ TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
24
+
25
+ DEFAULT_SUMMARY_PROMPT = """<role>
26
+ Context Extraction Assistant
27
+ </role>
28
+
29
+ <primary_objective>
30
+ Your sole objective in this task is to extract the highest quality/most relevant context from the conversation history below.
31
+ </primary_objective>
32
+
33
+ <objective_information>
34
+ You're nearing the total number of input tokens you can accept, so you must extract the highest quality/most relevant pieces of information from your conversation history.
35
+ This context will then overwrite the conversation history presented below. Because of this, ensure the context you extract is only the most important information to your overall goal.
36
+ </objective_information>
37
+
38
+ <instructions>
39
+ The conversation history below will be replaced with the context you extract in this step. Because of this, you must do your very best to extract and record all of the most important context from the conversation history.
40
+ You want to ensure that you don't repeat any actions you've already completed, so the context you extract from the conversation history should be focused on the most important information to your overall goal.
41
+ </instructions>
42
+
43
+ The user will message you with the full message history you'll be extracting context from, to then replace. Carefully read over it all, and think deeply about what information is most important to your overall goal that should be saved:
44
+
45
+ With all of this in mind, please carefully read over the entire conversation history, and extract the most important and relevant context to replace it so that you can free up space in the conversation history.
46
+ Respond ONLY with the extracted context. Do not include any additional information, or text before or after the extracted context.
47
+
48
+ <messages>
49
+ Messages to summarize:
50
+ {messages}
51
+ </messages>"""
52
+
53
+ SUMMARY_PREFIX = "## Previous conversation summary:"
54
+
55
+ _DEFAULT_MESSAGES_TO_KEEP = 20
56
+ _DEFAULT_TRIM_TOKEN_LIMIT = 4000
57
+ _DEFAULT_FALLBACK_MESSAGE_COUNT = 15
58
+ _SEARCH_RANGE_FOR_TOOL_PAIRS = 5
59
+
60
+
61
+ class SummarizationMiddleware(AgentMiddleware):
62
+ """Middleware that summarizes conversation history when token limits are approached.
63
+
64
+ This middleware monitors message token counts and automatically summarizes older
65
+ messages when a threshold is reached, preserving recent messages and maintaining
66
+ context continuity by ensuring AI/Tool message pairs remain together.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ model: str | BaseChatModel,
72
+ max_tokens_before_summary: int | None = None,
73
+ messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
74
+ token_counter: TokenCounter = count_tokens_approximately,
75
+ summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
76
+ summary_prefix: str = SUMMARY_PREFIX,
77
+ ) -> None:
78
+ """Initialize the summarization middleware.
79
+
80
+ Args:
81
+ model: The language model to use for generating summaries.
82
+ max_tokens_before_summary: Token threshold to trigger summarization.
83
+ If None, summarization is disabled.
84
+ messages_to_keep: Number of recent messages to preserve after summarization.
85
+ token_counter: Function to count tokens in messages.
86
+ summary_prompt: Prompt template for generating summaries.
87
+ summary_prefix: Prefix added to system message when including summary.
88
+ """
89
+ super().__init__()
90
+
91
+ if isinstance(model, str):
92
+ model = init_chat_model(model)
93
+
94
+ self.model = model
95
+ self.max_tokens_before_summary = max_tokens_before_summary
96
+ self.messages_to_keep = messages_to_keep
97
+ self.token_counter = token_counter
98
+ self.summary_prompt = summary_prompt
99
+ self.summary_prefix = summary_prefix
100
+
101
+ def before_model(self, state: AgentState) -> dict[str, Any] | None:
102
+ """Process messages before model invocation, potentially triggering summarization."""
103
+ messages = state["messages"]
104
+ self._ensure_message_ids(messages)
105
+
106
+ total_tokens = self.token_counter(messages)
107
+ if (
108
+ self.max_tokens_before_summary is not None
109
+ and total_tokens < self.max_tokens_before_summary
110
+ ):
111
+ return None
112
+
113
+ cutoff_index = self._find_safe_cutoff(messages)
114
+
115
+ if cutoff_index <= 0:
116
+ return None
117
+
118
+ messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
119
+
120
+ summary = self._create_summary(messages_to_summarize)
121
+ new_messages = self._build_new_messages(summary)
122
+
123
+ return {
124
+ "messages": [
125
+ RemoveMessage(id=REMOVE_ALL_MESSAGES),
126
+ *new_messages,
127
+ *preserved_messages,
128
+ ]
129
+ }
130
+
131
+ def _build_new_messages(self, summary: str) -> list[HumanMessage]:
132
+ return [
133
+ HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
134
+ ]
135
+
136
+ def _ensure_message_ids(self, messages: list[AnyMessage]) -> None:
137
+ """Ensure all messages have unique IDs for the add_messages reducer."""
138
+ for msg in messages:
139
+ if msg.id is None:
140
+ msg.id = str(uuid.uuid4())
141
+
142
+ def _partition_messages(
143
+ self,
144
+ conversation_messages: list[AnyMessage],
145
+ cutoff_index: int,
146
+ ) -> tuple[list[AnyMessage], list[AnyMessage]]:
147
+ """Partition messages into those to summarize and those to preserve."""
148
+ messages_to_summarize = conversation_messages[:cutoff_index]
149
+ preserved_messages = conversation_messages[cutoff_index:]
150
+
151
+ return messages_to_summarize, preserved_messages
152
+
153
+ def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
154
+ """Find safe cutoff point that preserves AI/Tool message pairs.
155
+
156
+ Returns the index where messages can be safely cut without separating
157
+ related AI and Tool messages. Returns 0 if no safe cutoff is found.
158
+ """
159
+ if len(messages) <= self.messages_to_keep:
160
+ return 0
161
+
162
+ target_cutoff = len(messages) - self.messages_to_keep
163
+
164
+ for i in range(target_cutoff, -1, -1):
165
+ if self._is_safe_cutoff_point(messages, i):
166
+ return i
167
+
168
+ return 0
169
+
170
+ def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
171
+ """Check if cutting at index would separate AI/Tool message pairs."""
172
+ if cutoff_index >= len(messages):
173
+ return True
174
+
175
+ search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
176
+ search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
177
+
178
+ for i in range(search_start, search_end):
179
+ if not self._has_tool_calls(messages[i]):
180
+ continue
181
+
182
+ tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
183
+ if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
184
+ return False
185
+
186
+ return True
187
+
188
+ def _has_tool_calls(self, message: AnyMessage) -> bool:
189
+ """Check if message is an AI message with tool calls."""
190
+ return (
191
+ isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
192
+ )
193
+
194
+ def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
195
+ """Extract tool call IDs from an AI message."""
196
+ tool_call_ids = set()
197
+ for tc in ai_message.tool_calls:
198
+ call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
199
+ if call_id is not None:
200
+ tool_call_ids.add(call_id)
201
+ return tool_call_ids
202
+
203
+ def _cutoff_separates_tool_pair(
204
+ self,
205
+ messages: list[AnyMessage],
206
+ ai_message_index: int,
207
+ cutoff_index: int,
208
+ tool_call_ids: set[str],
209
+ ) -> bool:
210
+ """Check if cutoff separates an AI message from its corresponding tool messages."""
211
+ for j in range(ai_message_index + 1, len(messages)):
212
+ message = messages[j]
213
+ if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
214
+ ai_before_cutoff = ai_message_index < cutoff_index
215
+ tool_before_cutoff = j < cutoff_index
216
+ if ai_before_cutoff != tool_before_cutoff:
217
+ return True
218
+ return False
219
+
220
+ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
221
+ """Generate summary for the given messages."""
222
+ if not messages_to_summarize:
223
+ return "No previous conversation history."
224
+
225
+ trimmed_messages = self._trim_messages_for_summary(messages_to_summarize)
226
+ if not trimmed_messages:
227
+ return "Previous conversation was too long to summarize."
228
+
229
+ try:
230
+ response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
231
+ return cast("str", response.content).strip()
232
+ except Exception as e: # noqa: BLE001
233
+ return f"Error generating summary: {e!s}"
234
+
235
+ def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
236
+ """Trim messages to fit within summary generation limits."""
237
+ try:
238
+ return trim_messages(
239
+ messages,
240
+ max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
241
+ token_counter=self.token_counter,
242
+ start_on="human",
243
+ strategy="last",
244
+ allow_partial=True,
245
+ include_system=True,
246
+ )
247
+ except Exception: # noqa: BLE001
248
+ return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
@@ -0,0 +1,78 @@
1
+ """Types for middleware and agents."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, cast
7
+
8
+ # needed as top level import for pydantic schema generation on AgentState
9
+ from langchain_core.messages import AnyMessage # noqa: TC002
10
+ from langgraph.channels.ephemeral_value import EphemeralValue
11
+ from langgraph.graph.message import Messages, add_messages
12
+ from typing_extensions import NotRequired, Required, TypedDict, TypeVar
13
+
14
+ if TYPE_CHECKING:
15
+ from langchain_core.language_models.chat_models import BaseChatModel
16
+ from langchain_core.tools import BaseTool
17
+
18
+ from langchain.agents.structured_output import ResponseFormat
19
+
20
+ JumpTo = Literal["tools", "model", "__end__"]
21
+ """Destination to jump to when a middleware node returns."""
22
+
23
+ ResponseT = TypeVar("ResponseT")
24
+
25
+
26
+ @dataclass
27
+ class ModelRequest:
28
+ """Model request information for the agent."""
29
+
30
+ model: BaseChatModel
31
+ system_prompt: str | None
32
+ messages: list[AnyMessage] # excluding system prompt
33
+ tool_choice: Any | None
34
+ tools: list[BaseTool]
35
+ response_format: ResponseFormat | None
36
+ model_settings: dict[str, Any] = field(default_factory=dict)
37
+
38
+
39
+ class AgentState(TypedDict, Generic[ResponseT]):
40
+ """State schema for the agent."""
41
+
42
+ messages: Required[Annotated[list[AnyMessage], add_messages]]
43
+ model_request: NotRequired[Annotated[ModelRequest | None, EphemeralValue]]
44
+ jump_to: NotRequired[Annotated[JumpTo | None, EphemeralValue]]
45
+ response: NotRequired[ResponseT]
46
+
47
+
48
+ class PublicAgentState(TypedDict, Generic[ResponseT]):
49
+ """Input / output schema for the agent."""
50
+
51
+ messages: Required[Messages]
52
+ response: NotRequired[ResponseT]
53
+
54
+
55
+ StateT = TypeVar("StateT", bound=AgentState)
56
+
57
+
58
+ class AgentMiddleware(Generic[StateT]):
59
+ """Base middleware class for an agent.
60
+
61
+ Subclass this and implement any of the defined methods to customize agent behavior between steps in the main agent loop.
62
+ """
63
+
64
+ state_schema: type[StateT] = cast("type[StateT]", AgentState)
65
+ """The schema for state passed to the middleware nodes."""
66
+
67
+ tools: list[BaseTool]
68
+ """Additional tools registered by the middleware."""
69
+
70
+ def before_model(self, state: StateT) -> dict[str, Any] | None:
71
+ """Logic to run before the model is called."""
72
+
73
+ def modify_model_request(self, request: ModelRequest, state: StateT) -> ModelRequest: # noqa: ARG002
74
+ """Logic to modify request kwargs before the model is called."""
75
+ return request
76
+
77
+ def after_model(self, state: StateT) -> dict[str, Any] | None:
78
+ """Logic to run after the model is called."""