daita-agents 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. daita/__init__.py +216 -0
  2. daita/agents/__init__.py +33 -0
  3. daita/agents/base.py +743 -0
  4. daita/agents/substrate.py +1141 -0
  5. daita/cli/__init__.py +145 -0
  6. daita/cli/__main__.py +7 -0
  7. daita/cli/ascii_art.py +44 -0
  8. daita/cli/core/__init__.py +0 -0
  9. daita/cli/core/create.py +254 -0
  10. daita/cli/core/deploy.py +473 -0
  11. daita/cli/core/deployments.py +309 -0
  12. daita/cli/core/import_detector.py +219 -0
  13. daita/cli/core/init.py +481 -0
  14. daita/cli/core/logs.py +239 -0
  15. daita/cli/core/managed_deploy.py +709 -0
  16. daita/cli/core/run.py +648 -0
  17. daita/cli/core/status.py +421 -0
  18. daita/cli/core/test.py +239 -0
  19. daita/cli/core/webhooks.py +172 -0
  20. daita/cli/main.py +588 -0
  21. daita/cli/utils.py +541 -0
  22. daita/config/__init__.py +62 -0
  23. daita/config/base.py +159 -0
  24. daita/config/settings.py +184 -0
  25. daita/core/__init__.py +262 -0
  26. daita/core/decision_tracing.py +701 -0
  27. daita/core/exceptions.py +480 -0
  28. daita/core/focus.py +251 -0
  29. daita/core/interfaces.py +76 -0
  30. daita/core/plugin_tracing.py +550 -0
  31. daita/core/relay.py +779 -0
  32. daita/core/reliability.py +381 -0
  33. daita/core/scaling.py +459 -0
  34. daita/core/tools.py +554 -0
  35. daita/core/tracing.py +770 -0
  36. daita/core/workflow.py +1144 -0
  37. daita/display/__init__.py +1 -0
  38. daita/display/console.py +160 -0
  39. daita/execution/__init__.py +58 -0
  40. daita/execution/client.py +856 -0
  41. daita/execution/exceptions.py +92 -0
  42. daita/execution/models.py +317 -0
  43. daita/llm/__init__.py +60 -0
  44. daita/llm/anthropic.py +291 -0
  45. daita/llm/base.py +530 -0
  46. daita/llm/factory.py +101 -0
  47. daita/llm/gemini.py +355 -0
  48. daita/llm/grok.py +219 -0
  49. daita/llm/mock.py +172 -0
  50. daita/llm/openai.py +220 -0
  51. daita/plugins/__init__.py +141 -0
  52. daita/plugins/base.py +37 -0
  53. daita/plugins/base_db.py +167 -0
  54. daita/plugins/elasticsearch.py +849 -0
  55. daita/plugins/mcp.py +481 -0
  56. daita/plugins/mongodb.py +520 -0
  57. daita/plugins/mysql.py +362 -0
  58. daita/plugins/postgresql.py +342 -0
  59. daita/plugins/redis_messaging.py +500 -0
  60. daita/plugins/rest.py +537 -0
  61. daita/plugins/s3.py +770 -0
  62. daita/plugins/slack.py +729 -0
  63. daita/utils/__init__.py +18 -0
  64. daita_agents-0.2.0.dist-info/METADATA +409 -0
  65. daita_agents-0.2.0.dist-info/RECORD +69 -0
  66. daita_agents-0.2.0.dist-info/WHEEL +5 -0
  67. daita_agents-0.2.0.dist-info/entry_points.txt +2 -0
  68. daita_agents-0.2.0.dist-info/licenses/LICENSE +56 -0
  69. daita_agents-0.2.0.dist-info/top_level.txt +1 -0
