dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Summarization middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
This module provides a LoggingSummarizationMiddleware that extends LangChain's
|
|
5
|
+
built-in SummarizationMiddleware with logging capabilities, and provides
|
|
6
|
+
helper utilities for creating summarization middleware from DAO AI configuration.
|
|
7
|
+
|
|
8
|
+
The middleware automatically:
|
|
9
|
+
- Summarizes older messages using a separate LLM call when thresholds are exceeded
|
|
10
|
+
- Replaces them with a summary message in State (permanently)
|
|
11
|
+
- Keeps recent messages intact for context
|
|
12
|
+
- Logs when summarization is triggered and completed
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
from dao_ai.middleware import create_summarization_middleware
|
|
16
|
+
from dao_ai.config import ChatHistoryModel, LLMModel
|
|
17
|
+
|
|
18
|
+
chat_history = ChatHistoryModel(
|
|
19
|
+
model=LLMModel(name="gpt-4o-mini"),
|
|
20
|
+
max_tokens=256,
|
|
21
|
+
max_tokens_before_summary=4000,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
middleware = create_summarization_middleware(chat_history)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from typing import Any, Tuple
|
|
28
|
+
|
|
29
|
+
from langchain.agents.middleware import SummarizationMiddleware
|
|
30
|
+
from langchain_core.language_models import LanguageModelLike
|
|
31
|
+
from langchain_core.messages import BaseMessage
|
|
32
|
+
from langgraph.runtime import Runtime
|
|
33
|
+
from loguru import logger
|
|
34
|
+
|
|
35
|
+
from dao_ai.config import ChatHistoryModel
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"SummarizationMiddleware",
|
|
39
|
+
"LoggingSummarizationMiddleware",
|
|
40
|
+
"create_summarization_middleware",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LoggingSummarizationMiddleware(SummarizationMiddleware):
|
|
45
|
+
"""
|
|
46
|
+
SummarizationMiddleware with logging for when summarization occurs.
|
|
47
|
+
|
|
48
|
+
This extends LangChain's SummarizationMiddleware to add logging at INFO level
|
|
49
|
+
when summarization is triggered and completed, providing visibility into
|
|
50
|
+
when conversation history is being summarized.
|
|
51
|
+
|
|
52
|
+
Logs include:
|
|
53
|
+
- Original message count and approximate token count (before summarization)
|
|
54
|
+
- New message count and approximate token count (after summarization)
|
|
55
|
+
- Number of messages that were summarized
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def _log_summarization(
|
|
59
|
+
self,
|
|
60
|
+
original_message_count: int,
|
|
61
|
+
original_token_count: int,
|
|
62
|
+
result_messages: list[Any],
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Log summarization details with before/after metrics."""
|
|
65
|
+
# Result messages: [RemoveMessage, summary_message, ...preserved_messages]
|
|
66
|
+
# New message count excludes RemoveMessage (index 0)
|
|
67
|
+
new_messages = [
|
|
68
|
+
msg for msg in result_messages if not self._is_remove_message(msg)
|
|
69
|
+
]
|
|
70
|
+
new_message_count = len(new_messages)
|
|
71
|
+
new_token_count = self.token_counter(new_messages) if new_messages else 0
|
|
72
|
+
|
|
73
|
+
# Calculate how many messages were summarized
|
|
74
|
+
# preserved = new_messages - 1 (the summary message)
|
|
75
|
+
preserved_count = max(0, new_message_count - 1)
|
|
76
|
+
summarized_count = original_message_count - preserved_count
|
|
77
|
+
|
|
78
|
+
logger.info(
|
|
79
|
+
"Conversation summarized",
|
|
80
|
+
before_messages=original_message_count,
|
|
81
|
+
before_tokens=original_token_count,
|
|
82
|
+
after_messages=new_message_count,
|
|
83
|
+
after_tokens=new_token_count,
|
|
84
|
+
summarized_messages=summarized_count,
|
|
85
|
+
)
|
|
86
|
+
logger.debug(
|
|
87
|
+
"Summarization details",
|
|
88
|
+
trigger=self.trigger,
|
|
89
|
+
keep=self.keep,
|
|
90
|
+
preserved_messages=preserved_count,
|
|
91
|
+
token_reduction=original_token_count - new_token_count,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _is_remove_message(self, msg: Any) -> bool:
|
|
95
|
+
"""Check if a message is a RemoveMessage."""
|
|
96
|
+
return type(msg).__name__ == "RemoveMessage"
|
|
97
|
+
|
|
98
|
+
def before_model(
|
|
99
|
+
self, state: dict[str, Any], runtime: Runtime
|
|
100
|
+
) -> dict[str, Any] | None:
|
|
101
|
+
"""Process messages before model invocation, logging when summarization occurs."""
|
|
102
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
103
|
+
original_message_count = len(messages)
|
|
104
|
+
original_token_count = self.token_counter(messages) if messages else 0
|
|
105
|
+
|
|
106
|
+
result = super().before_model(state, runtime)
|
|
107
|
+
|
|
108
|
+
if result is not None:
|
|
109
|
+
result_messages = result.get("messages", [])
|
|
110
|
+
self._log_summarization(
|
|
111
|
+
original_message_count,
|
|
112
|
+
original_token_count,
|
|
113
|
+
result_messages,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return result
|
|
117
|
+
|
|
118
|
+
async def abefore_model(
|
|
119
|
+
self, state: dict[str, Any], runtime: Runtime
|
|
120
|
+
) -> dict[str, Any] | None:
|
|
121
|
+
"""Process messages before model invocation (async), logging when summarization occurs."""
|
|
122
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
123
|
+
original_message_count = len(messages)
|
|
124
|
+
original_token_count = self.token_counter(messages) if messages else 0
|
|
125
|
+
|
|
126
|
+
result = await super().abefore_model(state, runtime)
|
|
127
|
+
|
|
128
|
+
if result is not None:
|
|
129
|
+
result_messages = result.get("messages", [])
|
|
130
|
+
self._log_summarization(
|
|
131
|
+
original_message_count,
|
|
132
|
+
original_token_count,
|
|
133
|
+
result_messages,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def create_summarization_middleware(
|
|
140
|
+
chat_history: ChatHistoryModel,
|
|
141
|
+
) -> LoggingSummarizationMiddleware:
|
|
142
|
+
"""
|
|
143
|
+
Create a LoggingSummarizationMiddleware from DAO AI ChatHistoryModel configuration.
|
|
144
|
+
|
|
145
|
+
This factory function creates a LoggingSummarizationMiddleware instance
|
|
146
|
+
configured according to the DAO AI ChatHistoryModel settings. The middleware
|
|
147
|
+
includes logging at INFO level when summarization is triggered.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
chat_history: ChatHistoryModel configuration for summarization
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
LoggingSummarizationMiddleware configured with the specified parameters
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
from dao_ai.config import ChatHistoryModel, LLMModel
|
|
157
|
+
|
|
158
|
+
chat_history = ChatHistoryModel(
|
|
159
|
+
model=LLMModel(name="gpt-4o-mini"),
|
|
160
|
+
max_tokens=256,
|
|
161
|
+
max_tokens_before_summary=4000,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
middleware = create_summarization_middleware(chat_history)
|
|
165
|
+
"""
|
|
166
|
+
logger.debug(
|
|
167
|
+
"Creating summarization middleware",
|
|
168
|
+
max_tokens=chat_history.max_tokens,
|
|
169
|
+
max_tokens_before_summary=chat_history.max_tokens_before_summary,
|
|
170
|
+
max_messages_before_summary=chat_history.max_messages_before_summary,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Get the LLM model
|
|
174
|
+
model: LanguageModelLike = chat_history.model.as_chat_model()
|
|
175
|
+
|
|
176
|
+
# Determine trigger condition
|
|
177
|
+
# LangChain uses ("tokens", value) or ("messages", value) tuples
|
|
178
|
+
trigger: Tuple[str, int]
|
|
179
|
+
if chat_history.max_tokens_before_summary:
|
|
180
|
+
trigger = ("tokens", chat_history.max_tokens_before_summary)
|
|
181
|
+
elif chat_history.max_messages_before_summary:
|
|
182
|
+
trigger = ("messages", chat_history.max_messages_before_summary)
|
|
183
|
+
else:
|
|
184
|
+
# Default to a reasonable token threshold
|
|
185
|
+
trigger = ("tokens", chat_history.max_tokens * 10)
|
|
186
|
+
|
|
187
|
+
# Determine keep condition - how many recent messages/tokens to preserve
|
|
188
|
+
# Default to keeping enough for context
|
|
189
|
+
keep: Tuple[str, int] = ("tokens", chat_history.max_tokens)
|
|
190
|
+
|
|
191
|
+
logger.info("Summarization middleware configured", trigger=trigger, keep=keep)
|
|
192
|
+
|
|
193
|
+
return LoggingSummarizationMiddleware(
|
|
194
|
+
model=model,
|
|
195
|
+
trigger=trigger,
|
|
196
|
+
keep=keep,
|
|
197
|
+
)
|