dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +342 -58
- dao_ai/config.py +1610 -380
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +158 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +233 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +240 -161
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +279 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +584 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai/vector_search.py +37 -0
- dao_ai-0.1.5.dist-info/METADATA +489 -0
- dao_ai-0.1.5.dist-info/RECORD +70 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model retry middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
Automatically retries failed model (LLM) calls with configurable exponential backoff.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
from dao_ai.middleware import create_model_retry_middleware
|
|
8
|
+
|
|
9
|
+
# Retry failed model calls with exponential backoff
|
|
10
|
+
middleware = create_model_retry_middleware(
|
|
11
|
+
max_retries=3,
|
|
12
|
+
backoff_factor=2.0,
|
|
13
|
+
initial_delay=1.0,
|
|
14
|
+
)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any, Callable, Literal
|
|
20
|
+
|
|
21
|
+
from langchain.agents.middleware import ModelRetryMiddleware
|
|
22
|
+
from loguru import logger
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"ModelRetryMiddleware",
|
|
26
|
+
"create_model_retry_middleware",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_model_retry_middleware(
|
|
31
|
+
max_retries: int = 3,
|
|
32
|
+
backoff_factor: float = 2.0,
|
|
33
|
+
initial_delay: float = 1.0,
|
|
34
|
+
max_delay: float | None = None,
|
|
35
|
+
jitter: bool = False,
|
|
36
|
+
retry_on: tuple[type[Exception], ...] | Callable[[Exception], bool] | None = None,
|
|
37
|
+
on_failure: Literal["continue", "error"] | Callable[[Exception], str] = "continue",
|
|
38
|
+
) -> ModelRetryMiddleware:
|
|
39
|
+
"""
|
|
40
|
+
Create a ModelRetryMiddleware for automatic model call retries.
|
|
41
|
+
|
|
42
|
+
Handles transient failures in model API calls with exponential backoff.
|
|
43
|
+
Useful for handling rate limits, network issues, and temporary outages.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
max_retries: Max retry attempts after initial call. Default 3.
|
|
47
|
+
backoff_factor: Multiplier for exponential backoff. Default 2.0.
|
|
48
|
+
Delay = initial_delay * (backoff_factor ** retry_number)
|
|
49
|
+
Set to 0.0 for constant delay.
|
|
50
|
+
initial_delay: Initial delay in seconds before first retry. Default 1.0.
|
|
51
|
+
max_delay: Max delay in seconds (caps exponential growth). None = no cap.
|
|
52
|
+
jitter: Add ±25% random jitter to avoid thundering herd. Default False.
|
|
53
|
+
retry_on: When to retry:
|
|
54
|
+
- None: Retry on all errors (default)
|
|
55
|
+
- tuple of Exception types: Retry only on these
|
|
56
|
+
- callable: Function(exception) -> bool for custom logic
|
|
57
|
+
on_failure: Behavior when all retries exhausted:
|
|
58
|
+
- "continue": Return AIMessage with error, let agent continue (default)
|
|
59
|
+
- "error": Re-raise exception, stop execution
|
|
60
|
+
- callable: Function(exception) -> str for custom error message
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List containing ModelRetryMiddleware instance
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
# Basic retry with defaults
|
|
67
|
+
retry = create_model_retry_middleware()
|
|
68
|
+
|
|
69
|
+
# Custom backoff for rate limits
|
|
70
|
+
retry = create_model_retry_middleware(
|
|
71
|
+
max_retries=5,
|
|
72
|
+
backoff_factor=2.0,
|
|
73
|
+
initial_delay=1.0,
|
|
74
|
+
max_delay=60.0,
|
|
75
|
+
jitter=True,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Retry only on specific exceptions, fail hard
|
|
79
|
+
retry = create_model_retry_middleware(
|
|
80
|
+
max_retries=3,
|
|
81
|
+
retry_on=(RateLimitError, TimeoutError),
|
|
82
|
+
on_failure="error",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Custom retry logic
|
|
86
|
+
def should_retry(error: Exception) -> bool:
|
|
87
|
+
return "rate_limit" in str(error).lower()
|
|
88
|
+
|
|
89
|
+
retry = create_model_retry_middleware(
|
|
90
|
+
max_retries=5,
|
|
91
|
+
retry_on=should_retry,
|
|
92
|
+
)
|
|
93
|
+
"""
|
|
94
|
+
logger.debug(
|
|
95
|
+
"Creating model retry middleware",
|
|
96
|
+
max_retries=max_retries,
|
|
97
|
+
backoff_factor=backoff_factor,
|
|
98
|
+
initial_delay=initial_delay,
|
|
99
|
+
max_delay=max_delay,
|
|
100
|
+
jitter=jitter,
|
|
101
|
+
on_failure=on_failure if isinstance(on_failure, str) else "custom",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Build kwargs
|
|
105
|
+
kwargs: dict[str, Any] = {
|
|
106
|
+
"max_retries": max_retries,
|
|
107
|
+
"backoff_factor": backoff_factor,
|
|
108
|
+
"initial_delay": initial_delay,
|
|
109
|
+
"on_failure": on_failure,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if max_delay is not None:
|
|
113
|
+
kwargs["max_delay"] = max_delay
|
|
114
|
+
|
|
115
|
+
if jitter:
|
|
116
|
+
kwargs["jitter"] = jitter
|
|
117
|
+
|
|
118
|
+
if retry_on is not None:
|
|
119
|
+
kwargs["retry_on"] = retry_on
|
|
120
|
+
|
|
121
|
+
return ModelRetryMiddleware(**kwargs)
|
dao_ai/middleware/pii.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PII detection middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
Detects and handles Personally Identifiable Information (PII) in conversations
|
|
5
|
+
using configurable strategies (redact, mask, hash, block).
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from dao_ai.middleware import create_pii_middleware
|
|
9
|
+
|
|
10
|
+
# Redact emails in user input
|
|
11
|
+
middleware = create_pii_middleware(
|
|
12
|
+
pii_type="email",
|
|
13
|
+
strategy="redact",
|
|
14
|
+
apply_to_input=True,
|
|
15
|
+
)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Any, Callable, Literal, Pattern
|
|
21
|
+
|
|
22
|
+
from langchain.agents.middleware import PIIMiddleware
|
|
23
|
+
from loguru import logger
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"PIIMiddleware",
|
|
27
|
+
"create_pii_middleware",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# Type alias for PII detector
|
|
31
|
+
PIIDetector = str | Pattern[str] | Callable[[str], list[dict[str, str | int]]]
|
|
32
|
+
|
|
33
|
+
# Built-in PII types
|
|
34
|
+
BUILTIN_PII_TYPES = frozenset({"email", "credit_card", "ip", "mac_address", "url"})
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_pii_middleware(
|
|
38
|
+
pii_type: str,
|
|
39
|
+
strategy: Literal["redact", "mask", "hash", "block"] = "redact",
|
|
40
|
+
detector: PIIDetector | None = None,
|
|
41
|
+
apply_to_input: bool = True,
|
|
42
|
+
apply_to_output: bool = False,
|
|
43
|
+
apply_to_tool_results: bool = False,
|
|
44
|
+
) -> PIIMiddleware:
|
|
45
|
+
"""
|
|
46
|
+
Create a PIIMiddleware for detecting and handling PII.
|
|
47
|
+
|
|
48
|
+
Detects Personally Identifiable Information in conversations and handles
|
|
49
|
+
it according to the specified strategy. Useful for compliance, privacy,
|
|
50
|
+
and sanitizing logs.
|
|
51
|
+
|
|
52
|
+
Built-in PII types:
|
|
53
|
+
- email: Email addresses
|
|
54
|
+
- credit_card: Credit card numbers (Luhn validated)
|
|
55
|
+
- ip: IP addresses
|
|
56
|
+
- mac_address: MAC addresses
|
|
57
|
+
- url: URLs
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
pii_type: Type of PII to detect. Use built-in types (email, credit_card,
|
|
61
|
+
ip, mac_address, url) or custom type names with a detector.
|
|
62
|
+
strategy: How to handle detected PII:
|
|
63
|
+
- "redact": Replace with [REDACTED_{TYPE}] (default)
|
|
64
|
+
- "mask": Partially obscure (e.g., ****-****-****-1234)
|
|
65
|
+
- "hash": Replace with deterministic hash
|
|
66
|
+
- "block": Raise exception when detected
|
|
67
|
+
detector: Custom detector for non-built-in types. Can be:
|
|
68
|
+
- str: Regex pattern string
|
|
69
|
+
- re.Pattern: Compiled regex pattern
|
|
70
|
+
- Callable: Function(content: str) -> list[dict] with keys:
|
|
71
|
+
- text: The matched text
|
|
72
|
+
- start: Start index
|
|
73
|
+
- end: End index
|
|
74
|
+
Default None (uses built-in detector for built-in types).
|
|
75
|
+
apply_to_input: Check user messages before model call. Default True.
|
|
76
|
+
apply_to_output: Check AI messages after model call. Default False.
|
|
77
|
+
apply_to_tool_results: Check tool results after execution. Default False.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List containing PIIMiddleware instance
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If custom pii_type without detector, or invalid strategy
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
# Redact emails in input
|
|
87
|
+
email_redactor = create_pii_middleware(
|
|
88
|
+
pii_type="email",
|
|
89
|
+
strategy="redact",
|
|
90
|
+
apply_to_input=True,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Mask credit cards
|
|
94
|
+
card_masker = create_pii_middleware(
|
|
95
|
+
pii_type="credit_card",
|
|
96
|
+
strategy="mask",
|
|
97
|
+
apply_to_input=True,
|
|
98
|
+
apply_to_output=True,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Block API keys with custom regex
|
|
102
|
+
api_key_blocker = create_pii_middleware(
|
|
103
|
+
pii_type="api_key",
|
|
104
|
+
detector=r"sk-[a-zA-Z0-9]{32}",
|
|
105
|
+
strategy="block",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Custom SSN detector with validation
|
|
109
|
+
def detect_ssn(content: str) -> list[dict]:
|
|
110
|
+
matches = []
|
|
111
|
+
pattern = r"\\d{3}-\\d{2}-\\d{4}"
|
|
112
|
+
for match in re.finditer(pattern, content):
|
|
113
|
+
ssn = match.group(0)
|
|
114
|
+
first_three = int(ssn[:3])
|
|
115
|
+
if first_three not in [0, 666] and not (900 <= first_three <= 999):
|
|
116
|
+
matches.append({
|
|
117
|
+
"text": ssn,
|
|
118
|
+
"start": match.start(),
|
|
119
|
+
"end": match.end(),
|
|
120
|
+
})
|
|
121
|
+
return matches
|
|
122
|
+
|
|
123
|
+
ssn_hasher = create_pii_middleware(
|
|
124
|
+
pii_type="ssn",
|
|
125
|
+
detector=detect_ssn,
|
|
126
|
+
strategy="hash",
|
|
127
|
+
)
|
|
128
|
+
"""
|
|
129
|
+
# Validate: custom types require detector
|
|
130
|
+
if pii_type not in BUILTIN_PII_TYPES and detector is None:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Custom PII type '{pii_type}' requires a detector. "
|
|
133
|
+
f"Built-in types are: {', '.join(sorted(BUILTIN_PII_TYPES))}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
logger.debug(
|
|
137
|
+
"Creating PII middleware",
|
|
138
|
+
pii_type=pii_type,
|
|
139
|
+
strategy=strategy,
|
|
140
|
+
has_custom_detector=detector is not None,
|
|
141
|
+
apply_to_input=apply_to_input,
|
|
142
|
+
apply_to_output=apply_to_output,
|
|
143
|
+
apply_to_tool_results=apply_to_tool_results,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Build kwargs
|
|
147
|
+
kwargs: dict[str, Any] = {
|
|
148
|
+
"strategy": strategy,
|
|
149
|
+
"apply_to_input": apply_to_input,
|
|
150
|
+
"apply_to_output": apply_to_output,
|
|
151
|
+
"apply_to_tool_results": apply_to_tool_results,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
if detector is not None:
|
|
155
|
+
kwargs["detector"] = detector
|
|
156
|
+
|
|
157
|
+
return PIIMiddleware(pii_type, **kwargs)
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Summarization middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
This module provides a LoggingSummarizationMiddleware that extends LangChain's
|
|
5
|
+
built-in SummarizationMiddleware with logging capabilities, and provides
|
|
6
|
+
helper utilities for creating summarization middleware from DAO AI configuration.
|
|
7
|
+
|
|
8
|
+
The middleware automatically:
|
|
9
|
+
- Summarizes older messages using a separate LLM call when thresholds are exceeded
|
|
10
|
+
- Replaces them with a summary message in State (permanently)
|
|
11
|
+
- Keeps recent messages intact for context
|
|
12
|
+
- Logs when summarization is triggered and completed
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
from dao_ai.middleware import create_summarization_middleware
|
|
16
|
+
from dao_ai.config import ChatHistoryModel, LLMModel
|
|
17
|
+
|
|
18
|
+
chat_history = ChatHistoryModel(
|
|
19
|
+
model=LLMModel(name="gpt-4o-mini"),
|
|
20
|
+
max_tokens=256,
|
|
21
|
+
max_tokens_before_summary=4000,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
middleware = create_summarization_middleware(chat_history)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from typing import Any, Tuple
|
|
28
|
+
|
|
29
|
+
from langchain.agents.middleware import SummarizationMiddleware
|
|
30
|
+
from langchain_core.language_models import LanguageModelLike
|
|
31
|
+
from langchain_core.messages import BaseMessage
|
|
32
|
+
from langgraph.runtime import Runtime
|
|
33
|
+
from loguru import logger
|
|
34
|
+
|
|
35
|
+
from dao_ai.config import ChatHistoryModel
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"SummarizationMiddleware",
|
|
39
|
+
"LoggingSummarizationMiddleware",
|
|
40
|
+
"create_summarization_middleware",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LoggingSummarizationMiddleware(SummarizationMiddleware):
|
|
45
|
+
"""
|
|
46
|
+
SummarizationMiddleware with logging for when summarization occurs.
|
|
47
|
+
|
|
48
|
+
This extends LangChain's SummarizationMiddleware to add logging at INFO level
|
|
49
|
+
when summarization is triggered and completed, providing visibility into
|
|
50
|
+
when conversation history is being summarized.
|
|
51
|
+
|
|
52
|
+
Logs include:
|
|
53
|
+
- Original message count and approximate token count (before summarization)
|
|
54
|
+
- New message count and approximate token count (after summarization)
|
|
55
|
+
- Number of messages that were summarized
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def _log_summarization(
|
|
59
|
+
self,
|
|
60
|
+
original_message_count: int,
|
|
61
|
+
original_token_count: int,
|
|
62
|
+
result_messages: list[Any],
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Log summarization details with before/after metrics."""
|
|
65
|
+
# Result messages: [RemoveMessage, summary_message, ...preserved_messages]
|
|
66
|
+
# New message count excludes RemoveMessage (index 0)
|
|
67
|
+
new_messages = [
|
|
68
|
+
msg for msg in result_messages if not self._is_remove_message(msg)
|
|
69
|
+
]
|
|
70
|
+
new_message_count = len(new_messages)
|
|
71
|
+
new_token_count = self.token_counter(new_messages) if new_messages else 0
|
|
72
|
+
|
|
73
|
+
# Calculate how many messages were summarized
|
|
74
|
+
# preserved = new_messages - 1 (the summary message)
|
|
75
|
+
preserved_count = max(0, new_message_count - 1)
|
|
76
|
+
summarized_count = original_message_count - preserved_count
|
|
77
|
+
|
|
78
|
+
logger.info(
|
|
79
|
+
"Conversation summarized",
|
|
80
|
+
before_messages=original_message_count,
|
|
81
|
+
before_tokens=original_token_count,
|
|
82
|
+
after_messages=new_message_count,
|
|
83
|
+
after_tokens=new_token_count,
|
|
84
|
+
summarized_messages=summarized_count,
|
|
85
|
+
)
|
|
86
|
+
logger.debug(
|
|
87
|
+
"Summarization details",
|
|
88
|
+
trigger=self.trigger,
|
|
89
|
+
keep=self.keep,
|
|
90
|
+
preserved_messages=preserved_count,
|
|
91
|
+
token_reduction=original_token_count - new_token_count,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _is_remove_message(self, msg: Any) -> bool:
|
|
95
|
+
"""Check if a message is a RemoveMessage."""
|
|
96
|
+
return type(msg).__name__ == "RemoveMessage"
|
|
97
|
+
|
|
98
|
+
def before_model(
|
|
99
|
+
self, state: dict[str, Any], runtime: Runtime
|
|
100
|
+
) -> dict[str, Any] | None:
|
|
101
|
+
"""Process messages before model invocation, logging when summarization occurs."""
|
|
102
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
103
|
+
original_message_count = len(messages)
|
|
104
|
+
original_token_count = self.token_counter(messages) if messages else 0
|
|
105
|
+
|
|
106
|
+
result = super().before_model(state, runtime)
|
|
107
|
+
|
|
108
|
+
if result is not None:
|
|
109
|
+
result_messages = result.get("messages", [])
|
|
110
|
+
self._log_summarization(
|
|
111
|
+
original_message_count,
|
|
112
|
+
original_token_count,
|
|
113
|
+
result_messages,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return result
|
|
117
|
+
|
|
118
|
+
async def abefore_model(
|
|
119
|
+
self, state: dict[str, Any], runtime: Runtime
|
|
120
|
+
) -> dict[str, Any] | None:
|
|
121
|
+
"""Process messages before model invocation (async), logging when summarization occurs."""
|
|
122
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
123
|
+
original_message_count = len(messages)
|
|
124
|
+
original_token_count = self.token_counter(messages) if messages else 0
|
|
125
|
+
|
|
126
|
+
result = await super().abefore_model(state, runtime)
|
|
127
|
+
|
|
128
|
+
if result is not None:
|
|
129
|
+
result_messages = result.get("messages", [])
|
|
130
|
+
self._log_summarization(
|
|
131
|
+
original_message_count,
|
|
132
|
+
original_token_count,
|
|
133
|
+
result_messages,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def create_summarization_middleware(
|
|
140
|
+
chat_history: ChatHistoryModel,
|
|
141
|
+
) -> LoggingSummarizationMiddleware:
|
|
142
|
+
"""
|
|
143
|
+
Create a LoggingSummarizationMiddleware from DAO AI ChatHistoryModel configuration.
|
|
144
|
+
|
|
145
|
+
This factory function creates a LoggingSummarizationMiddleware instance
|
|
146
|
+
configured according to the DAO AI ChatHistoryModel settings. The middleware
|
|
147
|
+
includes logging at INFO level when summarization is triggered.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
chat_history: ChatHistoryModel configuration for summarization
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
List containing LoggingSummarizationMiddleware configured with the specified parameters
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
from dao_ai.config import ChatHistoryModel, LLMModel
|
|
157
|
+
|
|
158
|
+
chat_history = ChatHistoryModel(
|
|
159
|
+
model=LLMModel(name="gpt-4o-mini"),
|
|
160
|
+
max_tokens=256,
|
|
161
|
+
max_tokens_before_summary=4000,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
middleware = create_summarization_middleware(chat_history)
|
|
165
|
+
"""
|
|
166
|
+
logger.debug(
|
|
167
|
+
"Creating summarization middleware",
|
|
168
|
+
max_tokens=chat_history.max_tokens,
|
|
169
|
+
max_tokens_before_summary=chat_history.max_tokens_before_summary,
|
|
170
|
+
max_messages_before_summary=chat_history.max_messages_before_summary,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Get the LLM model
|
|
174
|
+
model: LanguageModelLike = chat_history.model.as_chat_model()
|
|
175
|
+
|
|
176
|
+
# Determine trigger condition
|
|
177
|
+
# LangChain uses ("tokens", value) or ("messages", value) tuples
|
|
178
|
+
trigger: Tuple[str, int]
|
|
179
|
+
if chat_history.max_tokens_before_summary:
|
|
180
|
+
trigger = ("tokens", chat_history.max_tokens_before_summary)
|
|
181
|
+
elif chat_history.max_messages_before_summary:
|
|
182
|
+
trigger = ("messages", chat_history.max_messages_before_summary)
|
|
183
|
+
else:
|
|
184
|
+
# Default to a reasonable token threshold
|
|
185
|
+
trigger = ("tokens", chat_history.max_tokens * 10)
|
|
186
|
+
|
|
187
|
+
# Determine keep condition - how many recent messages/tokens to preserve
|
|
188
|
+
# Default to keeping enough for context
|
|
189
|
+
keep: Tuple[str, int] = ("tokens", chat_history.max_tokens)
|
|
190
|
+
|
|
191
|
+
logger.info("Summarization middleware configured", trigger=trigger, keep=keep)
|
|
192
|
+
|
|
193
|
+
return LoggingSummarizationMiddleware(
|
|
194
|
+
model=model,
|
|
195
|
+
trigger=trigger,
|
|
196
|
+
keep=keep,
|
|
197
|
+
)
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool call limit middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
This module provides a factory for creating LangChain's ToolCallLimitMiddleware
|
|
5
|
+
from DAO AI configuration.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from dao_ai.middleware import create_tool_call_limit_middleware
|
|
9
|
+
|
|
10
|
+
# Global limit across all tools
|
|
11
|
+
middleware = create_tool_call_limit_middleware(
|
|
12
|
+
thread_limit=20,
|
|
13
|
+
run_limit=10,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Limit specific tool by name
|
|
17
|
+
search_limiter = create_tool_call_limit_middleware(
|
|
18
|
+
tool="search_web",
|
|
19
|
+
run_limit=3,
|
|
20
|
+
exit_behavior="continue",
|
|
21
|
+
)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from typing import Any, Literal
|
|
27
|
+
|
|
28
|
+
from langchain.agents.middleware import ToolCallLimitMiddleware
|
|
29
|
+
from langchain_core.tools import BaseTool
|
|
30
|
+
from loguru import logger
|
|
31
|
+
|
|
32
|
+
from dao_ai.config import BaseFunctionModel, ToolModel
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"ToolCallLimitMiddleware",
|
|
36
|
+
"create_tool_call_limit_middleware",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _resolve_tool(tool: str | ToolModel | dict[str, Any]) -> list[str]:
|
|
41
|
+
"""
|
|
42
|
+
Resolve tool argument to a list of actual tool names.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
tool: String name, ToolModel, or dict to resolve
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
List of tool name strings
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If dict cannot be converted to ToolModel
|
|
52
|
+
TypeError: If tool is not a supported type
|
|
53
|
+
"""
|
|
54
|
+
# String: return as single-item list
|
|
55
|
+
if isinstance(tool, str):
|
|
56
|
+
return [tool]
|
|
57
|
+
|
|
58
|
+
# Dict: convert to ToolModel first
|
|
59
|
+
if isinstance(tool, dict):
|
|
60
|
+
try:
|
|
61
|
+
tool_model = ToolModel(**tool)
|
|
62
|
+
except Exception as e:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Failed to construct ToolModel from dict: {e}\n"
|
|
65
|
+
f"Dict must have 'name' and 'function' keys."
|
|
66
|
+
) from e
|
|
67
|
+
elif isinstance(tool, ToolModel):
|
|
68
|
+
tool_model = tool
|
|
69
|
+
else:
|
|
70
|
+
raise TypeError(
|
|
71
|
+
f"tool must be str, ToolModel, or dict, got {type(tool).__name__}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Extract tool names from ToolModel
|
|
75
|
+
return _extract_tool_names(tool_model)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _extract_tool_names(tool_model: ToolModel) -> list[str]:
|
|
79
|
+
"""
|
|
80
|
+
Extract actual tool names from a ToolModel.
|
|
81
|
+
|
|
82
|
+
A single ToolModel can produce multiple tools (e.g., UC functions).
|
|
83
|
+
Falls back to ToolModel.name if extraction fails.
|
|
84
|
+
"""
|
|
85
|
+
function = tool_model.function
|
|
86
|
+
|
|
87
|
+
# String function references can't be introspected
|
|
88
|
+
if not isinstance(function, BaseFunctionModel):
|
|
89
|
+
logger.debug(
|
|
90
|
+
"Cannot extract names from string function, using ToolModel.name",
|
|
91
|
+
tool_model_name=tool_model.name,
|
|
92
|
+
)
|
|
93
|
+
return [tool_model.name]
|
|
94
|
+
|
|
95
|
+
# Try to extract names from created tools
|
|
96
|
+
try:
|
|
97
|
+
tool_names = [
|
|
98
|
+
tool.name
|
|
99
|
+
for tool in function.as_tools()
|
|
100
|
+
if isinstance(tool, BaseTool) and tool.name
|
|
101
|
+
]
|
|
102
|
+
if tool_names:
|
|
103
|
+
logger.trace(
|
|
104
|
+
"Extracted tool names",
|
|
105
|
+
tool_model_name=tool_model.name,
|
|
106
|
+
tool_names=tool_names,
|
|
107
|
+
)
|
|
108
|
+
return tool_names
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.warning(
|
|
111
|
+
"Error extracting tool names from ToolModel",
|
|
112
|
+
tool_model_name=tool_model.name,
|
|
113
|
+
error=str(e),
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Fallback to ToolModel.name
|
|
117
|
+
logger.debug(
|
|
118
|
+
"Falling back to ToolModel.name",
|
|
119
|
+
tool_model_name=tool_model.name,
|
|
120
|
+
)
|
|
121
|
+
return [tool_model.name]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def create_tool_call_limit_middleware(
|
|
125
|
+
tool: str | ToolModel | dict[str, Any] | None = None,
|
|
126
|
+
thread_limit: int | None = None,
|
|
127
|
+
run_limit: int | None = None,
|
|
128
|
+
exit_behavior: Literal["continue", "error", "end"] = "continue",
|
|
129
|
+
) -> ToolCallLimitMiddleware:
|
|
130
|
+
"""
|
|
131
|
+
Create a ToolCallLimitMiddleware with graceful termination support.
|
|
132
|
+
|
|
133
|
+
Factory for LangChain's ToolCallLimitMiddleware that supports DAO AI
|
|
134
|
+
configuration types.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tool: Tool to limit. Can be:
|
|
138
|
+
- None: Global limit on all tools
|
|
139
|
+
- str: Limit specific tool by name
|
|
140
|
+
- ToolModel: Limit tool(s) from DAO AI config
|
|
141
|
+
- dict: Tool config dict (converted to ToolModel)
|
|
142
|
+
thread_limit: Max calls per thread (conversation). Requires checkpointer.
|
|
143
|
+
run_limit: Max calls per run (single invocation).
|
|
144
|
+
exit_behavior: What to do when limit hit:
|
|
145
|
+
- "continue": Block tool with error message, let agent continue
|
|
146
|
+
- "error": Raise ToolCallLimitExceededError immediately
|
|
147
|
+
- "end": Stop execution gracefully (single-tool only)
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
A ToolCallLimitMiddleware instance. If ToolModel produces multiple tools,
|
|
151
|
+
only the first tool is used (with a warning logged).
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
ValueError: If no limits specified, or invalid dict
|
|
155
|
+
TypeError: If tool is unsupported type
|
|
156
|
+
|
|
157
|
+
Example:
|
|
158
|
+
# Global limit
|
|
159
|
+
limiter = create_tool_call_limit_middleware(run_limit=10)
|
|
160
|
+
|
|
161
|
+
# Tool-specific limit
|
|
162
|
+
limiter = create_tool_call_limit_middleware(
|
|
163
|
+
tool="search_web",
|
|
164
|
+
run_limit=3,
|
|
165
|
+
exit_behavior="continue",
|
|
166
|
+
)
|
|
167
|
+
"""
|
|
168
|
+
if thread_limit is None and run_limit is None:
|
|
169
|
+
raise ValueError("At least one of thread_limit or run_limit must be specified.")
|
|
170
|
+
|
|
171
|
+
# Global limit: no tool parameter
|
|
172
|
+
if tool is None:
|
|
173
|
+
logger.debug(
|
|
174
|
+
"Creating global tool call limit",
|
|
175
|
+
thread_limit=thread_limit,
|
|
176
|
+
run_limit=run_limit,
|
|
177
|
+
exit_behavior=exit_behavior,
|
|
178
|
+
)
|
|
179
|
+
return ToolCallLimitMiddleware(
|
|
180
|
+
thread_limit=thread_limit,
|
|
181
|
+
run_limit=run_limit,
|
|
182
|
+
exit_behavior=exit_behavior,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Resolve to list of tool names
|
|
186
|
+
names = _resolve_tool(tool)
|
|
187
|
+
|
|
188
|
+
# Use first tool name (warn if multiple)
|
|
189
|
+
tool_name = names[0]
|
|
190
|
+
if len(names) > 1:
|
|
191
|
+
logger.warning(
|
|
192
|
+
"ToolModel resolved to multiple tool names, using first only",
|
|
193
|
+
tool_names=names,
|
|
194
|
+
using=tool_name,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
logger.debug(
|
|
198
|
+
"Creating tool call limit middleware",
|
|
199
|
+
tool_name=tool_name,
|
|
200
|
+
thread_limit=thread_limit,
|
|
201
|
+
run_limit=run_limit,
|
|
202
|
+
exit_behavior=exit_behavior,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return ToolCallLimitMiddleware(
|
|
206
|
+
tool_name=tool_name,
|
|
207
|
+
thread_limit=thread_limit,
|
|
208
|
+
run_limit=run_limit,
|
|
209
|
+
exit_behavior=exit_behavior,
|
|
210
|
+
)
|