daita/core/relay.py ADDED
@@ -0,0 +1,779 @@
1
+ """
2
+ Simplified Relay System for Daita Agents.
3
+
4
+ Provides simple message passing between agents in a workflow.
5
+ Focus on essential communication functionality without complex features.
6
+
7
+ Updated to pass only result data between agents for cleaner interfaces.
8
+
9
+ Example:
10
+ ```python
11
+ from daita.core.relay import RelayManager
12
+
13
+ # Initialize relay manager
14
+ relay = RelayManager()
15
+
16
+ # Agent publishes data (full response gets stored, only result gets forwarded)
17
+ await relay.publish("data_channel", {
18
+ "status": "success",
19
+ "result": {"processed": True},
20
+ "agent_id": "agent_123"
21
+ })
22
+
23
+ # Another agent subscribes and receives only: {"processed": True}
24
+ async def handle_data(result_data):
25
+ print(f"Got result: {result_data}")
26
+
27
+ await relay.subscribe("data_channel", handle_data)
28
+ ```
29
+ """
30
+ import asyncio
31
+ import logging
32
+ import random
33
+ import time
34
+ import uuid
35
+ from typing import Dict, Any, Optional, List, Callable, Union
36
+ from collections import deque
37
+ from dataclasses import dataclass, field
38
+ from enum import Enum
39
+ import weakref
40
+
41
+ try:
42
+ from ..core.exceptions import DaitaError, AcknowledgmentTimeoutError
43
+ except ImportError:
44
+ # Fallback for direct execution or testing
45
+ from core.exceptions import DaitaError, AcknowledgmentTimeoutError
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ class RelayError(DaitaError):
50
+ """Exception raised for relay-related errors."""
51
+ pass
52
+
53
+ class MessageStatus(str, Enum):
54
+ """Status of a message in the relay system."""
55
+ PENDING = "pending"
56
+ ACKNOWLEDGED = "acknowledged"
57
+ FAILED = "failed"
58
+ TIMEOUT = "timeout"
59
+
60
+ @dataclass
61
+ class ReliableMessage:
62
+ """A message with acknowledgment tracking and metadata."""
63
+ id: str
64
+ channel: str
65
+ data: Any
66
+ metadata: Dict[str, Any] = field(default_factory=dict)
67
+ publisher: Optional[str] = None
68
+ timestamp: float = field(default_factory=time.time)
69
+ status: MessageStatus = MessageStatus.PENDING
70
+ ack_timeout: float = 30.0 # Default 30 second timeout
71
+ error: Optional[str] = None
72
+ attempts: int = 0
73
+ max_attempts: int = 3
74
+
75
+ class RelayManager:
76
+ """
77
+ Simple relay manager for agent-to-agent communication.
78
+
79
+ Handles message publishing and subscription with minimal complexity.
80
+ Now passes only result data to maintain clean interfaces.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ max_messages_per_channel: int = 10,
86
+ enable_reliability: bool = False,
87
+ default_ack_timeout: float = 30.0
88
+ ):
89
+ """
90
+ Initialize the relay manager.
91
+
92
+ Args:
93
+ max_messages_per_channel: Maximum messages to keep per channel (reduced to 10 for memory efficiency)
94
+ enable_reliability: Enable message acknowledgments and reliability features
95
+ default_ack_timeout: Default acknowledgment timeout in seconds
96
+ """
97
+ self.max_messages_per_channel = max_messages_per_channel
98
+ self.enable_reliability = enable_reliability
99
+ self.default_ack_timeout = default_ack_timeout
100
+
101
+ # Channel storage: channel_name -> deque of result data only
102
+ self.channels: Dict[str, deque] = {}
103
+
104
+ # Subscribers: channel_name -> set of callbacks
105
+ # Using regular set instead of WeakSet to prevent premature garbage collection
106
+ self.subscribers: Dict[str, set] = {}
107
+
108
+ # Per-channel locks to prevent race conditions between publish and subscribe
109
+ self._channel_locks: Dict[str, asyncio.Lock] = {}
110
+
111
+ # Subscriber error tracking
112
+ self.subscriber_errors: deque = deque(maxlen=100)
113
+
114
+ # Reliability features (when enabled)
115
+ self.pending_messages: Dict[str, ReliableMessage] = {} if enable_reliability else {}
116
+ self.message_timeouts: Dict[str, asyncio.Task] = {} if enable_reliability else {}
117
+
118
+ # Running state
119
+ self._running = False
120
+
121
+ logger.debug(f"RelayManager initialized (reliability: {enable_reliability})")
122
+
123
+ async def start(self) -> None:
124
+ """Start the relay manager."""
125
+ if self._running:
126
+ return
127
+
128
+ self._running = True
129
+ logger.info("RelayManager started")
130
+
131
+ async def stop(self) -> None:
132
+ """Stop the relay manager and cleanup."""
133
+ if not self._running:
134
+ return
135
+
136
+ self._running = False
137
+
138
+ # Clear all channels and subscribers
139
+ self.channels.clear()
140
+ self.subscribers.clear()
141
+
142
+ logger.info("RelayManager stopped")
143
+
144
+ def _ensure_channel(self, channel: str) -> None:
145
+ """Ensure channel exists."""
146
+ if channel not in self.channels:
147
+ self.channels[channel] = deque(maxlen=self.max_messages_per_channel)
148
+ self.subscribers[channel] = set() # Regular set to prevent garbage collection
149
+ self._channel_locks[channel] = asyncio.Lock()
150
+
151
+ def _extract_metadata(self, agent_response: Dict[str, Any], publisher: Optional[str]) -> Dict[str, Any]:
152
+ """
153
+ Extract metadata from agent response for context propagation.
154
+
155
+ Automatically captures metrics, token usage, confidence scores, and other
156
+ context that downstream agents need to make intelligent decisions.
157
+ """
158
+ metadata = {
159
+ 'upstream_agent': publisher or agent_response.get('agent_name'),
160
+ 'upstream_agent_id': agent_response.get('agent_id'),
161
+ 'timestamp': agent_response.get('timestamp', time.time()),
162
+ }
163
+
164
+ # Extract processing metrics if available
165
+ if 'processing_time_ms' in agent_response:
166
+ metadata['processing_time_ms'] = agent_response['processing_time_ms']
167
+
168
+ # Extract token usage if available
169
+ if 'token_usage' in agent_response:
170
+ metadata['token_usage'] = agent_response['token_usage']
171
+
172
+ # Extract confidence score if available
173
+ if 'confidence_score' in agent_response:
174
+ metadata['confidence_score'] = agent_response['confidence_score']
175
+ elif 'confidence' in agent_response:
176
+ metadata['confidence_score'] = agent_response['confidence']
177
+
178
+ # Extract record count if available
179
+ if 'record_count' in agent_response:
180
+ metadata['record_count'] = agent_response['record_count']
181
+
182
+ # Extract error information if available
183
+ if 'error_count' in agent_response:
184
+ metadata['error_count'] = agent_response['error_count']
185
+
186
+ # Extract correlation ID for distributed tracing
187
+ if 'correlation_id' in agent_response:
188
+ metadata['correlation_id'] = agent_response['correlation_id']
189
+ elif 'context' in agent_response and isinstance(agent_response['context'], dict):
190
+ metadata['correlation_id'] = agent_response['context'].get('correlation_id', str(uuid.uuid4()))
191
+ else:
192
+ metadata['correlation_id'] = str(uuid.uuid4())
193
+
194
+ # Extract retry information if available
195
+ if 'retry_info' in agent_response:
196
+ metadata['retry_info'] = agent_response['retry_info']
197
+
198
+ # Extract status (success/error)
199
+ metadata['status'] = agent_response.get('status', 'unknown')
200
+
201
+ return metadata
202
+
203
+ async def publish(
204
+ self,
205
+ channel: str,
206
+ data: Any,
207
+ publisher: Optional[str] = None,
208
+ require_ack: Optional[bool] = None
209
+ ) -> Optional[str]:
210
+ """
211
+ Publish data to a channel with automatic metadata propagation.
212
+
213
+ Args:
214
+ channel: Channel name
215
+ data: Data to publish (can be any type - dict, list, str, etc.)
216
+ If dict with 'result' field, extracts result and metadata (backward compat)
217
+ Otherwise, publishes data as-is
218
+ publisher: Optional publisher identifier
219
+ require_ack: Whether this message requires acknowledgment (overrides global setting)
220
+
221
+ Returns:
222
+ Message ID if reliability is enabled, None otherwise
223
+ """
224
+ if not self._running:
225
+ await self.start()
226
+
227
+ # Handle both new API (raw data) and old API (agent response dict)
228
+ if isinstance(data, dict) and 'result' in data:
229
+ # Old API: Extract result and metadata from agent response
230
+ result_data = data['result']
231
+ metadata = self._extract_metadata(data, publisher)
232
+ else:
233
+ # New API: Use data directly, create minimal metadata
234
+ result_data = data
235
+ metadata = {
236
+ 'publisher': publisher,
237
+ 'timestamp': time.time(),
238
+ 'correlation_id': str(uuid.uuid4()),
239
+ 'status': 'success'
240
+ }
241
+
242
+ self._ensure_channel(channel)
243
+
244
+ # Determine if we need reliability
245
+ needs_reliability = require_ack if require_ack is not None else self.enable_reliability
246
+
247
+ if needs_reliability:
248
+ return await self._publish_reliable(channel, result_data, metadata, publisher)
249
+ else:
250
+ return await self._publish_fire_and_forget(channel, result_data, metadata, publisher)
251
+
252
+ async def _publish_fire_and_forget(
253
+ self,
254
+ channel: str,
255
+ result_data: Any,
256
+ metadata: Dict[str, Any],
257
+ publisher: Optional[str]
258
+ ) -> None:
259
+ """Publish message without reliability features with metadata propagation."""
260
+ # Use per-channel lock to make publish atomic
261
+ async with self._channel_locks[channel]:
262
+ # Create message with result data and metadata
263
+ message = {
264
+ 'data': result_data,
265
+ 'metadata': metadata,
266
+ 'publisher': publisher,
267
+ 'timestamp': time.time()
268
+ }
269
+
270
+ # Store result message
271
+ self.channels[channel].append(message)
272
+
273
+ # Notify subscribers with result data and metadata (while holding lock)
274
+ await self._notify_subscribers(channel, result_data, metadata)
275
+
276
+ logger.debug(f"Published result to channel '{channel}' from {publisher}")
277
+ return None
278
+
279
+ async def _publish_reliable(
280
+ self,
281
+ channel: str,
282
+ result_data: Any,
283
+ metadata: Dict[str, Any],
284
+ publisher: Optional[str]
285
+ ) -> str:
286
+ """Publish message with reliability features and metadata propagation."""
287
+ # Create reliable message
288
+ message_id = uuid.uuid4().hex
289
+ reliable_message = ReliableMessage(
290
+ id=message_id,
291
+ channel=channel,
292
+ data=result_data,
293
+ metadata=metadata,
294
+ publisher=publisher,
295
+ ack_timeout=self.default_ack_timeout
296
+ )
297
+
298
+ # Store pending message
299
+ self.pending_messages[message_id] = reliable_message
300
+
301
+ # Use per-channel lock to make publish atomic
302
+ async with self._channel_locks[channel]:
303
+ # Create message for channel storage with metadata
304
+ message = {
305
+ 'id': message_id,
306
+ 'data': result_data,
307
+ 'metadata': metadata,
308
+ 'publisher': publisher,
309
+ 'timestamp': time.time(),
310
+ 'requires_ack': True
311
+ }
312
+
313
+ # Store result message
314
+ self.channels[channel].append(message)
315
+
316
+ # Set up timeout task
317
+ timeout_task = asyncio.create_task(
318
+ self._handle_message_timeout(message_id, reliable_message.ack_timeout)
319
+ )
320
+ timeout_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
321
+ self.message_timeouts[message_id] = timeout_task
322
+
323
+ # Notify subscribers with message ID for acknowledgment (while holding lock)
324
+ await self._notify_subscribers_reliable(channel, result_data, metadata, message_id)
325
+
326
+ logger.debug(f"Published reliable message {message_id} to channel '{channel}' from {publisher}")
327
+ return message_id
328
+
329
+ async def _notify_subscribers(self, channel: str, result_data: Any, metadata: Dict[str, Any]) -> None:
330
+ """Notify all subscribers of a channel with result data and metadata."""
331
+ if channel not in self.subscribers:
332
+ return
333
+
334
+ # Get snapshot of subscribers to avoid modification during iteration
335
+ subscriber_list = list(self.subscribers[channel])
336
+
337
+ if not subscriber_list:
338
+ return
339
+
340
+ # Notify all subscribers concurrently
341
+ tasks = []
342
+ for subscriber in subscriber_list:
343
+ task = asyncio.create_task(self._call_subscriber(subscriber, result_data, metadata))
344
+ task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
345
+ tasks.append(task)
346
+
347
+ # Wait for all notifications to complete
348
+ if tasks:
349
+ await asyncio.gather(*tasks, return_exceptions=True)
350
+
351
+ async def _notify_subscribers_reliable(self, channel: str, result_data: Any, metadata: Dict[str, Any], message_id: str) -> None:
352
+ """Notify all subscribers of a reliable message with metadata."""
353
+ if channel not in self.subscribers:
354
+ return
355
+
356
+ # Get snapshot of subscribers to avoid modification during iteration
357
+ subscriber_list = list(self.subscribers[channel])
358
+
359
+ if not subscriber_list:
360
+ return
361
+
362
+ # Notify all subscribers concurrently with message ID and metadata
363
+ tasks = []
364
+ for subscriber in subscriber_list:
365
+ task = asyncio.create_task(
366
+ self._call_subscriber_reliable(subscriber, result_data, metadata, message_id)
367
+ )
368
+ task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
369
+ tasks.append(task)
370
+
371
+ # Wait for all notifications to complete
372
+ if tasks:
373
+ await asyncio.gather(*tasks, return_exceptions=True)
374
+
375
+ async def _call_subscriber_reliable(self, callback: Callable, result_data: Any, metadata: Dict[str, Any], message_id: str) -> None:
376
+ """Safely call a subscriber callback for reliable message with metadata."""
377
+ try:
378
+ if asyncio.iscoroutinefunction(callback):
379
+ # Check callback signature to determine what parameters to pass
380
+ import inspect
381
+ sig = inspect.signature(callback)
382
+ params = list(sig.parameters.keys())
383
+
384
+ # Call with appropriate parameters based on signature
385
+ if 'message_id' in params and 'metadata' in params:
386
+ await callback(result_data, metadata=metadata, message_id=message_id)
387
+ elif 'message_id' in params:
388
+ await callback(result_data, message_id=message_id)
389
+ elif 'metadata' in params:
390
+ await callback(result_data, metadata=metadata)
391
+ else:
392
+ await callback(result_data)
393
+ else:
394
+ # Run sync callback in thread pool
395
+ loop = asyncio.get_event_loop()
396
+ await loop.run_in_executor(None, callback, result_data)
397
+ except Exception as e:
398
+ logger.error(f"Error in reliable subscriber callback: {str(e)}")
399
+ # NACK the message on callback error
400
+ await self.nack_message(message_id, str(e))
401
+
402
+ async def _call_subscriber(self, callback: Callable, result_data: Any, metadata: Dict[str, Any]) -> None:
403
+ """Safely call a subscriber callback with result data and metadata."""
404
+ try:
405
+ if asyncio.iscoroutinefunction(callback):
406
+ # Check callback signature to determine what parameters to pass
407
+ import inspect
408
+ sig = inspect.signature(callback)
409
+ params = list(sig.parameters.keys())
410
+
411
+ # Call with metadata if callback accepts it
412
+ if 'metadata' in params:
413
+ await callback(result_data, metadata=metadata)
414
+ else:
415
+ await callback(result_data)
416
+ else:
417
+ # Run sync callback in thread pool
418
+ loop = asyncio.get_event_loop()
419
+ await loop.run_in_executor(None, callback, result_data)
420
+ except Exception as e:
421
+ error_info = {
422
+ 'callback': str(callback),
423
+ 'error': str(e),
424
+ 'error_type': type(e).__name__,
425
+ 'timestamp': time.time(),
426
+ 'data_preview': str(result_data)[:100]
427
+ }
428
+ self.subscriber_errors.append(error_info)
429
+ logger.error(f"Subscriber callback failed: {e}")
430
+ # Don't re-raise - we want other subscribers to continue
431
+
432
+ async def subscribe(self, channel: str, callback: Callable) -> None:
433
+ """
434
+ Subscribe to a channel.
435
+
436
+ Args:
437
+ channel: Channel name
438
+ callback: Callback function to receive result data
439
+ """
440
+ self._ensure_channel(channel)
441
+ self.subscribers[channel].add(callback)
442
+
443
+ logger.debug(f"Subscribed to channel '{channel}'")
444
+
445
+ def unsubscribe(self, channel: str, callback: Callable) -> bool:
446
+ """
447
+ Unsubscribe from a channel.
448
+
449
+ Args:
450
+ channel: Channel name
451
+ callback: Callback to remove
452
+
453
+ Returns:
454
+ True if callback was removed
455
+ """
456
+ if channel not in self.subscribers:
457
+ return False
458
+
459
+ try:
460
+ self.subscribers[channel].remove(callback)
461
+ logger.debug(f"Unsubscribed from channel '{channel}'")
462
+ return True
463
+ except KeyError:
464
+ return False
465
+
466
+ async def get_latest(self, channel: str, count: int = 1) -> List[Any]:
467
+ """
468
+ Get latest result data from a channel.
469
+
470
+ Args:
471
+ channel: Channel name
472
+ count: Number of messages to retrieve
473
+
474
+ Returns:
475
+ List of result data (newest first)
476
+ """
477
+ if channel not in self.channels:
478
+ return []
479
+
480
+ # Get latest messages
481
+ messages = list(self.channels[channel])
482
+ latest = messages[-count:] if count < len(messages) else messages
483
+ latest.reverse() # Newest first
484
+
485
+ # Return just the result data
486
+ return [msg['data'] for msg in latest]
487
+
488
+ def list_channels(self) -> List[str]:
489
+ """List all channels."""
490
+ return list(self.channels.keys())
491
+
492
+ def clear_channel(self, channel: str) -> bool:
493
+ """
494
+ Clear all messages from a channel.
495
+
496
+ Args:
497
+ channel: Channel name
498
+
499
+ Returns:
500
+ True if channel was cleared
501
+ """
502
+ if channel not in self.channels:
503
+ return False
504
+
505
+ self.channels[channel].clear()
506
+
507
+ logger.debug(f"Cleared channel '{channel}'")
508
+ return True
509
+
510
+ def get_stats(self) -> Dict[str, Any]:
511
+ """Get relay manager statistics."""
512
+ total_messages = sum(len(ch) for ch in self.channels.values())
513
+ total_subscribers = sum(len(subs) for subs in self.subscribers.values())
514
+
515
+ stats = {
516
+ 'running': self._running,
517
+ 'total_channels': len(self.channels),
518
+ 'total_messages': total_messages,
519
+ 'total_subscribers': total_subscribers,
520
+ 'channels': list(self.channels.keys()),
521
+ 'subscriber_errors_count': len(self.subscriber_errors)
522
+ }
523
+
524
+ # Add per-channel error breakdown
525
+ if self.subscriber_errors:
526
+ stats['errors_by_channel'] = {}
527
+ for channel in self.channels.keys():
528
+ channel_errors = sum(
529
+ 1 for err in self.subscriber_errors
530
+ if channel in err.get('callback', '')
531
+ )
532
+ if channel_errors > 0:
533
+ stats['errors_by_channel'][channel] = channel_errors
534
+
535
+ # Add reliability stats if enabled
536
+ if self.enable_reliability:
537
+ stats.update({
538
+ 'reliability_enabled': True,
539
+ 'pending_messages': len(self.pending_messages),
540
+ 'active_timeouts': len(self.message_timeouts)
541
+ })
542
+ else:
543
+ stats['reliability_enabled'] = False
544
+
545
+ return stats
546
+
547
+ # Reliability methods (acknowledgments and timeouts)
548
+
549
+ async def ack_message(self, message_id: str) -> bool:
550
+ """
551
+ Acknowledge a message as successfully processed.
552
+
553
+ Args:
554
+ message_id: ID of the message to acknowledge
555
+
556
+ Returns:
557
+ True if message was acknowledged, False if not found
558
+ """
559
+ if not self.enable_reliability or message_id not in self.pending_messages:
560
+ return False
561
+
562
+ # Update message status
563
+ message = self.pending_messages[message_id]
564
+ message.status = MessageStatus.ACKNOWLEDGED
565
+
566
+ # Cancel timeout task
567
+ if message_id in self.message_timeouts:
568
+ timeout_task = self.message_timeouts.pop(message_id)
569
+ if not timeout_task.done():
570
+ timeout_task.cancel()
571
+
572
+ # Remove from pending
573
+ del self.pending_messages[message_id]
574
+
575
+ logger.debug(f"Acknowledged message {message_id}")
576
+ return True
577
+
578
+ async def nack_message(self, message_id: str, error: str) -> bool:
579
+ """
580
+ Negative acknowledge a message (processing failed).
581
+
582
+ Args:
583
+ message_id: ID of the message to NACK
584
+ error: Error message
585
+
586
+ Returns:
587
+ True if message was NACKed, False if not found
588
+ """
589
+ if not self.enable_reliability or message_id not in self.pending_messages:
590
+ return False
591
+
592
+ # Update message status
593
+ message = self.pending_messages[message_id]
594
+ message.status = MessageStatus.FAILED
595
+ message.error = error
596
+ message.attempts += 1
597
+
598
+ # Cancel timeout task
599
+ if message_id in self.message_timeouts:
600
+ timeout_task = self.message_timeouts.pop(message_id)
601
+ if not timeout_task.done():
602
+ timeout_task.cancel()
603
+
604
+ # Check if we should retry
605
+ if message.attempts < message.max_attempts:
606
+ logger.warning(f"Message {message_id} failed (attempt {message.attempts}/{message.max_attempts}): {error}")
607
+ await self._schedule_retry(message)
608
+ else:
609
+ logger.error(f"Message {message_id} failed permanently after {message.attempts} attempts: {error}")
610
+ # Remove from pending messages
611
+ del self.pending_messages[message_id]
612
+
613
+ return True
614
+
615
+ async def _handle_message_timeout(self, message_id: str, timeout_duration: float) -> None:
616
+ """Handle message acknowledgment timeout."""
617
+ try:
618
+ await asyncio.sleep(timeout_duration)
619
+
620
+ # Check if message is still pending
621
+ if message_id in self.pending_messages:
622
+ message = self.pending_messages[message_id]
623
+ message.status = MessageStatus.TIMEOUT
624
+ message.attempts += 1
625
+
626
+ logger.warning(f"Message {message_id} timed out after {timeout_duration}s")
627
+
628
+ # Check if we should retry
629
+ if message.attempts < message.max_attempts:
630
+ await self._schedule_retry(message)
631
+ else:
632
+ logger.error(f"Message {message_id} timed out permanently after {message.attempts} attempts")
633
+ # Remove from pending messages
634
+ del self.pending_messages[message_id]
635
+
636
+ # Remove timeout task
637
+ self.message_timeouts.pop(message_id, None)
638
+
639
+ except asyncio.CancelledError:
640
+ # Timeout was cancelled (message was acknowledged)
641
+ pass
642
+ except Exception as e:
643
+ logger.error(f"Error handling timeout for message {message_id}: {e}")
644
+
645
+ def _calculate_retry_delay(self, attempt: int, base_delay: float = 1.0) -> float:
646
+ """Calculate retry delay with exponential backoff and jitter."""
647
+ # Exponential backoff: 1s, 2s, 4s, 8s, ...
648
+ delay = base_delay * (2 ** attempt)
649
+
650
+ # Cap at 60 seconds
651
+ delay = min(delay, 60.0)
652
+
653
+ # Add ±50% jitter to prevent thundering herd
654
+ jitter_factor = 0.5 + random.random() # Random between 0.5 and 1.5
655
+ delay *= jitter_factor
656
+
657
+ return delay
658
+
659
+ async def _schedule_retry(self, message: ReliableMessage) -> None:
660
+ """Schedule a retry for a failed or timed out message."""
661
+ # Calculate exponential backoff delay with jitter
662
+ delay = self._calculate_retry_delay(message.attempts - 1)
663
+
664
+ logger.debug(f"Scheduling retry for message {message.id} in {delay:.1f}s (attempt {message.attempts + 1})")
665
+
666
+ # Reset message status for retry
667
+ message.status = MessageStatus.PENDING
668
+ message.timestamp = time.time() # Update timestamp for retry
669
+
670
+ # Schedule the retry task
671
+ retry_task = asyncio.create_task(self._execute_retry(message, delay))
672
+ retry_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
673
+ # Use a different key to avoid conflicts with timeout tasks
674
+ self.message_timeouts[f"{message.id}_retry"] = retry_task
675
+
676
+ async def _execute_retry(self, message: ReliableMessage, delay: float) -> None:
677
+ """Execute a message retry after the specified delay."""
678
+ try:
679
+ await asyncio.sleep(delay)
680
+
681
+ # Check if message is still in pending (not cancelled)
682
+ if message.id in self.pending_messages:
683
+ logger.debug(f"Retrying message {message.id} (attempt {message.attempts + 1})")
684
+
685
+ # Republish the message by re-adding it to the channel
686
+ self._ensure_channel(message.channel)
687
+
688
+ # Create new message entry in channel
689
+ retry_message = {
690
+ 'id': message.id,
691
+ 'data': message.data,
692
+ 'metadata': message.metadata,
693
+ 'publisher': message.publisher,
694
+ 'timestamp': message.timestamp,
695
+ 'requires_ack': True
696
+ }
697
+
698
+ # Add to channel
699
+ self.channels[message.channel].append(retry_message)
700
+
701
+ # Set up new timeout task for the retry
702
+ timeout_task = asyncio.create_task(
703
+ self._handle_message_timeout(message.id, message.ack_timeout)
704
+ )
705
+ timeout_task.add_done_callback(lambda t: t.exception() if not t.cancelled() else None)
706
+ self.message_timeouts[message.id] = timeout_task
707
+
708
+ # Notify subscribers again
709
+ await self._notify_subscribers_reliable(message.channel, message.data, message.metadata, message.id)
710
+
711
+ # Clean up retry task reference
712
+ self.message_timeouts.pop(f"{message.id}_retry", None)
713
+
714
+ except asyncio.CancelledError:
715
+ # Retry was cancelled
716
+ pass
717
+ except Exception as e:
718
+ logger.error(f"Error executing retry for message {message.id}: {e}")
719
+
720
+ def get_pending_messages(self) -> List[Dict[str, Any]]:
721
+ """Get list of pending messages waiting for acknowledgment."""
722
+ if not self.enable_reliability:
723
+ return []
724
+
725
+ return [
726
+ {
727
+ 'id': msg.id,
728
+ 'channel': msg.channel,
729
+ 'publisher': msg.publisher,
730
+ 'status': msg.status.value,
731
+ 'timestamp': msg.timestamp,
732
+ 'attempts': msg.attempts,
733
+ 'max_attempts': msg.max_attempts,
734
+ 'error': msg.error,
735
+ 'age': time.time() - msg.timestamp
736
+ }
737
+ for msg in self.pending_messages.values()
738
+ ]
739
+
740
+ def get_subscriber_errors(self, limit: int = 10) -> List[Dict[str, Any]]:
741
+ """Get recent subscriber errors for debugging."""
742
+ errors = list(self.subscriber_errors)
743
+ return errors[-limit:] if limit < len(errors) else errors
744
+
745
+ # Context manager support
746
+ async def __aenter__(self) -> "RelayManager":
747
+ """Async context manager entry."""
748
+ await self.start()
749
+ return self
750
+
751
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
752
+ """Async context manager exit."""
753
+ await self.stop()
754
+
755
+ # Global relay manager instance
756
+ _global_relay = None
757
+
758
+ def get_global_relay() -> RelayManager:
759
+ """Get the global relay manager instance."""
760
+ global _global_relay
761
+ if _global_relay is None:
762
+ _global_relay = RelayManager()
763
+ return _global_relay
764
+
765
+ # Convenience functions
766
+ async def publish(channel: str, agent_response: Dict[str, Any], publisher: Optional[str] = None) -> None:
767
+ """Publish using the global relay manager."""
768
+ relay = get_global_relay()
769
+ await relay.publish(channel, agent_response, publisher)
770
+
771
+ async def subscribe(channel: str, callback: Callable) -> None:
772
+ """Subscribe using the global relay manager."""
773
+ relay = get_global_relay()
774
+ await relay.subscribe(channel, callback)
775
+
776
+ async def get_latest(channel: str, count: int = 1) -> List[Any]:
777
+ """Get latest result data using the global relay manager."""
778
+ relay = get_global_relay()
779
+ return await relay.get_latest(channel, count)