aiecs 1.3.8__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiecs might be problematic. Click here for more details.

Files changed (36) hide show
  1. aiecs/__init__.py +1 -1
  2. aiecs/domain/__init__.py +120 -0
  3. aiecs/domain/agent/__init__.py +184 -0
  4. aiecs/domain/agent/base_agent.py +691 -0
  5. aiecs/domain/agent/exceptions.py +99 -0
  6. aiecs/domain/agent/hybrid_agent.py +495 -0
  7. aiecs/domain/agent/integration/__init__.py +23 -0
  8. aiecs/domain/agent/integration/context_compressor.py +219 -0
  9. aiecs/domain/agent/integration/context_engine_adapter.py +258 -0
  10. aiecs/domain/agent/integration/retry_policy.py +228 -0
  11. aiecs/domain/agent/integration/role_config.py +217 -0
  12. aiecs/domain/agent/lifecycle.py +298 -0
  13. aiecs/domain/agent/llm_agent.py +309 -0
  14. aiecs/domain/agent/memory/__init__.py +13 -0
  15. aiecs/domain/agent/memory/conversation.py +216 -0
  16. aiecs/domain/agent/migration/__init__.py +15 -0
  17. aiecs/domain/agent/migration/conversion.py +171 -0
  18. aiecs/domain/agent/migration/legacy_wrapper.py +97 -0
  19. aiecs/domain/agent/models.py +263 -0
  20. aiecs/domain/agent/observability.py +443 -0
  21. aiecs/domain/agent/persistence.py +287 -0
  22. aiecs/domain/agent/prompts/__init__.py +25 -0
  23. aiecs/domain/agent/prompts/builder.py +164 -0
  24. aiecs/domain/agent/prompts/formatters.py +192 -0
  25. aiecs/domain/agent/prompts/template.py +264 -0
  26. aiecs/domain/agent/registry.py +261 -0
  27. aiecs/domain/agent/tool_agent.py +267 -0
  28. aiecs/domain/agent/tools/__init__.py +13 -0
  29. aiecs/domain/agent/tools/schema_generator.py +222 -0
  30. aiecs/main.py +2 -2
  31. {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/METADATA +1 -1
  32. {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/RECORD +36 -9
  33. {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/WHEEL +0 -0
  34. {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/entry_points.txt +0 -0
  35. {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/licenses/LICENSE +0 -0
  36. {aiecs-1.3.8.dist-info → aiecs-1.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,219 @@
1
+ """
2
+ Context Compression
3
+
4
+ Smart context compression for token limits.
5
+ """
6
+
7
+ import logging
8
+ from typing import List, Dict, Any, Optional
9
+ from enum import Enum
10
+
11
+ from aiecs.llm import LLMMessage
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class CompressionStrategy(Enum):
17
+ """Context compression strategies."""
18
+ TRUNCATE_MIDDLE = "truncate_middle"
19
+ TRUNCATE_START = "truncate_start"
20
+ PRESERVE_RECENT = "preserve_recent"
21
+ SUMMARIZE = "summarize"
22
+
23
+
24
+ class ContextCompressor:
25
+ """
26
+ Smart context compression for managing token limits.
27
+
28
+ Example:
29
+ compressor = ContextCompressor(max_tokens=4000)
30
+ compressed = compressor.compress_messages(messages)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ max_tokens: int = 4000,
36
+ strategy: CompressionStrategy = CompressionStrategy.PRESERVE_RECENT,
37
+ preserve_system: bool = True,
38
+ ):
39
+ """
40
+ Initialize context compressor.
41
+
42
+ Args:
43
+ max_tokens: Maximum token limit
44
+ strategy: Compression strategy
45
+ preserve_system: Always preserve system messages
46
+ """
47
+ self.max_tokens = max_tokens
48
+ self.strategy = strategy
49
+ self.preserve_system = preserve_system
50
+
51
+ def compress_messages(
52
+ self,
53
+ messages: List[LLMMessage],
54
+ priority_indices: Optional[List[int]] = None
55
+ ) -> List[LLMMessage]:
56
+ """
57
+ Compress message list to fit within token limit.
58
+
59
+ Args:
60
+ messages: List of messages
61
+ priority_indices: Optional indices of high-priority messages
62
+
63
+ Returns:
64
+ Compressed message list
65
+ """
66
+ # Estimate tokens
67
+ total_tokens = self._estimate_tokens(messages)
68
+
69
+ if total_tokens <= self.max_tokens:
70
+ return messages
71
+
72
+ logger.debug(
73
+ f"Compressing {len(messages)} messages from ~{total_tokens} to ~{self.max_tokens} tokens"
74
+ )
75
+
76
+ # Apply compression strategy
77
+ if self.strategy == CompressionStrategy.PRESERVE_RECENT:
78
+ return self._compress_preserve_recent(messages, priority_indices)
79
+ elif self.strategy == CompressionStrategy.TRUNCATE_MIDDLE:
80
+ return self._compress_truncate_middle(messages, priority_indices)
81
+ elif self.strategy == CompressionStrategy.TRUNCATE_START:
82
+ return self._compress_truncate_start(messages)
83
+ else:
84
+ # Default: preserve recent
85
+ return self._compress_preserve_recent(messages, priority_indices)
86
+
87
+ def _compress_preserve_recent(
88
+ self,
89
+ messages: List[LLMMessage],
90
+ priority_indices: Optional[List[int]]
91
+ ) -> List[LLMMessage]:
92
+ """Preserve recent messages and priority messages."""
93
+ priority_indices = set(priority_indices or [])
94
+ compressed = []
95
+
96
+ # Always include system messages if enabled
97
+ if self.preserve_system:
98
+ system_msgs = [msg for msg in messages if msg.role == "system"]
99
+ compressed.extend(system_msgs)
100
+
101
+ # Calculate remaining budget
102
+ remaining_tokens = self.max_tokens - self._estimate_tokens(compressed)
103
+
104
+ # Add priority messages
105
+ for idx in priority_indices:
106
+ if idx < len(messages) and messages[idx] not in compressed:
107
+ msg_tokens = self._estimate_tokens([messages[idx]])
108
+ if msg_tokens <= remaining_tokens:
109
+ compressed.append(messages[idx])
110
+ remaining_tokens -= msg_tokens
111
+
112
+ # Add recent messages (from end)
113
+ for msg in reversed(messages):
114
+ if msg not in compressed:
115
+ msg_tokens = self._estimate_tokens([msg])
116
+ if msg_tokens <= remaining_tokens:
117
+ compressed.insert(len(compressed), msg)
118
+ remaining_tokens -= msg_tokens
119
+ else:
120
+ break
121
+
122
+ return compressed
123
+
124
+ def _compress_truncate_middle(
125
+ self,
126
+ messages: List[LLMMessage],
127
+ priority_indices: Optional[List[int]]
128
+ ) -> List[LLMMessage]:
129
+ """Keep start and end messages, truncate middle."""
130
+ if len(messages) <= 4:
131
+ return messages
132
+
133
+ # Keep first 2 and last 2 by default
134
+ keep_start = 2
135
+ keep_end = 2
136
+
137
+ # Adjust based on token budget
138
+ start_msgs = messages[:keep_start]
139
+ end_msgs = messages[-keep_end:]
140
+
141
+ compressed = start_msgs + [
142
+ LLMMessage(role="system", content="[... conversation history compressed ...]")
143
+ ] + end_msgs
144
+
145
+ return compressed
146
+
147
+ def _compress_truncate_start(self, messages: List[LLMMessage]) -> List[LLMMessage]:
148
+ """Keep recent messages, truncate start."""
149
+ compressed = []
150
+ remaining_tokens = self.max_tokens
151
+
152
+ # Process from end
153
+ for msg in reversed(messages):
154
+ msg_tokens = self._estimate_tokens([msg])
155
+ if msg_tokens <= remaining_tokens:
156
+ compressed.insert(0, msg)
157
+ remaining_tokens -= msg_tokens
158
+ else:
159
+ break
160
+
161
+ return compressed
162
+
163
+ def _estimate_tokens(self, messages: List[LLMMessage]) -> int:
164
+ """
165
+ Estimate token count for messages.
166
+
167
+ Args:
168
+ messages: List of messages
169
+
170
+ Returns:
171
+ Estimated token count
172
+ """
173
+ total_chars = sum(len(msg.content) for msg in messages)
174
+ # Rough estimate: 4 chars ≈ 1 token
175
+ return total_chars // 4
176
+
177
+ def compress_text(self, text: str, max_tokens: int) -> str:
178
+ """
179
+ Compress text to fit within token limit.
180
+
181
+ Args:
182
+ text: Text to compress
183
+ max_tokens: Maximum tokens
184
+
185
+ Returns:
186
+ Compressed text
187
+ """
188
+ estimated_tokens = len(text) // 4
189
+
190
+ if estimated_tokens <= max_tokens:
191
+ return text
192
+
193
+ # Truncate to fit
194
+ max_chars = max_tokens * 4
195
+ if len(text) <= max_chars:
196
+ return text
197
+
198
+ return text[:max_chars - 20] + "... [truncated]"
199
+
200
+
201
+ def compress_messages(
202
+ messages: List[LLMMessage],
203
+ max_tokens: int = 4000,
204
+ strategy: CompressionStrategy = CompressionStrategy.PRESERVE_RECENT
205
+ ) -> List[LLMMessage]:
206
+ """
207
+ Convenience function for compressing messages.
208
+
209
+ Args:
210
+ messages: List of messages
211
+ max_tokens: Maximum token limit
212
+ strategy: Compression strategy
213
+
214
+ Returns:
215
+ Compressed message list
216
+ """
217
+ compressor = ContextCompressor(max_tokens=max_tokens, strategy=strategy)
218
+ return compressor.compress_messages(messages)
219
+
@@ -0,0 +1,258 @@
1
+ """
2
+ ContextEngine Adapter
3
+
4
+ Adapter for integrating agent persistence with AIECS ContextEngine.
5
+ """
6
+
7
+ import logging
8
+ import json
9
+ import uuid
10
+ from typing import Dict, Any, Optional, List, TYPE_CHECKING
11
+ from datetime import datetime
12
+
13
+ if TYPE_CHECKING:
14
+ from aiecs.domain.context.context_engine import ContextEngine
15
+
16
+ from aiecs.domain.agent.persistence import AgentPersistence
17
+ from aiecs.domain.agent.base_agent import BaseAIAgent
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ContextEngineAdapter:
23
+ """
24
+ Adapter for persisting agent state to ContextEngine.
25
+
26
+ Uses ContextEngine's checkpoint system for versioned state storage
27
+ and TaskContext for session-based state management.
28
+ """
29
+
30
+ def __init__(self, context_engine: "ContextEngine", user_id: str = "system"):
31
+ """
32
+ Initialize adapter.
33
+
34
+ Args:
35
+ context_engine: ContextEngine instance
36
+ user_id: User identifier for session management
37
+ """
38
+ if context_engine is None:
39
+ raise ValueError("ContextEngine instance is required")
40
+
41
+ self.context_engine = context_engine
42
+ self.user_id = user_id
43
+ self._agent_state_prefix = "agent_state"
44
+ self._agent_conversation_prefix = "agent_conversation"
45
+ logger.info("ContextEngineAdapter initialized")
46
+
47
+ async def save_agent_state(
48
+ self,
49
+ agent_id: str,
50
+ state: Dict[str, Any],
51
+ version: Optional[str] = None
52
+ ) -> str:
53
+ """
54
+ Save agent state to ContextEngine using checkpoint system.
55
+
56
+ Args:
57
+ agent_id: Agent identifier
58
+ state: Agent state dictionary
59
+ version: Optional version identifier (auto-generated if None)
60
+
61
+ Returns:
62
+ Version identifier
63
+ """
64
+ if version is None:
65
+ version = str(uuid.uuid4())
66
+
67
+ checkpoint_data = {
68
+ "agent_id": agent_id,
69
+ "state": state,
70
+ "timestamp": datetime.utcnow().isoformat(),
71
+ "version": version
72
+ }
73
+
74
+ # Store as checkpoint (thread_id = agent_id)
75
+ await self.context_engine.store_checkpoint(
76
+ thread_id=agent_id,
77
+ checkpoint_id=version,
78
+ checkpoint_data=checkpoint_data,
79
+ metadata={"type": "agent_state", "agent_id": agent_id}
80
+ )
81
+
82
+ logger.debug(f"Saved agent {agent_id} state version {version} to ContextEngine")
83
+ return version
84
+
85
+ async def load_agent_state(
86
+ self,
87
+ agent_id: str,
88
+ version: Optional[str] = None
89
+ ) -> Optional[Dict[str, Any]]:
90
+ """
91
+ Load agent state from ContextEngine.
92
+
93
+ Args:
94
+ agent_id: Agent identifier
95
+ version: Optional version identifier (loads latest if None)
96
+
97
+ Returns:
98
+ Agent state dictionary or None
99
+ """
100
+ checkpoint = await self.context_engine.get_checkpoint(
101
+ thread_id=agent_id,
102
+ checkpoint_id=version
103
+ )
104
+
105
+ if checkpoint and "data" in checkpoint:
106
+ checkpoint_data = checkpoint["data"]
107
+ if isinstance(checkpoint_data, dict) and "state" in checkpoint_data:
108
+ logger.debug(f"Loaded agent {agent_id} state version {version or 'latest'}")
109
+ return checkpoint_data["state"]
110
+
111
+ logger.debug(f"No state found for agent {agent_id} version {version or 'latest'}")
112
+ return None
113
+
114
+ async def list_agent_versions(self, agent_id: str) -> List[Dict[str, Any]]:
115
+ """
116
+ List all versions of an agent's state.
117
+
118
+ Args:
119
+ agent_id: Agent identifier
120
+
121
+ Returns:
122
+ List of version metadata dictionaries
123
+ """
124
+ checkpoints = await self.context_engine.list_checkpoints(thread_id=agent_id)
125
+ if not checkpoints:
126
+ return []
127
+
128
+ versions = []
129
+ for checkpoint in checkpoints:
130
+ # list_checkpoints returns dicts with "data" key containing checkpoint_data
131
+ if isinstance(checkpoint, dict):
132
+ data = checkpoint.get("data", {})
133
+ if isinstance(data, dict) and "version" in data:
134
+ versions.append({
135
+ "version": data["version"],
136
+ "timestamp": data.get("timestamp"),
137
+ "metadata": checkpoint.get("metadata", {})
138
+ })
139
+
140
+ # Sort by timestamp descending
141
+ versions.sort(key=lambda v: v.get("timestamp", ""), reverse=True)
142
+ return versions
143
+
144
+ async def save_conversation_history(
145
+ self,
146
+ session_id: str,
147
+ messages: List[Dict[str, Any]]
148
+ ) -> None:
149
+ """
150
+ Save conversation history to ContextEngine.
151
+
152
+ Args:
153
+ session_id: Session identifier
154
+ messages: List of message dictionaries with 'role' and 'content'
155
+ """
156
+ # Ensure session exists
157
+ session = await self.context_engine.get_session(session_id)
158
+ if not session:
159
+ await self.context_engine.create_session(
160
+ session_id=session_id,
161
+ user_id=self.user_id,
162
+ metadata={"type": "agent_conversation"}
163
+ )
164
+
165
+ # Store messages using ContextEngine's conversation API
166
+ for msg in messages:
167
+ role = msg.get("role", "user")
168
+ content = msg.get("content", "")
169
+ metadata = msg.get("metadata", {})
170
+
171
+ await self.context_engine.add_conversation_message(
172
+ session_id=session_id,
173
+ role=role,
174
+ content=content,
175
+ metadata=metadata
176
+ )
177
+
178
+ logger.debug(f"Saved {len(messages)} messages to session {session_id}")
179
+
180
+ async def load_conversation_history(
181
+ self,
182
+ session_id: str,
183
+ limit: int = 50
184
+ ) -> List[Dict[str, Any]]:
185
+ """
186
+ Load conversation history from ContextEngine.
187
+
188
+ Args:
189
+ session_id: Session identifier
190
+ limit: Maximum number of messages to retrieve
191
+
192
+ Returns:
193
+ List of message dictionaries
194
+ """
195
+ messages = await self.context_engine.get_conversation_history(
196
+ session_id=session_id,
197
+ limit=limit
198
+ )
199
+
200
+ # Convert ConversationMessage objects to dictionaries
201
+ result = []
202
+ for msg in messages:
203
+ result.append({
204
+ "role": msg.role,
205
+ "content": msg.content,
206
+ "timestamp": msg.timestamp.isoformat() if hasattr(msg.timestamp, 'isoformat') else str(msg.timestamp),
207
+ "metadata": msg.metadata
208
+ })
209
+
210
+ logger.debug(f"Loaded {len(result)} messages from session {session_id}")
211
+ return result
212
+
213
+ async def delete_agent_state(
214
+ self,
215
+ agent_id: str,
216
+ version: Optional[str] = None
217
+ ) -> None:
218
+ """
219
+ Delete agent state from ContextEngine.
220
+
221
+ Args:
222
+ agent_id: Agent identifier
223
+ version: Optional version identifier (deletes all if None)
224
+ """
225
+ # Note: ContextEngine doesn't have explicit delete for checkpoints
226
+ # We'll store a tombstone checkpoint or rely on TTL
227
+ if version:
228
+ # Store empty state as deletion marker
229
+ await self.context_engine.store_checkpoint(
230
+ thread_id=agent_id,
231
+ checkpoint_id=f"{version}_deleted",
232
+ checkpoint_data={"deleted": True, "original_version": version},
233
+ metadata={"type": "deletion_marker"}
234
+ )
235
+ logger.debug(f"Marked agent {agent_id} state version {version or 'all'} for deletion")
236
+
237
+ # AgentPersistence Protocol implementation
238
+ async def save(self, agent: BaseAIAgent) -> None:
239
+ """Save agent state (implements AgentPersistence protocol)."""
240
+ state = agent.to_dict()
241
+ await self.save_agent_state(agent.agent_id, state)
242
+
243
+ async def load(self, agent_id: str) -> Dict[str, Any]:
244
+ """Load agent state (implements AgentPersistence protocol)."""
245
+ state = await self.load_agent_state(agent_id)
246
+ if state is None:
247
+ raise KeyError(f"Agent {agent_id} not found in storage")
248
+ return state
249
+
250
+ async def exists(self, agent_id: str) -> bool:
251
+ """Check if agent state exists (implements AgentPersistence protocol)."""
252
+ state = await self.load_agent_state(agent_id)
253
+ return state is not None
254
+
255
+ async def delete(self, agent_id: str) -> None:
256
+ """Delete agent state (implements AgentPersistence protocol)."""
257
+ await self.delete_agent_state(agent_id)
258
+
@@ -0,0 +1,228 @@
1
+ """
2
+ Enhanced Retry Policy
3
+
4
+ Sophisticated retry logic with exponential backoff and error classification.
5
+ """
6
+
7
+ import asyncio
8
+ import random
9
+ import logging
10
+ from typing import Optional, Callable, Any
11
+ from datetime import datetime
12
+ from enum import Enum
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ErrorType(Enum):
18
+ """Error types for classification."""
19
+ RATE_LIMIT = "rate_limit"
20
+ TIMEOUT = "timeout"
21
+ SERVER_ERROR = "server_error"
22
+ CLIENT_ERROR = "client_error"
23
+ NETWORK_ERROR = "network_error"
24
+ UNKNOWN = "unknown"
25
+
26
+
27
+ class ErrorClassifier:
28
+ """Classifies errors for retry strategy."""
29
+
30
+ @staticmethod
31
+ def classify(error: Exception) -> ErrorType:
32
+ """
33
+ Classify error type.
34
+
35
+ Args:
36
+ error: Exception to classify
37
+
38
+ Returns:
39
+ ErrorType
40
+ """
41
+ error_str = str(error).lower()
42
+ error_type_name = type(error).__name__.lower()
43
+
44
+ # Rate limit errors
45
+ if "rate limit" in error_str or "429" in error_str:
46
+ return ErrorType.RATE_LIMIT
47
+
48
+ # Timeout errors
49
+ if "timeout" in error_str or "timed out" in error_str:
50
+ return ErrorType.TIMEOUT
51
+
52
+ # Server errors (5xx)
53
+ if any(code in error_str for code in ["500", "502", "503", "504"]):
54
+ return ErrorType.SERVER_ERROR
55
+
56
+ # Client errors (4xx)
57
+ if any(code in error_str for code in ["400", "401", "403", "404"]):
58
+ return ErrorType.CLIENT_ERROR
59
+
60
+ # Network errors
61
+ if any(term in error_type_name for term in ["connection", "network", "socket"]):
62
+ return ErrorType.NETWORK_ERROR
63
+
64
+ return ErrorType.UNKNOWN
65
+
66
+ @staticmethod
67
+ def is_retryable(error_type: ErrorType) -> bool:
68
+ """
69
+ Determine if error type should be retried.
70
+
71
+ Args:
72
+ error_type: Error type
73
+
74
+ Returns:
75
+ True if retryable
76
+ """
77
+ retryable_types = {
78
+ ErrorType.RATE_LIMIT,
79
+ ErrorType.TIMEOUT,
80
+ ErrorType.SERVER_ERROR,
81
+ ErrorType.NETWORK_ERROR,
82
+ }
83
+ return error_type in retryable_types
84
+
85
+
86
+ class EnhancedRetryPolicy:
87
+ """
88
+ Enhanced retry policy with exponential backoff and jitter.
89
+
90
+ Example:
91
+ policy = EnhancedRetryPolicy(max_retries=5, base_delay=1.0)
92
+ result = await policy.execute_with_retry(my_async_function, arg1, arg2)
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ max_retries: int = 3,
98
+ base_delay: float = 1.0,
99
+ max_delay: float = 60.0,
100
+ exponential_base: float = 2.0,
101
+ jitter: bool = True,
102
+ ):
103
+ """
104
+ Initialize retry policy.
105
+
106
+ Args:
107
+ max_retries: Maximum number of retry attempts
108
+ base_delay: Base delay in seconds
109
+ max_delay: Maximum delay in seconds
110
+ exponential_base: Base for exponential backoff
111
+ jitter: Whether to add random jitter
112
+ """
113
+ self.max_retries = max_retries
114
+ self.base_delay = base_delay
115
+ self.max_delay = max_delay
116
+ self.exponential_base = exponential_base
117
+ self.jitter = jitter
118
+
119
+ def calculate_delay(self, attempt: int, error_type: ErrorType) -> float:
120
+ """
121
+ Calculate delay for retry attempt.
122
+
123
+ Args:
124
+ attempt: Retry attempt number (0-indexed)
125
+ error_type: Type of error
126
+
127
+ Returns:
128
+ Delay in seconds
129
+ """
130
+ # Base exponential backoff
131
+ delay = min(
132
+ self.base_delay * (self.exponential_base ** attempt),
133
+ self.max_delay
134
+ )
135
+
136
+ # Adjust for error type
137
+ if error_type == ErrorType.RATE_LIMIT:
138
+ # Longer delay for rate limits
139
+ delay *= 2
140
+
141
+ # Add jitter to prevent thundering herd
142
+ if self.jitter:
143
+ delay *= (0.5 + random.random())
144
+
145
+ return delay
146
+
147
+ async def execute_with_retry(
148
+ self,
149
+ func: Callable,
150
+ *args,
151
+ **kwargs
152
+ ) -> Any:
153
+ """
154
+ Execute function with retry logic.
155
+
156
+ Args:
157
+ func: Async function to execute
158
+ *args: Function arguments
159
+ **kwargs: Function keyword arguments
160
+
161
+ Returns:
162
+ Function result
163
+
164
+ Raises:
165
+ Exception: If all retries exhausted
166
+ """
167
+ last_error = None
168
+
169
+ for attempt in range(self.max_retries + 1):
170
+ try:
171
+ result = await func(*args, **kwargs)
172
+
173
+ # Log success after retries
174
+ if attempt > 0:
175
+ logger.info(f"Succeeded after {attempt} retries")
176
+
177
+ return result
178
+
179
+ except Exception as e:
180
+ last_error = e
181
+
182
+ # Classify error
183
+ error_type = ErrorClassifier.classify(e)
184
+
185
+ # Check if we should retry
186
+ if attempt >= self.max_retries:
187
+ logger.error(f"Max retries ({self.max_retries}) exhausted")
188
+ break
189
+
190
+ if not ErrorClassifier.is_retryable(error_type):
191
+ logger.error(f"Non-retryable error: {error_type.value}")
192
+ break
193
+
194
+ # Calculate delay and wait
195
+ delay = self.calculate_delay(attempt, error_type)
196
+ logger.warning(
197
+ f"Attempt {attempt + 1} failed ({error_type.value}). "
198
+ f"Retrying in {delay:.2f}s..."
199
+ )
200
+ await asyncio.sleep(delay)
201
+
202
+ # All retries exhausted
203
+ raise last_error
204
+
205
+
206
+ async def with_retry(
207
+ func: Callable,
208
+ max_retries: int = 3,
209
+ base_delay: float = 1.0,
210
+ *args,
211
+ **kwargs
212
+ ) -> Any:
213
+ """
214
+ Convenience function for executing with retry.
215
+
216
+ Args:
217
+ func: Async function to execute
218
+ max_retries: Maximum number of retries
219
+ base_delay: Base delay in seconds
220
+ *args: Function arguments
221
+ **kwargs: Function keyword arguments
222
+
223
+ Returns:
224
+ Function result
225
+ """
226
+ policy = EnhancedRetryPolicy(max_retries=max_retries, base_delay=base_delay)
227
+ return await policy.execute_with_retry(func, *args, **kwargs)
228
+