aiecs 1.3.8__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +1 -1
- aiecs/domain/__init__.py +120 -0
- aiecs/domain/agent/__init__.py +184 -0
- aiecs/domain/agent/base_agent.py +691 -0
- aiecs/domain/agent/exceptions.py +99 -0
- aiecs/domain/agent/hybrid_agent.py +495 -0
- aiecs/domain/agent/integration/__init__.py +23 -0
- aiecs/domain/agent/integration/context_compressor.py +219 -0
- aiecs/domain/agent/integration/context_engine_adapter.py +258 -0
- aiecs/domain/agent/integration/retry_policy.py +228 -0
- aiecs/domain/agent/integration/role_config.py +217 -0
- aiecs/domain/agent/lifecycle.py +298 -0
- aiecs/domain/agent/llm_agent.py +309 -0
- aiecs/domain/agent/memory/__init__.py +13 -0
- aiecs/domain/agent/memory/conversation.py +216 -0
- aiecs/domain/agent/migration/__init__.py +15 -0
- aiecs/domain/agent/migration/conversion.py +171 -0
- aiecs/domain/agent/migration/legacy_wrapper.py +97 -0
- aiecs/domain/agent/models.py +263 -0
- aiecs/domain/agent/observability.py +443 -0
- aiecs/domain/agent/persistence.py +287 -0
- aiecs/domain/agent/prompts/__init__.py +25 -0
- aiecs/domain/agent/prompts/builder.py +164 -0
- aiecs/domain/agent/prompts/formatters.py +192 -0
- aiecs/domain/agent/prompts/template.py +264 -0
- aiecs/domain/agent/registry.py +261 -0
- aiecs/domain/agent/tool_agent.py +267 -0
- aiecs/domain/agent/tools/__init__.py +13 -0
- aiecs/domain/agent/tools/schema_generator.py +222 -0
- aiecs/main.py +2 -2
- {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/METADATA +1 -1
- {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/RECORD +36 -9
- {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/WHEEL +0 -0
- {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/entry_points.txt +0 -0
- {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/licenses/LICENSE +0 -0
- {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context Compression
|
|
3
|
+
|
|
4
|
+
Smart context compression for token limits.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import List, Dict, Any, Optional
|
|
9
|
+
from enum import Enum
|
|
10
|
+
|
|
11
|
+
from aiecs.llm import LLMMessage
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CompressionStrategy(Enum):
|
|
17
|
+
"""Context compression strategies."""
|
|
18
|
+
TRUNCATE_MIDDLE = "truncate_middle"
|
|
19
|
+
TRUNCATE_START = "truncate_start"
|
|
20
|
+
PRESERVE_RECENT = "preserve_recent"
|
|
21
|
+
SUMMARIZE = "summarize"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ContextCompressor:
|
|
25
|
+
"""
|
|
26
|
+
Smart context compression for managing token limits.
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
compressor = ContextCompressor(max_tokens=4000)
|
|
30
|
+
compressed = compressor.compress_messages(messages)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
max_tokens: int = 4000,
|
|
36
|
+
strategy: CompressionStrategy = CompressionStrategy.PRESERVE_RECENT,
|
|
37
|
+
preserve_system: bool = True,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Initialize context compressor.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
max_tokens: Maximum token limit
|
|
44
|
+
strategy: Compression strategy
|
|
45
|
+
preserve_system: Always preserve system messages
|
|
46
|
+
"""
|
|
47
|
+
self.max_tokens = max_tokens
|
|
48
|
+
self.strategy = strategy
|
|
49
|
+
self.preserve_system = preserve_system
|
|
50
|
+
|
|
51
|
+
def compress_messages(
|
|
52
|
+
self,
|
|
53
|
+
messages: List[LLMMessage],
|
|
54
|
+
priority_indices: Optional[List[int]] = None
|
|
55
|
+
) -> List[LLMMessage]:
|
|
56
|
+
"""
|
|
57
|
+
Compress message list to fit within token limit.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
messages: List of messages
|
|
61
|
+
priority_indices: Optional indices of high-priority messages
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Compressed message list
|
|
65
|
+
"""
|
|
66
|
+
# Estimate tokens
|
|
67
|
+
total_tokens = self._estimate_tokens(messages)
|
|
68
|
+
|
|
69
|
+
if total_tokens <= self.max_tokens:
|
|
70
|
+
return messages
|
|
71
|
+
|
|
72
|
+
logger.debug(
|
|
73
|
+
f"Compressing {len(messages)} messages from ~{total_tokens} to ~{self.max_tokens} tokens"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Apply compression strategy
|
|
77
|
+
if self.strategy == CompressionStrategy.PRESERVE_RECENT:
|
|
78
|
+
return self._compress_preserve_recent(messages, priority_indices)
|
|
79
|
+
elif self.strategy == CompressionStrategy.TRUNCATE_MIDDLE:
|
|
80
|
+
return self._compress_truncate_middle(messages, priority_indices)
|
|
81
|
+
elif self.strategy == CompressionStrategy.TRUNCATE_START:
|
|
82
|
+
return self._compress_truncate_start(messages)
|
|
83
|
+
else:
|
|
84
|
+
# Default: preserve recent
|
|
85
|
+
return self._compress_preserve_recent(messages, priority_indices)
|
|
86
|
+
|
|
87
|
+
def _compress_preserve_recent(
|
|
88
|
+
self,
|
|
89
|
+
messages: List[LLMMessage],
|
|
90
|
+
priority_indices: Optional[List[int]]
|
|
91
|
+
) -> List[LLMMessage]:
|
|
92
|
+
"""Preserve recent messages and priority messages."""
|
|
93
|
+
priority_indices = set(priority_indices or [])
|
|
94
|
+
compressed = []
|
|
95
|
+
|
|
96
|
+
# Always include system messages if enabled
|
|
97
|
+
if self.preserve_system:
|
|
98
|
+
system_msgs = [msg for msg in messages if msg.role == "system"]
|
|
99
|
+
compressed.extend(system_msgs)
|
|
100
|
+
|
|
101
|
+
# Calculate remaining budget
|
|
102
|
+
remaining_tokens = self.max_tokens - self._estimate_tokens(compressed)
|
|
103
|
+
|
|
104
|
+
# Add priority messages
|
|
105
|
+
for idx in priority_indices:
|
|
106
|
+
if idx < len(messages) and messages[idx] not in compressed:
|
|
107
|
+
msg_tokens = self._estimate_tokens([messages[idx]])
|
|
108
|
+
if msg_tokens <= remaining_tokens:
|
|
109
|
+
compressed.append(messages[idx])
|
|
110
|
+
remaining_tokens -= msg_tokens
|
|
111
|
+
|
|
112
|
+
# Add recent messages (from end)
|
|
113
|
+
for msg in reversed(messages):
|
|
114
|
+
if msg not in compressed:
|
|
115
|
+
msg_tokens = self._estimate_tokens([msg])
|
|
116
|
+
if msg_tokens <= remaining_tokens:
|
|
117
|
+
compressed.insert(len(compressed), msg)
|
|
118
|
+
remaining_tokens -= msg_tokens
|
|
119
|
+
else:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
return compressed
|
|
123
|
+
|
|
124
|
+
def _compress_truncate_middle(
|
|
125
|
+
self,
|
|
126
|
+
messages: List[LLMMessage],
|
|
127
|
+
priority_indices: Optional[List[int]]
|
|
128
|
+
) -> List[LLMMessage]:
|
|
129
|
+
"""Keep start and end messages, truncate middle."""
|
|
130
|
+
if len(messages) <= 4:
|
|
131
|
+
return messages
|
|
132
|
+
|
|
133
|
+
# Keep first 2 and last 2 by default
|
|
134
|
+
keep_start = 2
|
|
135
|
+
keep_end = 2
|
|
136
|
+
|
|
137
|
+
# Adjust based on token budget
|
|
138
|
+
start_msgs = messages[:keep_start]
|
|
139
|
+
end_msgs = messages[-keep_end:]
|
|
140
|
+
|
|
141
|
+
compressed = start_msgs + [
|
|
142
|
+
LLMMessage(role="system", content="[... conversation history compressed ...]")
|
|
143
|
+
] + end_msgs
|
|
144
|
+
|
|
145
|
+
return compressed
|
|
146
|
+
|
|
147
|
+
def _compress_truncate_start(self, messages: List[LLMMessage]) -> List[LLMMessage]:
|
|
148
|
+
"""Keep recent messages, truncate start."""
|
|
149
|
+
compressed = []
|
|
150
|
+
remaining_tokens = self.max_tokens
|
|
151
|
+
|
|
152
|
+
# Process from end
|
|
153
|
+
for msg in reversed(messages):
|
|
154
|
+
msg_tokens = self._estimate_tokens([msg])
|
|
155
|
+
if msg_tokens <= remaining_tokens:
|
|
156
|
+
compressed.insert(0, msg)
|
|
157
|
+
remaining_tokens -= msg_tokens
|
|
158
|
+
else:
|
|
159
|
+
break
|
|
160
|
+
|
|
161
|
+
return compressed
|
|
162
|
+
|
|
163
|
+
def _estimate_tokens(self, messages: List[LLMMessage]) -> int:
|
|
164
|
+
"""
|
|
165
|
+
Estimate token count for messages.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
messages: List of messages
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Estimated token count
|
|
172
|
+
"""
|
|
173
|
+
total_chars = sum(len(msg.content) for msg in messages)
|
|
174
|
+
# Rough estimate: 4 chars ≈ 1 token
|
|
175
|
+
return total_chars // 4
|
|
176
|
+
|
|
177
|
+
def compress_text(self, text: str, max_tokens: int) -> str:
|
|
178
|
+
"""
|
|
179
|
+
Compress text to fit within token limit.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
text: Text to compress
|
|
183
|
+
max_tokens: Maximum tokens
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Compressed text
|
|
187
|
+
"""
|
|
188
|
+
estimated_tokens = len(text) // 4
|
|
189
|
+
|
|
190
|
+
if estimated_tokens <= max_tokens:
|
|
191
|
+
return text
|
|
192
|
+
|
|
193
|
+
# Truncate to fit
|
|
194
|
+
max_chars = max_tokens * 4
|
|
195
|
+
if len(text) <= max_chars:
|
|
196
|
+
return text
|
|
197
|
+
|
|
198
|
+
return text[:max_chars - 20] + "... [truncated]"
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def compress_messages(
|
|
202
|
+
messages: List[LLMMessage],
|
|
203
|
+
max_tokens: int = 4000,
|
|
204
|
+
strategy: CompressionStrategy = CompressionStrategy.PRESERVE_RECENT
|
|
205
|
+
) -> List[LLMMessage]:
|
|
206
|
+
"""
|
|
207
|
+
Convenience function for compressing messages.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
messages: List of messages
|
|
211
|
+
max_tokens: Maximum token limit
|
|
212
|
+
strategy: Compression strategy
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Compressed message list
|
|
216
|
+
"""
|
|
217
|
+
compressor = ContextCompressor(max_tokens=max_tokens, strategy=strategy)
|
|
218
|
+
return compressor.compress_messages(messages)
|
|
219
|
+
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
"""
|
|
2
|
+
ContextEngine Adapter
|
|
3
|
+
|
|
4
|
+
Adapter for integrating agent persistence with AIECS ContextEngine.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import json
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from aiecs.domain.context.context_engine import ContextEngine
|
|
15
|
+
|
|
16
|
+
from aiecs.domain.agent.persistence import AgentPersistence
|
|
17
|
+
from aiecs.domain.agent.base_agent import BaseAIAgent
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ContextEngineAdapter:
|
|
23
|
+
"""
|
|
24
|
+
Adapter for persisting agent state to ContextEngine.
|
|
25
|
+
|
|
26
|
+
Uses ContextEngine's checkpoint system for versioned state storage
|
|
27
|
+
and TaskContext for session-based state management.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, context_engine: "ContextEngine", user_id: str = "system"):
|
|
31
|
+
"""
|
|
32
|
+
Initialize adapter.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
context_engine: ContextEngine instance
|
|
36
|
+
user_id: User identifier for session management
|
|
37
|
+
"""
|
|
38
|
+
if context_engine is None:
|
|
39
|
+
raise ValueError("ContextEngine instance is required")
|
|
40
|
+
|
|
41
|
+
self.context_engine = context_engine
|
|
42
|
+
self.user_id = user_id
|
|
43
|
+
self._agent_state_prefix = "agent_state"
|
|
44
|
+
self._agent_conversation_prefix = "agent_conversation"
|
|
45
|
+
logger.info("ContextEngineAdapter initialized")
|
|
46
|
+
|
|
47
|
+
async def save_agent_state(
|
|
48
|
+
self,
|
|
49
|
+
agent_id: str,
|
|
50
|
+
state: Dict[str, Any],
|
|
51
|
+
version: Optional[str] = None
|
|
52
|
+
) -> str:
|
|
53
|
+
"""
|
|
54
|
+
Save agent state to ContextEngine using checkpoint system.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
agent_id: Agent identifier
|
|
58
|
+
state: Agent state dictionary
|
|
59
|
+
version: Optional version identifier (auto-generated if None)
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Version identifier
|
|
63
|
+
"""
|
|
64
|
+
if version is None:
|
|
65
|
+
version = str(uuid.uuid4())
|
|
66
|
+
|
|
67
|
+
checkpoint_data = {
|
|
68
|
+
"agent_id": agent_id,
|
|
69
|
+
"state": state,
|
|
70
|
+
"timestamp": datetime.utcnow().isoformat(),
|
|
71
|
+
"version": version
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Store as checkpoint (thread_id = agent_id)
|
|
75
|
+
await self.context_engine.store_checkpoint(
|
|
76
|
+
thread_id=agent_id,
|
|
77
|
+
checkpoint_id=version,
|
|
78
|
+
checkpoint_data=checkpoint_data,
|
|
79
|
+
metadata={"type": "agent_state", "agent_id": agent_id}
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
logger.debug(f"Saved agent {agent_id} state version {version} to ContextEngine")
|
|
83
|
+
return version
|
|
84
|
+
|
|
85
|
+
async def load_agent_state(
|
|
86
|
+
self,
|
|
87
|
+
agent_id: str,
|
|
88
|
+
version: Optional[str] = None
|
|
89
|
+
) -> Optional[Dict[str, Any]]:
|
|
90
|
+
"""
|
|
91
|
+
Load agent state from ContextEngine.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
agent_id: Agent identifier
|
|
95
|
+
version: Optional version identifier (loads latest if None)
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Agent state dictionary or None
|
|
99
|
+
"""
|
|
100
|
+
checkpoint = await self.context_engine.get_checkpoint(
|
|
101
|
+
thread_id=agent_id,
|
|
102
|
+
checkpoint_id=version
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if checkpoint and "data" in checkpoint:
|
|
106
|
+
checkpoint_data = checkpoint["data"]
|
|
107
|
+
if isinstance(checkpoint_data, dict) and "state" in checkpoint_data:
|
|
108
|
+
logger.debug(f"Loaded agent {agent_id} state version {version or 'latest'}")
|
|
109
|
+
return checkpoint_data["state"]
|
|
110
|
+
|
|
111
|
+
logger.debug(f"No state found for agent {agent_id} version {version or 'latest'}")
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
async def list_agent_versions(self, agent_id: str) -> List[Dict[str, Any]]:
|
|
115
|
+
"""
|
|
116
|
+
List all versions of an agent's state.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
agent_id: Agent identifier
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of version metadata dictionaries
|
|
123
|
+
"""
|
|
124
|
+
checkpoints = await self.context_engine.list_checkpoints(thread_id=agent_id)
|
|
125
|
+
if not checkpoints:
|
|
126
|
+
return []
|
|
127
|
+
|
|
128
|
+
versions = []
|
|
129
|
+
for checkpoint in checkpoints:
|
|
130
|
+
# list_checkpoints returns dicts with "data" key containing checkpoint_data
|
|
131
|
+
if isinstance(checkpoint, dict):
|
|
132
|
+
data = checkpoint.get("data", {})
|
|
133
|
+
if isinstance(data, dict) and "version" in data:
|
|
134
|
+
versions.append({
|
|
135
|
+
"version": data["version"],
|
|
136
|
+
"timestamp": data.get("timestamp"),
|
|
137
|
+
"metadata": checkpoint.get("metadata", {})
|
|
138
|
+
})
|
|
139
|
+
|
|
140
|
+
# Sort by timestamp descending
|
|
141
|
+
versions.sort(key=lambda v: v.get("timestamp", ""), reverse=True)
|
|
142
|
+
return versions
|
|
143
|
+
|
|
144
|
+
async def save_conversation_history(
|
|
145
|
+
self,
|
|
146
|
+
session_id: str,
|
|
147
|
+
messages: List[Dict[str, Any]]
|
|
148
|
+
) -> None:
|
|
149
|
+
"""
|
|
150
|
+
Save conversation history to ContextEngine.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
session_id: Session identifier
|
|
154
|
+
messages: List of message dictionaries with 'role' and 'content'
|
|
155
|
+
"""
|
|
156
|
+
# Ensure session exists
|
|
157
|
+
session = await self.context_engine.get_session(session_id)
|
|
158
|
+
if not session:
|
|
159
|
+
await self.context_engine.create_session(
|
|
160
|
+
session_id=session_id,
|
|
161
|
+
user_id=self.user_id,
|
|
162
|
+
metadata={"type": "agent_conversation"}
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Store messages using ContextEngine's conversation API
|
|
166
|
+
for msg in messages:
|
|
167
|
+
role = msg.get("role", "user")
|
|
168
|
+
content = msg.get("content", "")
|
|
169
|
+
metadata = msg.get("metadata", {})
|
|
170
|
+
|
|
171
|
+
await self.context_engine.add_conversation_message(
|
|
172
|
+
session_id=session_id,
|
|
173
|
+
role=role,
|
|
174
|
+
content=content,
|
|
175
|
+
metadata=metadata
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
logger.debug(f"Saved {len(messages)} messages to session {session_id}")
|
|
179
|
+
|
|
180
|
+
async def load_conversation_history(
|
|
181
|
+
self,
|
|
182
|
+
session_id: str,
|
|
183
|
+
limit: int = 50
|
|
184
|
+
) -> List[Dict[str, Any]]:
|
|
185
|
+
"""
|
|
186
|
+
Load conversation history from ContextEngine.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
session_id: Session identifier
|
|
190
|
+
limit: Maximum number of messages to retrieve
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
List of message dictionaries
|
|
194
|
+
"""
|
|
195
|
+
messages = await self.context_engine.get_conversation_history(
|
|
196
|
+
session_id=session_id,
|
|
197
|
+
limit=limit
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Convert ConversationMessage objects to dictionaries
|
|
201
|
+
result = []
|
|
202
|
+
for msg in messages:
|
|
203
|
+
result.append({
|
|
204
|
+
"role": msg.role,
|
|
205
|
+
"content": msg.content,
|
|
206
|
+
"timestamp": msg.timestamp.isoformat() if hasattr(msg.timestamp, 'isoformat') else str(msg.timestamp),
|
|
207
|
+
"metadata": msg.metadata
|
|
208
|
+
})
|
|
209
|
+
|
|
210
|
+
logger.debug(f"Loaded {len(result)} messages from session {session_id}")
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
async def delete_agent_state(
|
|
214
|
+
self,
|
|
215
|
+
agent_id: str,
|
|
216
|
+
version: Optional[str] = None
|
|
217
|
+
) -> None:
|
|
218
|
+
"""
|
|
219
|
+
Delete agent state from ContextEngine.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
agent_id: Agent identifier
|
|
223
|
+
version: Optional version identifier (deletes all if None)
|
|
224
|
+
"""
|
|
225
|
+
# Note: ContextEngine doesn't have explicit delete for checkpoints
|
|
226
|
+
# We'll store a tombstone checkpoint or rely on TTL
|
|
227
|
+
if version:
|
|
228
|
+
# Store empty state as deletion marker
|
|
229
|
+
await self.context_engine.store_checkpoint(
|
|
230
|
+
thread_id=agent_id,
|
|
231
|
+
checkpoint_id=f"{version}_deleted",
|
|
232
|
+
checkpoint_data={"deleted": True, "original_version": version},
|
|
233
|
+
metadata={"type": "deletion_marker"}
|
|
234
|
+
)
|
|
235
|
+
logger.debug(f"Marked agent {agent_id} state version {version or 'all'} for deletion")
|
|
236
|
+
|
|
237
|
+
# AgentPersistence Protocol implementation
|
|
238
|
+
async def save(self, agent: BaseAIAgent) -> None:
|
|
239
|
+
"""Save agent state (implements AgentPersistence protocol)."""
|
|
240
|
+
state = agent.to_dict()
|
|
241
|
+
await self.save_agent_state(agent.agent_id, state)
|
|
242
|
+
|
|
243
|
+
async def load(self, agent_id: str) -> Dict[str, Any]:
|
|
244
|
+
"""Load agent state (implements AgentPersistence protocol)."""
|
|
245
|
+
state = await self.load_agent_state(agent_id)
|
|
246
|
+
if state is None:
|
|
247
|
+
raise KeyError(f"Agent {agent_id} not found in storage")
|
|
248
|
+
return state
|
|
249
|
+
|
|
250
|
+
async def exists(self, agent_id: str) -> bool:
|
|
251
|
+
"""Check if agent state exists (implements AgentPersistence protocol)."""
|
|
252
|
+
state = await self.load_agent_state(agent_id)
|
|
253
|
+
return state is not None
|
|
254
|
+
|
|
255
|
+
async def delete(self, agent_id: str) -> None:
|
|
256
|
+
"""Delete agent state (implements AgentPersistence protocol)."""
|
|
257
|
+
await self.delete_agent_state(agent_id)
|
|
258
|
+
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enhanced Retry Policy
|
|
3
|
+
|
|
4
|
+
Sophisticated retry logic with exponential backoff and error classification.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import random
|
|
9
|
+
import logging
|
|
10
|
+
from typing import Optional, Callable, Any
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from enum import Enum
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ErrorType(Enum):
|
|
18
|
+
"""Error types for classification."""
|
|
19
|
+
RATE_LIMIT = "rate_limit"
|
|
20
|
+
TIMEOUT = "timeout"
|
|
21
|
+
SERVER_ERROR = "server_error"
|
|
22
|
+
CLIENT_ERROR = "client_error"
|
|
23
|
+
NETWORK_ERROR = "network_error"
|
|
24
|
+
UNKNOWN = "unknown"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ErrorClassifier:
|
|
28
|
+
"""Classifies errors for retry strategy."""
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def classify(error: Exception) -> ErrorType:
|
|
32
|
+
"""
|
|
33
|
+
Classify error type.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
error: Exception to classify
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
ErrorType
|
|
40
|
+
"""
|
|
41
|
+
error_str = str(error).lower()
|
|
42
|
+
error_type_name = type(error).__name__.lower()
|
|
43
|
+
|
|
44
|
+
# Rate limit errors
|
|
45
|
+
if "rate limit" in error_str or "429" in error_str:
|
|
46
|
+
return ErrorType.RATE_LIMIT
|
|
47
|
+
|
|
48
|
+
# Timeout errors
|
|
49
|
+
if "timeout" in error_str or "timed out" in error_str:
|
|
50
|
+
return ErrorType.TIMEOUT
|
|
51
|
+
|
|
52
|
+
# Server errors (5xx)
|
|
53
|
+
if any(code in error_str for code in ["500", "502", "503", "504"]):
|
|
54
|
+
return ErrorType.SERVER_ERROR
|
|
55
|
+
|
|
56
|
+
# Client errors (4xx)
|
|
57
|
+
if any(code in error_str for code in ["400", "401", "403", "404"]):
|
|
58
|
+
return ErrorType.CLIENT_ERROR
|
|
59
|
+
|
|
60
|
+
# Network errors
|
|
61
|
+
if any(term in error_type_name for term in ["connection", "network", "socket"]):
|
|
62
|
+
return ErrorType.NETWORK_ERROR
|
|
63
|
+
|
|
64
|
+
return ErrorType.UNKNOWN
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def is_retryable(error_type: ErrorType) -> bool:
|
|
68
|
+
"""
|
|
69
|
+
Determine if error type should be retried.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
error_type: Error type
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
True if retryable
|
|
76
|
+
"""
|
|
77
|
+
retryable_types = {
|
|
78
|
+
ErrorType.RATE_LIMIT,
|
|
79
|
+
ErrorType.TIMEOUT,
|
|
80
|
+
ErrorType.SERVER_ERROR,
|
|
81
|
+
ErrorType.NETWORK_ERROR,
|
|
82
|
+
}
|
|
83
|
+
return error_type in retryable_types
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class EnhancedRetryPolicy:
|
|
87
|
+
"""
|
|
88
|
+
Enhanced retry policy with exponential backoff and jitter.
|
|
89
|
+
|
|
90
|
+
Example:
|
|
91
|
+
policy = EnhancedRetryPolicy(max_retries=5, base_delay=1.0)
|
|
92
|
+
result = await policy.execute_with_retry(my_async_function, arg1, arg2)
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
max_retries: int = 3,
|
|
98
|
+
base_delay: float = 1.0,
|
|
99
|
+
max_delay: float = 60.0,
|
|
100
|
+
exponential_base: float = 2.0,
|
|
101
|
+
jitter: bool = True,
|
|
102
|
+
):
|
|
103
|
+
"""
|
|
104
|
+
Initialize retry policy.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
max_retries: Maximum number of retry attempts
|
|
108
|
+
base_delay: Base delay in seconds
|
|
109
|
+
max_delay: Maximum delay in seconds
|
|
110
|
+
exponential_base: Base for exponential backoff
|
|
111
|
+
jitter: Whether to add random jitter
|
|
112
|
+
"""
|
|
113
|
+
self.max_retries = max_retries
|
|
114
|
+
self.base_delay = base_delay
|
|
115
|
+
self.max_delay = max_delay
|
|
116
|
+
self.exponential_base = exponential_base
|
|
117
|
+
self.jitter = jitter
|
|
118
|
+
|
|
119
|
+
def calculate_delay(self, attempt: int, error_type: ErrorType) -> float:
|
|
120
|
+
"""
|
|
121
|
+
Calculate delay for retry attempt.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
attempt: Retry attempt number (0-indexed)
|
|
125
|
+
error_type: Type of error
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Delay in seconds
|
|
129
|
+
"""
|
|
130
|
+
# Base exponential backoff
|
|
131
|
+
delay = min(
|
|
132
|
+
self.base_delay * (self.exponential_base ** attempt),
|
|
133
|
+
self.max_delay
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Adjust for error type
|
|
137
|
+
if error_type == ErrorType.RATE_LIMIT:
|
|
138
|
+
# Longer delay for rate limits
|
|
139
|
+
delay *= 2
|
|
140
|
+
|
|
141
|
+
# Add jitter to prevent thundering herd
|
|
142
|
+
if self.jitter:
|
|
143
|
+
delay *= (0.5 + random.random())
|
|
144
|
+
|
|
145
|
+
return delay
|
|
146
|
+
|
|
147
|
+
async def execute_with_retry(
|
|
148
|
+
self,
|
|
149
|
+
func: Callable,
|
|
150
|
+
*args,
|
|
151
|
+
**kwargs
|
|
152
|
+
) -> Any:
|
|
153
|
+
"""
|
|
154
|
+
Execute function with retry logic.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
func: Async function to execute
|
|
158
|
+
*args: Function arguments
|
|
159
|
+
**kwargs: Function keyword arguments
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Function result
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
Exception: If all retries exhausted
|
|
166
|
+
"""
|
|
167
|
+
last_error = None
|
|
168
|
+
|
|
169
|
+
for attempt in range(self.max_retries + 1):
|
|
170
|
+
try:
|
|
171
|
+
result = await func(*args, **kwargs)
|
|
172
|
+
|
|
173
|
+
# Log success after retries
|
|
174
|
+
if attempt > 0:
|
|
175
|
+
logger.info(f"Succeeded after {attempt} retries")
|
|
176
|
+
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
except Exception as e:
|
|
180
|
+
last_error = e
|
|
181
|
+
|
|
182
|
+
# Classify error
|
|
183
|
+
error_type = ErrorClassifier.classify(e)
|
|
184
|
+
|
|
185
|
+
# Check if we should retry
|
|
186
|
+
if attempt >= self.max_retries:
|
|
187
|
+
logger.error(f"Max retries ({self.max_retries}) exhausted")
|
|
188
|
+
break
|
|
189
|
+
|
|
190
|
+
if not ErrorClassifier.is_retryable(error_type):
|
|
191
|
+
logger.error(f"Non-retryable error: {error_type.value}")
|
|
192
|
+
break
|
|
193
|
+
|
|
194
|
+
# Calculate delay and wait
|
|
195
|
+
delay = self.calculate_delay(attempt, error_type)
|
|
196
|
+
logger.warning(
|
|
197
|
+
f"Attempt {attempt + 1} failed ({error_type.value}). "
|
|
198
|
+
f"Retrying in {delay:.2f}s..."
|
|
199
|
+
)
|
|
200
|
+
await asyncio.sleep(delay)
|
|
201
|
+
|
|
202
|
+
# All retries exhausted
|
|
203
|
+
raise last_error
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
async def with_retry(
|
|
207
|
+
func: Callable,
|
|
208
|
+
max_retries: int = 3,
|
|
209
|
+
base_delay: float = 1.0,
|
|
210
|
+
*args,
|
|
211
|
+
**kwargs
|
|
212
|
+
) -> Any:
|
|
213
|
+
"""
|
|
214
|
+
Convenience function for executing with retry.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
func: Async function to execute
|
|
218
|
+
max_retries: Maximum number of retries
|
|
219
|
+
base_delay: Base delay in seconds
|
|
220
|
+
*args: Function arguments
|
|
221
|
+
**kwargs: Function keyword arguments
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Function result
|
|
225
|
+
"""
|
|
226
|
+
policy = EnhancedRetryPolicy(max_retries=max_retries, base_delay=base_delay)
|
|
227
|
+
return await policy.execute_with_retry(func, *args, **kwargs)
|
|
228
|
+
|