daita-agents 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of daita-agents might be problematic. Click here for more details.
- daita/__init__.py +208 -0
- daita/agents/__init__.py +33 -0
- daita/agents/base.py +722 -0
- daita/agents/substrate.py +895 -0
- daita/cli/__init__.py +145 -0
- daita/cli/__main__.py +7 -0
- daita/cli/ascii_art.py +44 -0
- daita/cli/core/__init__.py +0 -0
- daita/cli/core/create.py +254 -0
- daita/cli/core/deploy.py +473 -0
- daita/cli/core/deployments.py +309 -0
- daita/cli/core/import_detector.py +219 -0
- daita/cli/core/init.py +382 -0
- daita/cli/core/logs.py +239 -0
- daita/cli/core/managed_deploy.py +709 -0
- daita/cli/core/run.py +648 -0
- daita/cli/core/status.py +421 -0
- daita/cli/core/test.py +239 -0
- daita/cli/core/webhooks.py +172 -0
- daita/cli/main.py +588 -0
- daita/cli/utils.py +541 -0
- daita/config/__init__.py +62 -0
- daita/config/base.py +159 -0
- daita/config/settings.py +184 -0
- daita/core/__init__.py +262 -0
- daita/core/decision_tracing.py +701 -0
- daita/core/exceptions.py +480 -0
- daita/core/focus.py +251 -0
- daita/core/interfaces.py +76 -0
- daita/core/plugin_tracing.py +550 -0
- daita/core/relay.py +695 -0
- daita/core/reliability.py +381 -0
- daita/core/scaling.py +444 -0
- daita/core/tools.py +402 -0
- daita/core/tracing.py +770 -0
- daita/core/workflow.py +1084 -0
- daita/display/__init__.py +1 -0
- daita/display/console.py +160 -0
- daita/execution/__init__.py +58 -0
- daita/execution/client.py +856 -0
- daita/execution/exceptions.py +92 -0
- daita/execution/models.py +317 -0
- daita/llm/__init__.py +60 -0
- daita/llm/anthropic.py +166 -0
- daita/llm/base.py +373 -0
- daita/llm/factory.py +101 -0
- daita/llm/gemini.py +152 -0
- daita/llm/grok.py +114 -0
- daita/llm/mock.py +135 -0
- daita/llm/openai.py +109 -0
- daita/plugins/__init__.py +141 -0
- daita/plugins/base.py +37 -0
- daita/plugins/base_db.py +167 -0
- daita/plugins/elasticsearch.py +844 -0
- daita/plugins/mcp.py +481 -0
- daita/plugins/mongodb.py +510 -0
- daita/plugins/mysql.py +351 -0
- daita/plugins/postgresql.py +331 -0
- daita/plugins/redis_messaging.py +500 -0
- daita/plugins/rest.py +529 -0
- daita/plugins/s3.py +761 -0
- daita/plugins/slack.py +729 -0
- daita/utils/__init__.py +18 -0
- daita_agents-0.1.0.dist-info/METADATA +350 -0
- daita_agents-0.1.0.dist-info/RECORD +69 -0
- daita_agents-0.1.0.dist-info/WHEEL +5 -0
- daita_agents-0.1.0.dist-info/entry_points.txt +2 -0
- daita_agents-0.1.0.dist-info/licenses/LICENSE +56 -0
- daita_agents-0.1.0.dist-info/top_level.txt +1 -0
daita/core/relay.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Simplified Relay System for Daita Agents.
|
|
3
|
+
|
|
4
|
+
Provides simple message passing between agents in a workflow.
|
|
5
|
+
Focus on essential communication functionality without complex features.
|
|
6
|
+
|
|
7
|
+
Updated to pass only result data between agents for cleaner interfaces.
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
```python
|
|
11
|
+
from daita.core.relay import RelayManager
|
|
12
|
+
|
|
13
|
+
# Initialize relay manager
|
|
14
|
+
relay = RelayManager()
|
|
15
|
+
|
|
16
|
+
# Agent publishes data (full response gets stored, only result gets forwarded)
|
|
17
|
+
await relay.publish("data_channel", {
|
|
18
|
+
"status": "success",
|
|
19
|
+
"result": {"processed": True},
|
|
20
|
+
"agent_id": "agent_123"
|
|
21
|
+
})
|
|
22
|
+
|
|
23
|
+
# Another agent subscribes and receives only: {"processed": True}
|
|
24
|
+
async def handle_data(result_data):
|
|
25
|
+
print(f"Got result: {result_data}")
|
|
26
|
+
|
|
27
|
+
await relay.subscribe("data_channel", handle_data)
|
|
28
|
+
```
|
|
29
|
+
"""
|
|
30
|
+
import asyncio
|
|
31
|
+
import logging
|
|
32
|
+
import random
|
|
33
|
+
import time
|
|
34
|
+
import uuid
|
|
35
|
+
from typing import Dict, Any, Optional, List, Callable, Union
|
|
36
|
+
from collections import deque
|
|
37
|
+
from dataclasses import dataclass, field
|
|
38
|
+
from enum import Enum
|
|
39
|
+
import weakref
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from ..core.exceptions import DaitaError, AcknowledgmentTimeoutError
|
|
43
|
+
except ImportError:
|
|
44
|
+
# Fallback for direct execution or testing
|
|
45
|
+
from core.exceptions import DaitaError, AcknowledgmentTimeoutError
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
class RelayError(DaitaError):
|
|
50
|
+
"""Exception raised for relay-related errors."""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
class MessageStatus(str, Enum):
|
|
54
|
+
"""Status of a message in the relay system."""
|
|
55
|
+
PENDING = "pending"
|
|
56
|
+
ACKNOWLEDGED = "acknowledged"
|
|
57
|
+
FAILED = "failed"
|
|
58
|
+
TIMEOUT = "timeout"
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class ReliableMessage:
|
|
62
|
+
"""A message with acknowledgment tracking."""
|
|
63
|
+
id: str
|
|
64
|
+
channel: str
|
|
65
|
+
data: Any
|
|
66
|
+
publisher: Optional[str] = None
|
|
67
|
+
timestamp: float = field(default_factory=time.time)
|
|
68
|
+
status: MessageStatus = MessageStatus.PENDING
|
|
69
|
+
ack_timeout: float = 30.0 # Default 30 second timeout
|
|
70
|
+
error: Optional[str] = None
|
|
71
|
+
attempts: int = 0
|
|
72
|
+
max_attempts: int = 3
|
|
73
|
+
|
|
74
|
+
class RelayManager:
|
|
75
|
+
"""
|
|
76
|
+
Simple relay manager for agent-to-agent communication.
|
|
77
|
+
|
|
78
|
+
Handles message publishing and subscription with minimal complexity.
|
|
79
|
+
Now passes only result data to maintain clean interfaces.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
max_messages_per_channel: int = 10,
|
|
85
|
+
enable_reliability: bool = False,
|
|
86
|
+
default_ack_timeout: float = 30.0
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
Initialize the relay manager.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
max_messages_per_channel: Maximum messages to keep per channel (reduced to 10 for memory efficiency)
|
|
93
|
+
enable_reliability: Enable message acknowledgments and reliability features
|
|
94
|
+
default_ack_timeout: Default acknowledgment timeout in seconds
|
|
95
|
+
"""
|
|
96
|
+
self.max_messages_per_channel = max_messages_per_channel
|
|
97
|
+
self.enable_reliability = enable_reliability
|
|
98
|
+
self.default_ack_timeout = default_ack_timeout
|
|
99
|
+
|
|
100
|
+
# Channel storage: channel_name -> deque of result data only
|
|
101
|
+
self.channels: Dict[str, deque] = {}
|
|
102
|
+
|
|
103
|
+
# Subscribers: channel_name -> set of callbacks
|
|
104
|
+
self.subscribers: Dict[str, weakref.WeakSet] = {}
|
|
105
|
+
|
|
106
|
+
# Per-channel locks to prevent race conditions between publish and subscribe
|
|
107
|
+
self._channel_locks: Dict[str, asyncio.Lock] = {}
|
|
108
|
+
|
|
109
|
+
# Subscriber error tracking
|
|
110
|
+
self.subscriber_errors: deque = deque(maxlen=100)
|
|
111
|
+
|
|
112
|
+
# Reliability features (when enabled)
|
|
113
|
+
self.pending_messages: Dict[str, ReliableMessage] = {} if enable_reliability else {}
|
|
114
|
+
self.message_timeouts: Dict[str, asyncio.Task] = {} if enable_reliability else {}
|
|
115
|
+
|
|
116
|
+
# Running state
|
|
117
|
+
self._running = False
|
|
118
|
+
|
|
119
|
+
logger.debug(f"RelayManager initialized (reliability: {enable_reliability})")
|
|
120
|
+
|
|
121
|
+
async def start(self) -> None:
|
|
122
|
+
"""Start the relay manager."""
|
|
123
|
+
if self._running:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
self._running = True
|
|
127
|
+
logger.info("RelayManager started")
|
|
128
|
+
|
|
129
|
+
async def stop(self) -> None:
|
|
130
|
+
"""Stop the relay manager and cleanup."""
|
|
131
|
+
if not self._running:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
self._running = False
|
|
135
|
+
|
|
136
|
+
# Clear all channels and subscribers
|
|
137
|
+
self.channels.clear()
|
|
138
|
+
self.subscribers.clear()
|
|
139
|
+
|
|
140
|
+
logger.info("RelayManager stopped")
|
|
141
|
+
|
|
142
|
+
def _ensure_channel(self, channel: str) -> None:
|
|
143
|
+
"""Ensure channel exists."""
|
|
144
|
+
if channel not in self.channels:
|
|
145
|
+
self.channels[channel] = deque(maxlen=self.max_messages_per_channel)
|
|
146
|
+
self.subscribers[channel] = weakref.WeakSet()
|
|
147
|
+
self._channel_locks[channel] = asyncio.Lock()
|
|
148
|
+
|
|
149
|
+
async def publish(
|
|
150
|
+
self,
|
|
151
|
+
channel: str,
|
|
152
|
+
agent_response: Dict[str, Any],
|
|
153
|
+
publisher: Optional[str] = None,
|
|
154
|
+
require_ack: Optional[bool] = None
|
|
155
|
+
) -> Optional[str]:
|
|
156
|
+
"""
|
|
157
|
+
Publish data to a channel.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
channel: Channel name
|
|
161
|
+
agent_response: Full agent response (we extract 'result' field)
|
|
162
|
+
publisher: Optional publisher identifier
|
|
163
|
+
require_ack: Whether this message requires acknowledgment (overrides global setting)
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Message ID if reliability is enabled, None otherwise
|
|
167
|
+
"""
|
|
168
|
+
if not self._running:
|
|
169
|
+
await self.start()
|
|
170
|
+
|
|
171
|
+
# Extract just the result data
|
|
172
|
+
result_data = agent_response.get('result')
|
|
173
|
+
|
|
174
|
+
if result_data is None:
|
|
175
|
+
logger.warning(f"No 'result' field found in agent response for channel '{channel}'")
|
|
176
|
+
result_data = agent_response # Fallback to full response if no result field
|
|
177
|
+
|
|
178
|
+
self._ensure_channel(channel)
|
|
179
|
+
|
|
180
|
+
# Determine if we need reliability
|
|
181
|
+
needs_reliability = require_ack if require_ack is not None else self.enable_reliability
|
|
182
|
+
|
|
183
|
+
if needs_reliability:
|
|
184
|
+
return await self._publish_reliable(channel, result_data, agent_response, publisher)
|
|
185
|
+
else:
|
|
186
|
+
return await self._publish_fire_and_forget(channel, result_data, agent_response, publisher)
|
|
187
|
+
|
|
188
|
+
async def _publish_fire_and_forget(
|
|
189
|
+
self,
|
|
190
|
+
channel: str,
|
|
191
|
+
result_data: Any,
|
|
192
|
+
agent_response: Dict[str, Any],
|
|
193
|
+
publisher: Optional[str]
|
|
194
|
+
) -> None:
|
|
195
|
+
"""Publish message without reliability features (original behavior)."""
|
|
196
|
+
# Use per-channel lock to make publish atomic
|
|
197
|
+
async with self._channel_locks[channel]:
|
|
198
|
+
# Create message with result data only
|
|
199
|
+
message = {
|
|
200
|
+
'data': result_data,
|
|
201
|
+
'publisher': publisher,
|
|
202
|
+
'timestamp': time.time()
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# Store result message
|
|
206
|
+
self.channels[channel].append(message)
|
|
207
|
+
|
|
208
|
+
# Notify subscribers with just the result data (while holding lock)
|
|
209
|
+
await self._notify_subscribers(channel, result_data)
|
|
210
|
+
|
|
211
|
+
logger.debug(f"Published result to channel '{channel}' from {publisher}")
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
async def _publish_reliable(
|
|
215
|
+
self,
|
|
216
|
+
channel: str,
|
|
217
|
+
result_data: Any,
|
|
218
|
+
agent_response: Dict[str, Any],
|
|
219
|
+
publisher: Optional[str]
|
|
220
|
+
) -> str:
|
|
221
|
+
"""Publish message with reliability features."""
|
|
222
|
+
# Create reliable message
|
|
223
|
+
message_id = uuid.uuid4().hex
|
|
224
|
+
reliable_message = ReliableMessage(
|
|
225
|
+
id=message_id,
|
|
226
|
+
channel=channel,
|
|
227
|
+
data=result_data,
|
|
228
|
+
publisher=publisher,
|
|
229
|
+
ack_timeout=self.default_ack_timeout
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Store pending message
|
|
233
|
+
self.pending_messages[message_id] = reliable_message
|
|
234
|
+
|
|
235
|
+
# Use per-channel lock to make publish atomic
|
|
236
|
+
async with self._channel_locks[channel]:
|
|
237
|
+
# Create message for channel storage
|
|
238
|
+
message = {
|
|
239
|
+
'id': message_id,
|
|
240
|
+
'data': result_data,
|
|
241
|
+
'publisher': publisher,
|
|
242
|
+
'timestamp': time.time(),
|
|
243
|
+
'requires_ack': True
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
# Store result message
|
|
247
|
+
self.channels[channel].append(message)
|
|
248
|
+
|
|
249
|
+
# Set up timeout task
|
|
250
|
+
timeout_task = asyncio.create_task(
|
|
251
|
+
self._handle_message_timeout(message_id, reliable_message.ack_timeout)
|
|
252
|
+
)
|
|
253
|
+
timeout_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
|
|
254
|
+
self.message_timeouts[message_id] = timeout_task
|
|
255
|
+
|
|
256
|
+
# Notify subscribers with message ID for acknowledgment (while holding lock)
|
|
257
|
+
await self._notify_subscribers_reliable(channel, result_data, message_id)
|
|
258
|
+
|
|
259
|
+
logger.debug(f"Published reliable message {message_id} to channel '{channel}' from {publisher}")
|
|
260
|
+
return message_id
|
|
261
|
+
|
|
262
|
+
async def _notify_subscribers(self, channel: str, result_data: Any) -> None:
|
|
263
|
+
"""Notify all subscribers of a channel with result data only."""
|
|
264
|
+
if channel not in self.subscribers:
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
# Get snapshot of subscribers to avoid modification during iteration
|
|
268
|
+
subscriber_list = list(self.subscribers[channel])
|
|
269
|
+
|
|
270
|
+
if not subscriber_list:
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
# Notify all subscribers concurrently
|
|
274
|
+
tasks = []
|
|
275
|
+
for subscriber in subscriber_list:
|
|
276
|
+
task = asyncio.create_task(self._call_subscriber(subscriber, result_data))
|
|
277
|
+
task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
|
|
278
|
+
tasks.append(task)
|
|
279
|
+
|
|
280
|
+
# Wait for all notifications to complete
|
|
281
|
+
if tasks:
|
|
282
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
283
|
+
|
|
284
|
+
async def _notify_subscribers_reliable(self, channel: str, result_data: Any, message_id: str) -> None:
|
|
285
|
+
"""Notify all subscribers of a reliable message."""
|
|
286
|
+
if channel not in self.subscribers:
|
|
287
|
+
return
|
|
288
|
+
|
|
289
|
+
# Get snapshot of subscribers to avoid modification during iteration
|
|
290
|
+
subscriber_list = list(self.subscribers[channel])
|
|
291
|
+
|
|
292
|
+
if not subscriber_list:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
# Notify all subscribers concurrently with message ID
|
|
296
|
+
tasks = []
|
|
297
|
+
for subscriber in subscriber_list:
|
|
298
|
+
task = asyncio.create_task(
|
|
299
|
+
self._call_subscriber_reliable(subscriber, result_data, message_id)
|
|
300
|
+
)
|
|
301
|
+
task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
|
|
302
|
+
tasks.append(task)
|
|
303
|
+
|
|
304
|
+
# Wait for all notifications to complete
|
|
305
|
+
if tasks:
|
|
306
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
307
|
+
|
|
308
|
+
async def _call_subscriber_reliable(self, callback: Callable, result_data: Any, message_id: str) -> None:
|
|
309
|
+
"""Safely call a subscriber callback for reliable message."""
|
|
310
|
+
try:
|
|
311
|
+
if asyncio.iscoroutinefunction(callback):
|
|
312
|
+
# Check if callback supports message_id parameter
|
|
313
|
+
import inspect
|
|
314
|
+
sig = inspect.signature(callback)
|
|
315
|
+
if 'message_id' in sig.parameters:
|
|
316
|
+
await callback(result_data, message_id=message_id)
|
|
317
|
+
else:
|
|
318
|
+
await callback(result_data)
|
|
319
|
+
else:
|
|
320
|
+
# Run sync callback in thread pool
|
|
321
|
+
loop = asyncio.get_event_loop()
|
|
322
|
+
await loop.run_in_executor(None, callback, result_data)
|
|
323
|
+
except Exception as e:
|
|
324
|
+
logger.error(f"Error in reliable subscriber callback: {str(e)}")
|
|
325
|
+
# NACK the message on callback error
|
|
326
|
+
await self.nack_message(message_id, str(e))
|
|
327
|
+
|
|
328
|
+
async def _call_subscriber(self, callback: Callable, result_data: Any) -> None:
|
|
329
|
+
"""Safely call a subscriber callback with result data."""
|
|
330
|
+
try:
|
|
331
|
+
if asyncio.iscoroutinefunction(callback):
|
|
332
|
+
await callback(result_data)
|
|
333
|
+
else:
|
|
334
|
+
# Run sync callback in thread pool
|
|
335
|
+
loop = asyncio.get_event_loop()
|
|
336
|
+
await loop.run_in_executor(None, callback, result_data)
|
|
337
|
+
except Exception as e:
|
|
338
|
+
error_info = {
|
|
339
|
+
'callback': str(callback),
|
|
340
|
+
'error': str(e),
|
|
341
|
+
'error_type': type(e).__name__,
|
|
342
|
+
'timestamp': time.time(),
|
|
343
|
+
'data_preview': str(result_data)[:100]
|
|
344
|
+
}
|
|
345
|
+
self.subscriber_errors.append(error_info)
|
|
346
|
+
logger.error(f"Subscriber callback failed: {e}")
|
|
347
|
+
# Don't re-raise - we want other subscribers to continue
|
|
348
|
+
|
|
349
|
+
async def subscribe(self, channel: str, callback: Callable) -> None:
|
|
350
|
+
"""
|
|
351
|
+
Subscribe to a channel.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
channel: Channel name
|
|
355
|
+
callback: Callback function to receive result data
|
|
356
|
+
"""
|
|
357
|
+
self._ensure_channel(channel)
|
|
358
|
+
self.subscribers[channel].add(callback)
|
|
359
|
+
|
|
360
|
+
logger.debug(f"Subscribed to channel '{channel}'")
|
|
361
|
+
|
|
362
|
+
def unsubscribe(self, channel: str, callback: Callable) -> bool:
|
|
363
|
+
"""
|
|
364
|
+
Unsubscribe from a channel.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
channel: Channel name
|
|
368
|
+
callback: Callback to remove
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
True if callback was removed
|
|
372
|
+
"""
|
|
373
|
+
if channel not in self.subscribers:
|
|
374
|
+
return False
|
|
375
|
+
|
|
376
|
+
try:
|
|
377
|
+
self.subscribers[channel].remove(callback)
|
|
378
|
+
logger.debug(f"Unsubscribed from channel '{channel}'")
|
|
379
|
+
return True
|
|
380
|
+
except KeyError:
|
|
381
|
+
return False
|
|
382
|
+
|
|
383
|
+
async def get_latest(self, channel: str, count: int = 1) -> List[Any]:
|
|
384
|
+
"""
|
|
385
|
+
Get latest result data from a channel.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
channel: Channel name
|
|
389
|
+
count: Number of messages to retrieve
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
List of result data (newest first)
|
|
393
|
+
"""
|
|
394
|
+
if channel not in self.channels:
|
|
395
|
+
return []
|
|
396
|
+
|
|
397
|
+
# Get latest messages
|
|
398
|
+
messages = list(self.channels[channel])
|
|
399
|
+
latest = messages[-count:] if count < len(messages) else messages
|
|
400
|
+
latest.reverse() # Newest first
|
|
401
|
+
|
|
402
|
+
# Return just the result data
|
|
403
|
+
return [msg['data'] for msg in latest]
|
|
404
|
+
|
|
405
|
+
def list_channels(self) -> List[str]:
|
|
406
|
+
"""List all channels."""
|
|
407
|
+
return list(self.channels.keys())
|
|
408
|
+
|
|
409
|
+
def clear_channel(self, channel: str) -> bool:
|
|
410
|
+
"""
|
|
411
|
+
Clear all messages from a channel.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
channel: Channel name
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
True if channel was cleared
|
|
418
|
+
"""
|
|
419
|
+
if channel not in self.channels:
|
|
420
|
+
return False
|
|
421
|
+
|
|
422
|
+
self.channels[channel].clear()
|
|
423
|
+
|
|
424
|
+
logger.debug(f"Cleared channel '{channel}'")
|
|
425
|
+
return True
|
|
426
|
+
|
|
427
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
428
|
+
"""Get relay manager statistics."""
|
|
429
|
+
total_messages = sum(len(ch) for ch in self.channels.values())
|
|
430
|
+
total_subscribers = sum(len(subs) for subs in self.subscribers.values())
|
|
431
|
+
|
|
432
|
+
stats = {
|
|
433
|
+
'running': self._running,
|
|
434
|
+
'total_channels': len(self.channels),
|
|
435
|
+
'total_messages': total_messages,
|
|
436
|
+
'total_subscribers': total_subscribers,
|
|
437
|
+
'channels': list(self.channels.keys()),
|
|
438
|
+
'subscriber_errors_count': len(self.subscriber_errors)
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
# Add per-channel error breakdown
|
|
442
|
+
if self.subscriber_errors:
|
|
443
|
+
stats['errors_by_channel'] = {}
|
|
444
|
+
for channel in self.channels.keys():
|
|
445
|
+
channel_errors = sum(
|
|
446
|
+
1 for err in self.subscriber_errors
|
|
447
|
+
if channel in err.get('callback', '')
|
|
448
|
+
)
|
|
449
|
+
if channel_errors > 0:
|
|
450
|
+
stats['errors_by_channel'][channel] = channel_errors
|
|
451
|
+
|
|
452
|
+
# Add reliability stats if enabled
|
|
453
|
+
if self.enable_reliability:
|
|
454
|
+
stats.update({
|
|
455
|
+
'reliability_enabled': True,
|
|
456
|
+
'pending_messages': len(self.pending_messages),
|
|
457
|
+
'active_timeouts': len(self.message_timeouts)
|
|
458
|
+
})
|
|
459
|
+
else:
|
|
460
|
+
stats['reliability_enabled'] = False
|
|
461
|
+
|
|
462
|
+
return stats
|
|
463
|
+
|
|
464
|
+
# Reliability methods (acknowledgments and timeouts)
|
|
465
|
+
|
|
466
|
+
async def ack_message(self, message_id: str) -> bool:
|
|
467
|
+
"""
|
|
468
|
+
Acknowledge a message as successfully processed.
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
message_id: ID of the message to acknowledge
|
|
472
|
+
|
|
473
|
+
Returns:
|
|
474
|
+
True if message was acknowledged, False if not found
|
|
475
|
+
"""
|
|
476
|
+
if not self.enable_reliability or message_id not in self.pending_messages:
|
|
477
|
+
return False
|
|
478
|
+
|
|
479
|
+
# Update message status
|
|
480
|
+
message = self.pending_messages[message_id]
|
|
481
|
+
message.status = MessageStatus.ACKNOWLEDGED
|
|
482
|
+
|
|
483
|
+
# Cancel timeout task
|
|
484
|
+
if message_id in self.message_timeouts:
|
|
485
|
+
timeout_task = self.message_timeouts.pop(message_id)
|
|
486
|
+
if not timeout_task.done():
|
|
487
|
+
timeout_task.cancel()
|
|
488
|
+
|
|
489
|
+
# Remove from pending
|
|
490
|
+
del self.pending_messages[message_id]
|
|
491
|
+
|
|
492
|
+
logger.debug(f"Acknowledged message {message_id}")
|
|
493
|
+
return True
|
|
494
|
+
|
|
495
|
+
async def nack_message(self, message_id: str, error: str) -> bool:
|
|
496
|
+
"""
|
|
497
|
+
Negative acknowledge a message (processing failed).
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
message_id: ID of the message to NACK
|
|
501
|
+
error: Error message
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
True if message was NACKed, False if not found
|
|
505
|
+
"""
|
|
506
|
+
if not self.enable_reliability or message_id not in self.pending_messages:
|
|
507
|
+
return False
|
|
508
|
+
|
|
509
|
+
# Update message status
|
|
510
|
+
message = self.pending_messages[message_id]
|
|
511
|
+
message.status = MessageStatus.FAILED
|
|
512
|
+
message.error = error
|
|
513
|
+
message.attempts += 1
|
|
514
|
+
|
|
515
|
+
# Cancel timeout task
|
|
516
|
+
if message_id in self.message_timeouts:
|
|
517
|
+
timeout_task = self.message_timeouts.pop(message_id)
|
|
518
|
+
if not timeout_task.done():
|
|
519
|
+
timeout_task.cancel()
|
|
520
|
+
|
|
521
|
+
# Check if we should retry
|
|
522
|
+
if message.attempts < message.max_attempts:
|
|
523
|
+
logger.warning(f"Message {message_id} failed (attempt {message.attempts}/{message.max_attempts}): {error}")
|
|
524
|
+
await self._schedule_retry(message)
|
|
525
|
+
else:
|
|
526
|
+
logger.error(f"Message {message_id} failed permanently after {message.attempts} attempts: {error}")
|
|
527
|
+
# Remove from pending messages
|
|
528
|
+
del self.pending_messages[message_id]
|
|
529
|
+
|
|
530
|
+
return True
|
|
531
|
+
|
|
532
|
+
async def _handle_message_timeout(self, message_id: str, timeout_duration: float) -> None:
|
|
533
|
+
"""Handle message acknowledgment timeout."""
|
|
534
|
+
try:
|
|
535
|
+
await asyncio.sleep(timeout_duration)
|
|
536
|
+
|
|
537
|
+
# Check if message is still pending
|
|
538
|
+
if message_id in self.pending_messages:
|
|
539
|
+
message = self.pending_messages[message_id]
|
|
540
|
+
message.status = MessageStatus.TIMEOUT
|
|
541
|
+
message.attempts += 1
|
|
542
|
+
|
|
543
|
+
logger.warning(f"Message {message_id} timed out after {timeout_duration}s")
|
|
544
|
+
|
|
545
|
+
# Check if we should retry
|
|
546
|
+
if message.attempts < message.max_attempts:
|
|
547
|
+
await self._schedule_retry(message)
|
|
548
|
+
else:
|
|
549
|
+
logger.error(f"Message {message_id} timed out permanently after {message.attempts} attempts")
|
|
550
|
+
# Remove from pending messages
|
|
551
|
+
del self.pending_messages[message_id]
|
|
552
|
+
|
|
553
|
+
# Remove timeout task
|
|
554
|
+
self.message_timeouts.pop(message_id, None)
|
|
555
|
+
|
|
556
|
+
except asyncio.CancelledError:
|
|
557
|
+
# Timeout was cancelled (message was acknowledged)
|
|
558
|
+
pass
|
|
559
|
+
except Exception as e:
|
|
560
|
+
logger.error(f"Error handling timeout for message {message_id}: {e}")
|
|
561
|
+
|
|
562
|
+
def _calculate_retry_delay(self, attempt: int, base_delay: float = 1.0) -> float:
|
|
563
|
+
"""Calculate retry delay with exponential backoff and jitter."""
|
|
564
|
+
# Exponential backoff: 1s, 2s, 4s, 8s, ...
|
|
565
|
+
delay = base_delay * (2 ** attempt)
|
|
566
|
+
|
|
567
|
+
# Cap at 60 seconds
|
|
568
|
+
delay = min(delay, 60.0)
|
|
569
|
+
|
|
570
|
+
# Add ±50% jitter to prevent thundering herd
|
|
571
|
+
jitter_factor = 0.5 + random.random() # Random between 0.5 and 1.5
|
|
572
|
+
delay *= jitter_factor
|
|
573
|
+
|
|
574
|
+
return delay
|
|
575
|
+
|
|
576
|
+
async def _schedule_retry(self, message: ReliableMessage) -> None:
|
|
577
|
+
"""Schedule a retry for a failed or timed out message."""
|
|
578
|
+
# Calculate exponential backoff delay with jitter
|
|
579
|
+
delay = self._calculate_retry_delay(message.attempts - 1)
|
|
580
|
+
|
|
581
|
+
logger.debug(f"Scheduling retry for message {message.id} in {delay:.1f}s (attempt {message.attempts + 1})")
|
|
582
|
+
|
|
583
|
+
# Reset message status for retry
|
|
584
|
+
message.status = MessageStatus.PENDING
|
|
585
|
+
message.timestamp = time.time() # Update timestamp for retry
|
|
586
|
+
|
|
587
|
+
# Schedule the retry task
|
|
588
|
+
retry_task = asyncio.create_task(self._execute_retry(message, delay))
|
|
589
|
+
retry_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
|
|
590
|
+
# Use a different key to avoid conflicts with timeout tasks
|
|
591
|
+
self.message_timeouts[f"{message.id}_retry"] = retry_task
|
|
592
|
+
|
|
593
|
+
async def _execute_retry(self, message: ReliableMessage, delay: float) -> None:
|
|
594
|
+
"""Execute a message retry after the specified delay."""
|
|
595
|
+
try:
|
|
596
|
+
await asyncio.sleep(delay)
|
|
597
|
+
|
|
598
|
+
# Check if message is still in pending (not cancelled)
|
|
599
|
+
if message.id in self.pending_messages:
|
|
600
|
+
logger.debug(f"Retrying message {message.id} (attempt {message.attempts + 1})")
|
|
601
|
+
|
|
602
|
+
# Republish the message by re-adding it to the channel
|
|
603
|
+
self._ensure_channel(message.channel)
|
|
604
|
+
|
|
605
|
+
# Create new message entry in channel
|
|
606
|
+
retry_message = {
|
|
607
|
+
'id': message.id,
|
|
608
|
+
'data': message.data,
|
|
609
|
+
'publisher': message.publisher,
|
|
610
|
+
'timestamp': message.timestamp,
|
|
611
|
+
'requires_ack': True
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
# Add to channel
|
|
615
|
+
self.channels[message.channel].append(retry_message)
|
|
616
|
+
|
|
617
|
+
# Set up new timeout task for the retry
|
|
618
|
+
timeout_task = asyncio.create_task(
|
|
619
|
+
self._handle_message_timeout(message.id, message.ack_timeout)
|
|
620
|
+
)
|
|
621
|
+
timeout_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
|
|
622
|
+
self.message_timeouts[message.id] = timeout_task
|
|
623
|
+
|
|
624
|
+
# Notify subscribers again
|
|
625
|
+
await self._notify_subscribers_reliable(message.channel, message.data, message.id)
|
|
626
|
+
|
|
627
|
+
# Clean up retry task reference
|
|
628
|
+
self.message_timeouts.pop(f"{message.id}_retry", None)
|
|
629
|
+
|
|
630
|
+
except asyncio.CancelledError:
|
|
631
|
+
# Retry was cancelled
|
|
632
|
+
pass
|
|
633
|
+
except Exception as e:
|
|
634
|
+
logger.error(f"Error executing retry for message {message.id}: {e}")
|
|
635
|
+
|
|
636
|
+
def get_pending_messages(self) -> List[Dict[str, Any]]:
|
|
637
|
+
"""Get list of pending messages waiting for acknowledgment."""
|
|
638
|
+
if not self.enable_reliability:
|
|
639
|
+
return []
|
|
640
|
+
|
|
641
|
+
return [
|
|
642
|
+
{
|
|
643
|
+
'id': msg.id,
|
|
644
|
+
'channel': msg.channel,
|
|
645
|
+
'publisher': msg.publisher,
|
|
646
|
+
'status': msg.status.value,
|
|
647
|
+
'timestamp': msg.timestamp,
|
|
648
|
+
'attempts': msg.attempts,
|
|
649
|
+
'max_attempts': msg.max_attempts,
|
|
650
|
+
'error': msg.error,
|
|
651
|
+
'age': time.time() - msg.timestamp
|
|
652
|
+
}
|
|
653
|
+
for msg in self.pending_messages.values()
|
|
654
|
+
]
|
|
655
|
+
|
|
656
|
+
def get_subscriber_errors(self, limit: int = 10) -> List[Dict[str, Any]]:
|
|
657
|
+
"""Get recent subscriber errors for debugging."""
|
|
658
|
+
errors = list(self.subscriber_errors)
|
|
659
|
+
return errors[-limit:] if limit < len(errors) else errors
|
|
660
|
+
|
|
661
|
+
# Context manager support
|
|
662
|
+
async def __aenter__(self) -> "RelayManager":
|
|
663
|
+
"""Async context manager entry."""
|
|
664
|
+
await self.start()
|
|
665
|
+
return self
|
|
666
|
+
|
|
667
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
668
|
+
"""Async context manager exit."""
|
|
669
|
+
await self.stop()
|
|
670
|
+
|
|
671
|
+
# Global relay manager instance
|
|
672
|
+
_global_relay = None
|
|
673
|
+
|
|
674
|
+
def get_global_relay() -> RelayManager:
|
|
675
|
+
"""Get the global relay manager instance."""
|
|
676
|
+
global _global_relay
|
|
677
|
+
if _global_relay is None:
|
|
678
|
+
_global_relay = RelayManager()
|
|
679
|
+
return _global_relay
|
|
680
|
+
|
|
681
|
+
# Convenience functions
|
|
682
|
+
async def publish(channel: str, agent_response: Dict[str, Any], publisher: Optional[str] = None) -> None:
|
|
683
|
+
"""Publish using the global relay manager."""
|
|
684
|
+
relay = get_global_relay()
|
|
685
|
+
await relay.publish(channel, agent_response, publisher)
|
|
686
|
+
|
|
687
|
+
async def subscribe(channel: str, callback: Callable) -> None:
|
|
688
|
+
"""Subscribe using the global relay manager."""
|
|
689
|
+
relay = get_global_relay()
|
|
690
|
+
await relay.subscribe(channel, callback)
|
|
691
|
+
|
|
692
|
+
async def get_latest(channel: str, count: int = 1) -> List[Any]:
|
|
693
|
+
"""Get latest result data using the global relay manager."""
|
|
694
|
+
relay = get_global_relay()
|
|
695
|
+
return await relay.get_latest(channel, count)
|