daita-agents 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- daita/__init__.py +216 -0
- daita/agents/__init__.py +33 -0
- daita/agents/base.py +743 -0
- daita/agents/substrate.py +1141 -0
- daita/cli/__init__.py +145 -0
- daita/cli/__main__.py +7 -0
- daita/cli/ascii_art.py +44 -0
- daita/cli/core/__init__.py +0 -0
- daita/cli/core/create.py +254 -0
- daita/cli/core/deploy.py +473 -0
- daita/cli/core/deployments.py +309 -0
- daita/cli/core/import_detector.py +219 -0
- daita/cli/core/init.py +481 -0
- daita/cli/core/logs.py +239 -0
- daita/cli/core/managed_deploy.py +709 -0
- daita/cli/core/run.py +648 -0
- daita/cli/core/status.py +421 -0
- daita/cli/core/test.py +239 -0
- daita/cli/core/webhooks.py +172 -0
- daita/cli/main.py +588 -0
- daita/cli/utils.py +541 -0
- daita/config/__init__.py +62 -0
- daita/config/base.py +159 -0
- daita/config/settings.py +184 -0
- daita/core/__init__.py +262 -0
- daita/core/decision_tracing.py +701 -0
- daita/core/exceptions.py +480 -0
- daita/core/focus.py +251 -0
- daita/core/interfaces.py +76 -0
- daita/core/plugin_tracing.py +550 -0
- daita/core/relay.py +779 -0
- daita/core/reliability.py +381 -0
- daita/core/scaling.py +459 -0
- daita/core/tools.py +554 -0
- daita/core/tracing.py +770 -0
- daita/core/workflow.py +1144 -0
- daita/display/__init__.py +1 -0
- daita/display/console.py +160 -0
- daita/execution/__init__.py +58 -0
- daita/execution/client.py +856 -0
- daita/execution/exceptions.py +92 -0
- daita/execution/models.py +317 -0
- daita/llm/__init__.py +60 -0
- daita/llm/anthropic.py +291 -0
- daita/llm/base.py +530 -0
- daita/llm/factory.py +101 -0
- daita/llm/gemini.py +355 -0
- daita/llm/grok.py +219 -0
- daita/llm/mock.py +172 -0
- daita/llm/openai.py +220 -0
- daita/plugins/__init__.py +141 -0
- daita/plugins/base.py +37 -0
- daita/plugins/base_db.py +167 -0
- daita/plugins/elasticsearch.py +849 -0
- daita/plugins/mcp.py +481 -0
- daita/plugins/mongodb.py +520 -0
- daita/plugins/mysql.py +362 -0
- daita/plugins/postgresql.py +342 -0
- daita/plugins/redis_messaging.py +500 -0
- daita/plugins/rest.py +537 -0
- daita/plugins/s3.py +770 -0
- daita/plugins/slack.py +729 -0
- daita/utils/__init__.py +18 -0
- daita_agents-0.2.0.dist-info/METADATA +409 -0
- daita_agents-0.2.0.dist-info/RECORD +69 -0
- daita_agents-0.2.0.dist-info/WHEEL +5 -0
- daita_agents-0.2.0.dist-info/entry_points.txt +2 -0
- daita_agents-0.2.0.dist-info/licenses/LICENSE +56 -0
- daita_agents-0.2.0.dist-info/top_level.txt +1 -0
daita/core/workflow.py
ADDED
|
@@ -0,0 +1,1144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Simplified Workflow System for Daita Agents.
|
|
3
|
+
|
|
4
|
+
Provides orchestration of agents as connected systems with automatic tracing.
|
|
5
|
+
All workflow communication is automatically traced through the unified tracing system.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
```python
|
|
9
|
+
from daita.core.workflow import Workflow
|
|
10
|
+
|
|
11
|
+
# Create agents
|
|
12
|
+
fetcher = sdk.substrate_agent(name="Data Fetcher")
|
|
13
|
+
analyzer = sdk.analysis_agent(name="Analyzer")
|
|
14
|
+
|
|
15
|
+
# Create workflow
|
|
16
|
+
workflow = Workflow("Data Pipeline")
|
|
17
|
+
workflow.add_agent("fetcher", fetcher)
|
|
18
|
+
workflow.add_agent("analyzer", analyzer)
|
|
19
|
+
|
|
20
|
+
# Connect agents via relay channels
|
|
21
|
+
workflow.connect("fetcher", "raw_data", "analyzer")
|
|
22
|
+
|
|
23
|
+
# Start workflow
|
|
24
|
+
await workflow.start()
|
|
25
|
+
|
|
26
|
+
# View recent communication in unified dashboard
|
|
27
|
+
# All workflow communication is automatically traced
|
|
28
|
+
```
|
|
29
|
+
"""
|
|
30
|
+
import asyncio
|
|
31
|
+
import logging
|
|
32
|
+
import time
|
|
33
|
+
from typing import Dict, Any, Optional, List, Tuple, Set
|
|
34
|
+
from dataclasses import dataclass
|
|
35
|
+
from enum import Enum
|
|
36
|
+
from datetime import datetime
|
|
37
|
+
|
|
38
|
+
from ..core.exceptions import DaitaError, WorkflowError, BackpressureError
|
|
39
|
+
from ..core.relay import RelayManager, get_global_relay
|
|
40
|
+
from ..core.tracing import get_trace_manager, TraceType, TraceStatus
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
class WorkflowStatus(str, Enum):
|
|
45
|
+
"""Status of a workflow."""
|
|
46
|
+
CREATED = "created"
|
|
47
|
+
STARTING = "starting"
|
|
48
|
+
RUNNING = "running"
|
|
49
|
+
STOPPING = "stopping"
|
|
50
|
+
STOPPED = "stopped"
|
|
51
|
+
ERROR = "error"
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class ReliabilityConfig:
|
|
55
|
+
"""Configuration for workflow reliability features."""
|
|
56
|
+
acknowledgments: bool = True
|
|
57
|
+
task_tracking: bool = True
|
|
58
|
+
backpressure_control: bool = True
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class Connection:
|
|
62
|
+
"""Represents a connection between agents via a relay channel."""
|
|
63
|
+
from_agent: str
|
|
64
|
+
channel: str
|
|
65
|
+
to_agent: str
|
|
66
|
+
task: str = "relay_message"
|
|
67
|
+
|
|
68
|
+
def __str__(self):
|
|
69
|
+
return f"{self.from_agent} -> {self.channel} -> {self.to_agent}"
|
|
70
|
+
|
|
71
|
+
class Workflow:
|
|
72
|
+
"""
|
|
73
|
+
A workflow manages a collection of agents and their connections.
|
|
74
|
+
|
|
75
|
+
All workflow communication is automatically traced through the unified
|
|
76
|
+
tracing system without any configuration required.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
name: str,
|
|
82
|
+
project_id: Optional[str] = None,
|
|
83
|
+
relay_manager: Optional[RelayManager] = None
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Initialize a workflow.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
name: Workflow name
|
|
90
|
+
project_id: Optional project ID this workflow belongs to
|
|
91
|
+
relay_manager: Relay manager for agent communication
|
|
92
|
+
"""
|
|
93
|
+
self.name = name
|
|
94
|
+
self.project_id = project_id
|
|
95
|
+
self.relay_manager = relay_manager or get_global_relay()
|
|
96
|
+
|
|
97
|
+
# Agent storage: agent_name -> agent_instance
|
|
98
|
+
self.agents: Dict[str, Any] = {}
|
|
99
|
+
|
|
100
|
+
# Agent pools: pool_name -> AgentPool instance (for horizontal scaling)
|
|
101
|
+
self.agent_pools: Dict[str, Any] = {}
|
|
102
|
+
|
|
103
|
+
# Connections: list of Connection objects
|
|
104
|
+
self.connections: List[Connection] = []
|
|
105
|
+
|
|
106
|
+
# Relay channels used by this workflow
|
|
107
|
+
self.channels: Set[str] = set()
|
|
108
|
+
|
|
109
|
+
# Reliability configuration
|
|
110
|
+
self.reliability_config: Optional[ReliabilityConfig] = None
|
|
111
|
+
self._reliability_enabled = False
|
|
112
|
+
|
|
113
|
+
# Workflow state
|
|
114
|
+
self.status = WorkflowStatus.CREATED
|
|
115
|
+
self.created_at = time.time()
|
|
116
|
+
self.started_at: Optional[float] = None
|
|
117
|
+
self.stopped_at: Optional[float] = None
|
|
118
|
+
self.error: Optional[str] = None
|
|
119
|
+
|
|
120
|
+
# Subscription tracking for cleanup
|
|
121
|
+
self._subscriptions: List[Tuple[str, Any]] = []
|
|
122
|
+
|
|
123
|
+
# Message deduplication (only for reliable mode)
|
|
124
|
+
self._processed_messages: Set[str] = set()
|
|
125
|
+
self._dedup_cleanup_task: Optional[asyncio.Task] = None
|
|
126
|
+
self._dedup_max_size = 10000 # Prevent unbounded growth
|
|
127
|
+
|
|
128
|
+
# Get trace manager for automatic workflow communication tracing
|
|
129
|
+
self.trace_manager = get_trace_manager()
|
|
130
|
+
|
|
131
|
+
logger.debug(f"Created workflow '{name}' with automatic tracing")
|
|
132
|
+
|
|
133
|
+
def add_agent(self, name: str, agent: Any) -> "Workflow":
|
|
134
|
+
"""
|
|
135
|
+
Add an agent to the workflow.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
name: Agent name for workflow reference
|
|
139
|
+
agent: Agent instance
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Self for method chaining
|
|
143
|
+
"""
|
|
144
|
+
if name in self.agents:
|
|
145
|
+
raise WorkflowError(f"Agent '{name}' already exists in workflow")
|
|
146
|
+
|
|
147
|
+
self.agents[name] = agent
|
|
148
|
+
logger.debug(f"Added agent '{name}' to workflow '{self.name}'")
|
|
149
|
+
return self
|
|
150
|
+
|
|
151
|
+
def add_agent_pool(
|
|
152
|
+
self,
|
|
153
|
+
name: str,
|
|
154
|
+
agent_factory: Any,
|
|
155
|
+
instances: int = 1
|
|
156
|
+
) -> "Workflow":
|
|
157
|
+
"""
|
|
158
|
+
Add an agent pool to the workflow for horizontal scaling.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
name: Pool name for workflow reference
|
|
162
|
+
agent_factory: Factory function to create agent instances
|
|
163
|
+
instances: Number of agent instances in the pool
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Self for method chaining
|
|
167
|
+
|
|
168
|
+
Example:
|
|
169
|
+
```python
|
|
170
|
+
def create_processor():
|
|
171
|
+
return sdk.substrate_agent(name="Processor")
|
|
172
|
+
|
|
173
|
+
workflow.add_agent_pool("processors", create_processor, instances=5)
|
|
174
|
+
```
|
|
175
|
+
"""
|
|
176
|
+
if name in self.agent_pools:
|
|
177
|
+
raise WorkflowError(f"Agent pool '{name}' already exists in workflow")
|
|
178
|
+
|
|
179
|
+
if name in self.agents:
|
|
180
|
+
raise WorkflowError(f"Name '{name}' already used by an agent in workflow")
|
|
181
|
+
|
|
182
|
+
# Import AgentPool here to avoid circular imports
|
|
183
|
+
from ..core.scaling import AgentPool
|
|
184
|
+
|
|
185
|
+
# Create agent pool
|
|
186
|
+
pool = AgentPool(
|
|
187
|
+
agent_factory=agent_factory,
|
|
188
|
+
instances=instances,
|
|
189
|
+
pool_name=f"{self.name}_{name}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.agent_pools[name] = pool
|
|
193
|
+
logger.debug(f"Added agent pool '{name}' with {instances} instances to workflow '{self.name}'")
|
|
194
|
+
return self
|
|
195
|
+
|
|
196
|
+
def remove_agent(self, name: str) -> bool:
|
|
197
|
+
"""
|
|
198
|
+
Remove agent and clean up its connections.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
name: Agent name to remove
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
True if agent was removed, False if not found
|
|
205
|
+
"""
|
|
206
|
+
if name not in self.agents:
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
# Remove agent
|
|
210
|
+
del self.agents[name]
|
|
211
|
+
|
|
212
|
+
# Clean up connections involving this agent
|
|
213
|
+
self.connections = [
|
|
214
|
+
c for c in self.connections
|
|
215
|
+
if c.from_agent != name and c.to_agent != name
|
|
216
|
+
]
|
|
217
|
+
|
|
218
|
+
# Note: Subscriptions will be cleaned up in _cleanup_connections when workflow stops
|
|
219
|
+
|
|
220
|
+
logger.debug(f"Removed agent '{name}' and cleaned up connections")
|
|
221
|
+
return True
|
|
222
|
+
|
|
223
|
+
def connect(self, from_agent: str, channel: str, to_agent: str, task: str = "relay_message") -> "Workflow":
|
|
224
|
+
"""
|
|
225
|
+
Connect two agents via a relay channel.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
from_agent: Source agent name
|
|
229
|
+
channel: Relay channel name
|
|
230
|
+
to_agent: Destination agent name
|
|
231
|
+
task: Task to execute on destination agent
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
Self for method chaining
|
|
235
|
+
"""
|
|
236
|
+
# Validate agents exist
|
|
237
|
+
if from_agent not in self.agents:
|
|
238
|
+
raise WorkflowError(f"Source agent '{from_agent}' not found")
|
|
239
|
+
if to_agent not in self.agents:
|
|
240
|
+
raise WorkflowError(f"Destination agent '{to_agent}' not found")
|
|
241
|
+
|
|
242
|
+
# Check if connection already exists
|
|
243
|
+
existing = next(
|
|
244
|
+
(c for c in self.connections if c.from_agent == from_agent
|
|
245
|
+
and c.channel == channel and c.to_agent == to_agent),
|
|
246
|
+
None
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
if existing:
|
|
250
|
+
logger.warning(f"Connection already exists: {existing}")
|
|
251
|
+
return self
|
|
252
|
+
|
|
253
|
+
connection = Connection(from_agent, channel, to_agent, task)
|
|
254
|
+
self.connections.append(connection)
|
|
255
|
+
self.channels.add(channel)
|
|
256
|
+
|
|
257
|
+
logger.debug(f"Connected {from_agent} -> {channel} -> {to_agent}")
|
|
258
|
+
return self
|
|
259
|
+
|
|
260
|
+
def configure_reliability(
|
|
261
|
+
self,
|
|
262
|
+
preset: Optional[str] = None,
|
|
263
|
+
acknowledgments: Optional[bool] = None,
|
|
264
|
+
task_tracking: Optional[bool] = None,
|
|
265
|
+
backpressure_control: Optional[bool] = None
|
|
266
|
+
) -> "Workflow":
|
|
267
|
+
"""
|
|
268
|
+
Configure reliability features for this workflow.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
preset: Predefined configuration preset ("basic", "production", "enterprise")
|
|
272
|
+
acknowledgments: Enable message acknowledgments
|
|
273
|
+
task_tracking: Enable task lifecycle tracking
|
|
274
|
+
backpressure_control: Enable backpressure control
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
Self for method chaining
|
|
278
|
+
"""
|
|
279
|
+
# Handle presets
|
|
280
|
+
if preset == "basic":
|
|
281
|
+
config = ReliabilityConfig(
|
|
282
|
+
acknowledgments=True,
|
|
283
|
+
task_tracking=True,
|
|
284
|
+
backpressure_control=True
|
|
285
|
+
)
|
|
286
|
+
elif preset == "production":
|
|
287
|
+
config = ReliabilityConfig(
|
|
288
|
+
acknowledgments=True,
|
|
289
|
+
task_tracking=True,
|
|
290
|
+
backpressure_control=True
|
|
291
|
+
)
|
|
292
|
+
elif preset == "enterprise":
|
|
293
|
+
config = ReliabilityConfig(
|
|
294
|
+
acknowledgments=True,
|
|
295
|
+
task_tracking=True,
|
|
296
|
+
backpressure_control=True
|
|
297
|
+
)
|
|
298
|
+
else:
|
|
299
|
+
# Default configuration or use provided values
|
|
300
|
+
config = ReliabilityConfig(
|
|
301
|
+
acknowledgments=acknowledgments if acknowledgments is not None else True,
|
|
302
|
+
task_tracking=task_tracking if task_tracking is not None else True,
|
|
303
|
+
backpressure_control=backpressure_control if backpressure_control is not None else True
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# Override individual settings if provided
|
|
307
|
+
if acknowledgments is not None:
|
|
308
|
+
config.acknowledgments = acknowledgments
|
|
309
|
+
if task_tracking is not None:
|
|
310
|
+
config.task_tracking = task_tracking
|
|
311
|
+
if backpressure_control is not None:
|
|
312
|
+
config.backpressure_control = backpressure_control
|
|
313
|
+
|
|
314
|
+
self.reliability_config = config
|
|
315
|
+
self._reliability_enabled = True
|
|
316
|
+
|
|
317
|
+
# Enable reliability in relay manager
|
|
318
|
+
self.relay_manager.enable_reliability = True
|
|
319
|
+
|
|
320
|
+
logger.info(f"Configured reliability for workflow '{self.name}': {config}")
|
|
321
|
+
return self
|
|
322
|
+
|
|
323
|
+
def validate_connections(self) -> List[str]:
|
|
324
|
+
"""
|
|
325
|
+
Validate all workflow connections.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
List of validation error messages (empty if all valid)
|
|
329
|
+
"""
|
|
330
|
+
errors = []
|
|
331
|
+
|
|
332
|
+
for conn in self.connections:
|
|
333
|
+
# Check from_agent exists
|
|
334
|
+
if conn.from_agent not in self.agents and conn.from_agent not in self.agent_pools:
|
|
335
|
+
errors.append(f"Source '{conn.from_agent}' not found in workflow")
|
|
336
|
+
|
|
337
|
+
# Check to_agent exists
|
|
338
|
+
if conn.to_agent not in self.agents and conn.to_agent not in self.agent_pools:
|
|
339
|
+
errors.append(f"Destination '{conn.to_agent}' not found in workflow")
|
|
340
|
+
|
|
341
|
+
# Check for circular dependencies (self-loops)
|
|
342
|
+
if conn.from_agent == conn.to_agent:
|
|
343
|
+
errors.append(f"Circular dependency: {conn.from_agent} -> {conn.to_agent}")
|
|
344
|
+
|
|
345
|
+
return errors
|
|
346
|
+
|
|
347
|
+
async def start(self) -> None:
|
|
348
|
+
"""Start the workflow and all agents with automatic tracing."""
|
|
349
|
+
if self.status in [WorkflowStatus.RUNNING, WorkflowStatus.STARTING]:
|
|
350
|
+
logger.warning(f"Workflow '{self.name}' is already running")
|
|
351
|
+
return
|
|
352
|
+
|
|
353
|
+
try:
|
|
354
|
+
self.status = WorkflowStatus.STARTING
|
|
355
|
+
logger.info(f"Starting workflow '{self.name}'")
|
|
356
|
+
|
|
357
|
+
# Validate connections before starting
|
|
358
|
+
validation_errors = self.validate_connections()
|
|
359
|
+
if validation_errors:
|
|
360
|
+
raise WorkflowError(f"Invalid connections: {'; '.join(validation_errors)}")
|
|
361
|
+
|
|
362
|
+
# Trace workflow lifecycle start
|
|
363
|
+
await self._trace_workflow_event("workflow_started", {
|
|
364
|
+
"workflow_name": self.name,
|
|
365
|
+
"agent_count": len(self.agents),
|
|
366
|
+
"agent_pool_count": len(self.agent_pools),
|
|
367
|
+
"connection_count": len(self.connections)
|
|
368
|
+
})
|
|
369
|
+
|
|
370
|
+
# Ensure relay manager is running
|
|
371
|
+
if not self.relay_manager._running:
|
|
372
|
+
try:
|
|
373
|
+
await self.relay_manager.start()
|
|
374
|
+
except Exception as e:
|
|
375
|
+
raise WorkflowError(f"Failed to start relay manager: {str(e)}")
|
|
376
|
+
|
|
377
|
+
# Configure agents with reliability features if enabled
|
|
378
|
+
if self._reliability_enabled and self.reliability_config:
|
|
379
|
+
await self._configure_agents_reliability()
|
|
380
|
+
|
|
381
|
+
# Start all agents
|
|
382
|
+
for agent_name, agent in self.agents.items():
|
|
383
|
+
try:
|
|
384
|
+
if hasattr(agent, 'start'):
|
|
385
|
+
await agent.start()
|
|
386
|
+
logger.debug(f"Started agent '{agent_name}'")
|
|
387
|
+
except Exception as e:
|
|
388
|
+
logger.error(f"Failed to start agent '{agent_name}': {str(e)}")
|
|
389
|
+
raise WorkflowError(f"Failed to start agent '{agent_name}': {str(e)}")
|
|
390
|
+
|
|
391
|
+
# Start all agent pools
|
|
392
|
+
for pool_name, pool in self.agent_pools.items():
|
|
393
|
+
try:
|
|
394
|
+
await pool.start()
|
|
395
|
+
logger.debug(f"Started agent pool '{pool_name}' with {pool.instance_count} instances")
|
|
396
|
+
except Exception as e:
|
|
397
|
+
logger.error(f"Failed to start agent pool '{pool_name}': {str(e)}")
|
|
398
|
+
raise WorkflowError(f"Failed to start agent pool '{pool_name}': {str(e)}")
|
|
399
|
+
|
|
400
|
+
# Set up relay connections with automatic tracing
|
|
401
|
+
await self._setup_connections()
|
|
402
|
+
|
|
403
|
+
# Start dedup cleanup task if reliability enabled
|
|
404
|
+
if self._reliability_enabled:
|
|
405
|
+
self._dedup_cleanup_task = asyncio.create_task(self._cleanup_dedup_cache())
|
|
406
|
+
|
|
407
|
+
# Update status
|
|
408
|
+
self.status = WorkflowStatus.RUNNING
|
|
409
|
+
self.started_at = time.time()
|
|
410
|
+
|
|
411
|
+
logger.info(f"Workflow '{self.name}' started successfully")
|
|
412
|
+
|
|
413
|
+
except Exception as e:
|
|
414
|
+
self.status = WorkflowStatus.ERROR
|
|
415
|
+
self.error = str(e)
|
|
416
|
+
logger.error(f"Failed to start workflow '{self.name}': {str(e)}")
|
|
417
|
+
|
|
418
|
+
# Trace workflow error
|
|
419
|
+
await self._trace_workflow_event("workflow_error", {
|
|
420
|
+
"workflow_name": self.name,
|
|
421
|
+
"error": str(e)
|
|
422
|
+
})
|
|
423
|
+
raise
|
|
424
|
+
|
|
425
|
+
async def stop(self) -> None:
|
|
426
|
+
"""Stop the workflow by stopping all agents and cleaning up connections."""
|
|
427
|
+
if self.status in [WorkflowStatus.STOPPED, WorkflowStatus.STOPPING]:
|
|
428
|
+
logger.warning(f"Workflow '{self.name}' is already stopped")
|
|
429
|
+
return
|
|
430
|
+
|
|
431
|
+
try:
|
|
432
|
+
self.status = WorkflowStatus.STOPPING
|
|
433
|
+
logger.info(f"Stopping workflow '{self.name}'")
|
|
434
|
+
|
|
435
|
+
# Stop dedup cleanup task
|
|
436
|
+
if self._dedup_cleanup_task:
|
|
437
|
+
self._dedup_cleanup_task.cancel()
|
|
438
|
+
try:
|
|
439
|
+
await self._dedup_cleanup_task
|
|
440
|
+
except asyncio.CancelledError:
|
|
441
|
+
pass
|
|
442
|
+
self._dedup_cleanup_task = None
|
|
443
|
+
|
|
444
|
+
# Clean up relay subscriptions
|
|
445
|
+
await self._cleanup_connections()
|
|
446
|
+
|
|
447
|
+
# Stop all agents
|
|
448
|
+
for agent_name, agent in self.agents.items():
|
|
449
|
+
try:
|
|
450
|
+
if hasattr(agent, 'stop'):
|
|
451
|
+
await agent.stop()
|
|
452
|
+
logger.debug(f"Stopped agent '{agent_name}'")
|
|
453
|
+
except Exception as e:
|
|
454
|
+
logger.warning(f"Error stopping agent '{agent_name}': {str(e)}")
|
|
455
|
+
|
|
456
|
+
# Stop all agent pools
|
|
457
|
+
for pool_name, pool in self.agent_pools.items():
|
|
458
|
+
try:
|
|
459
|
+
await pool.stop()
|
|
460
|
+
logger.debug(f"Stopped agent pool '{pool_name}'")
|
|
461
|
+
except Exception as e:
|
|
462
|
+
logger.warning(f"Error stopping agent pool '{pool_name}': {str(e)}")
|
|
463
|
+
|
|
464
|
+
self.status = WorkflowStatus.STOPPED
|
|
465
|
+
self.stopped_at = time.time()
|
|
466
|
+
|
|
467
|
+
# Trace workflow lifecycle stop
|
|
468
|
+
await self._trace_workflow_event("workflow_stopped", {
|
|
469
|
+
"workflow_name": self.name,
|
|
470
|
+
"running_time_seconds": self.stopped_at - (self.started_at or self.stopped_at)
|
|
471
|
+
})
|
|
472
|
+
|
|
473
|
+
logger.info(f"Workflow '{self.name}' stopped")
|
|
474
|
+
|
|
475
|
+
except Exception as e:
|
|
476
|
+
self.status = WorkflowStatus.ERROR
|
|
477
|
+
self.error = str(e)
|
|
478
|
+
logger.error(f"Error stopping workflow '{self.name}': {str(e)}")
|
|
479
|
+
raise
|
|
480
|
+
|
|
481
|
+
async def _setup_connections(self) -> None:
|
|
482
|
+
"""Set up relay connections between agents with automatic tracing."""
|
|
483
|
+
for connection in self.connections:
|
|
484
|
+
try:
|
|
485
|
+
# Create callback based on reliability configuration
|
|
486
|
+
if self._reliability_enabled and self.reliability_config:
|
|
487
|
+
callback = self._create_reliable_callback(connection)
|
|
488
|
+
else:
|
|
489
|
+
callback = self._create_traced_callback(connection)
|
|
490
|
+
|
|
491
|
+
# Subscribe to the relay channel
|
|
492
|
+
await self.relay_manager.subscribe(connection.channel, callback)
|
|
493
|
+
self._subscriptions.append((connection.channel, callback))
|
|
494
|
+
|
|
495
|
+
logger.debug(f"Set up connection: {connection} (reliability: {self._reliability_enabled})")
|
|
496
|
+
|
|
497
|
+
except Exception as e:
|
|
498
|
+
logger.error(f"Failed to set up connection {connection}: {str(e)}")
|
|
499
|
+
raise WorkflowError(f"Failed to set up connection {connection}: {str(e)}")
|
|
500
|
+
|
|
501
|
+
def _create_traced_callback(self, connection: Connection):
|
|
502
|
+
"""Create a callback that automatically traces workflow communication and propagates metadata."""
|
|
503
|
+
async def traced_callback(data: Any, metadata: Optional[Dict[str, Any]] = None):
|
|
504
|
+
"""Callback that processes relay data with automatic metadata propagation."""
|
|
505
|
+
try:
|
|
506
|
+
# Automatically trace the workflow communication
|
|
507
|
+
await self._trace_workflow_communication(
|
|
508
|
+
from_agent=connection.from_agent,
|
|
509
|
+
to_agent=connection.to_agent,
|
|
510
|
+
channel=connection.channel,
|
|
511
|
+
data=data,
|
|
512
|
+
success=True
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
# Build enriched context with metadata propagation
|
|
516
|
+
enriched_context = {
|
|
517
|
+
'source_agent': connection.from_agent,
|
|
518
|
+
'channel': connection.channel,
|
|
519
|
+
'workflow': self.name
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
# Add all upstream metadata to context
|
|
523
|
+
if metadata:
|
|
524
|
+
enriched_context.update(metadata)
|
|
525
|
+
|
|
526
|
+
# Process the data with the destination agent or agent pool
|
|
527
|
+
if connection.to_agent in self.agents:
|
|
528
|
+
# Single agent
|
|
529
|
+
dest_agent = self.agents[connection.to_agent]
|
|
530
|
+
|
|
531
|
+
# Try new API first
|
|
532
|
+
if hasattr(dest_agent, 'receive_message'):
|
|
533
|
+
await dest_agent.receive_message(
|
|
534
|
+
data=data,
|
|
535
|
+
source_agent=connection.from_agent,
|
|
536
|
+
channel=connection.channel,
|
|
537
|
+
workflow_name=self.name
|
|
538
|
+
)
|
|
539
|
+
# Fallback to old API for legacy agents
|
|
540
|
+
elif hasattr(dest_agent, '_process'):
|
|
541
|
+
await dest_agent._process(connection.task, data, enriched_context)
|
|
542
|
+
else:
|
|
543
|
+
logger.warning(
|
|
544
|
+
f"Agent '{connection.to_agent}' has no receive_message() or _process() method"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
elif connection.to_agent in self.agent_pools:
|
|
548
|
+
# Agent pool - submit task to pool with enriched context
|
|
549
|
+
pool = self.agent_pools[connection.to_agent]
|
|
550
|
+
await pool.submit_task(connection.task, data, enriched_context)
|
|
551
|
+
|
|
552
|
+
else:
|
|
553
|
+
logger.error(f"Destination '{connection.to_agent}' not found in agents or agent pools")
|
|
554
|
+
|
|
555
|
+
except Exception as e:
|
|
556
|
+
# Trace communication error
|
|
557
|
+
await self._trace_workflow_communication(
|
|
558
|
+
from_agent=connection.from_agent,
|
|
559
|
+
to_agent=connection.to_agent,
|
|
560
|
+
channel=connection.channel,
|
|
561
|
+
data=data,
|
|
562
|
+
success=False,
|
|
563
|
+
error_message=str(e)
|
|
564
|
+
)
|
|
565
|
+
logger.error(f"Error in workflow communication {connection}: {str(e)}")
|
|
566
|
+
|
|
567
|
+
return traced_callback
|
|
568
|
+
|
|
569
|
+
async def _configure_agents_reliability(self) -> None:
|
|
570
|
+
"""Configure agents with reliability features."""
|
|
571
|
+
if not self.reliability_config:
|
|
572
|
+
return
|
|
573
|
+
|
|
574
|
+
for agent_name, agent in self.agents.items():
|
|
575
|
+
try:
|
|
576
|
+
# Enable reliability features if agent supports it
|
|
577
|
+
if hasattr(agent, 'enable_reliability_features') and self.reliability_config.task_tracking:
|
|
578
|
+
agent.enable_reliability_features()
|
|
579
|
+
logger.debug(f"Enabled reliability features for agent '{agent_name}'")
|
|
580
|
+
|
|
581
|
+
# Configure backpressure if supported
|
|
582
|
+
if (hasattr(agent, 'backpressure_controller') and
|
|
583
|
+
self.reliability_config.backpressure_control and
|
|
584
|
+
agent.backpressure_controller is None):
|
|
585
|
+
from ..core.reliability import BackpressureController
|
|
586
|
+
agent.backpressure_controller = BackpressureController(
|
|
587
|
+
max_concurrent_tasks=10,
|
|
588
|
+
max_queue_size=100,
|
|
589
|
+
agent_id=getattr(agent, 'agent_id', agent_name)
|
|
590
|
+
)
|
|
591
|
+
logger.debug(f"Configured backpressure control for agent '{agent_name}'")
|
|
592
|
+
|
|
593
|
+
except Exception as e:
|
|
594
|
+
logger.warning(f"Failed to configure reliability for agent '{agent_name}': {e}")
|
|
595
|
+
|
|
596
|
+
def _create_reliable_callback(self, connection: Connection):
|
|
597
|
+
"""Create a callback with reliability features and metadata propagation."""
|
|
598
|
+
async def reliable_callback(data: Any, metadata: Optional[Dict[str, Any]] = None, message_id: Optional[str] = None):
|
|
599
|
+
"""Callback that processes relay data with reliability features and metadata."""
|
|
600
|
+
start_time = time.time()
|
|
601
|
+
|
|
602
|
+
# Deduplication check (only if message_id provided)
|
|
603
|
+
if message_id:
|
|
604
|
+
if message_id in self._processed_messages:
|
|
605
|
+
logger.debug(f"Skipping duplicate message {message_id}")
|
|
606
|
+
# Still ACK it since we processed it before (idempotency)
|
|
607
|
+
if self.reliability_config and self.reliability_config.acknowledgments:
|
|
608
|
+
await self.relay_manager.ack_message(message_id)
|
|
609
|
+
return
|
|
610
|
+
|
|
611
|
+
# Mark as processing
|
|
612
|
+
self._processed_messages.add(message_id)
|
|
613
|
+
|
|
614
|
+
try:
|
|
615
|
+
# Automatically trace the workflow communication
|
|
616
|
+
await self._trace_workflow_communication(
|
|
617
|
+
from_agent=connection.from_agent,
|
|
618
|
+
to_agent=connection.to_agent,
|
|
619
|
+
channel=connection.channel,
|
|
620
|
+
data=data,
|
|
621
|
+
success=True,
|
|
622
|
+
message_id=message_id
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Build enriched context with metadata propagation
|
|
626
|
+
enriched_context = {
|
|
627
|
+
'source_agent': connection.from_agent,
|
|
628
|
+
'channel': connection.channel,
|
|
629
|
+
'workflow': self.name,
|
|
630
|
+
'message_id': message_id,
|
|
631
|
+
'reliability_enabled': True
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
# Add all upstream metadata to context
|
|
635
|
+
if metadata:
|
|
636
|
+
enriched_context.update(metadata)
|
|
637
|
+
|
|
638
|
+
# Process the data with the destination agent
|
|
639
|
+
dest_agent = self.agents[connection.to_agent]
|
|
640
|
+
|
|
641
|
+
# Try new API first
|
|
642
|
+
if hasattr(dest_agent, 'receive_message'):
|
|
643
|
+
# Handle backpressure if enabled
|
|
644
|
+
if (self.reliability_config and
|
|
645
|
+
self.reliability_config.backpressure_control and
|
|
646
|
+
hasattr(dest_agent, 'backpressure_controller') and
|
|
647
|
+
dest_agent.backpressure_controller):
|
|
648
|
+
|
|
649
|
+
# Check if agent can handle the task
|
|
650
|
+
if not await dest_agent.backpressure_controller.acquire_processing_slot():
|
|
651
|
+
raise BackpressureError(
|
|
652
|
+
f"Agent {connection.to_agent} queue is full",
|
|
653
|
+
agent_id=getattr(dest_agent, 'agent_id', connection.to_agent)
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
try:
|
|
657
|
+
await dest_agent.receive_message(
|
|
658
|
+
data=data,
|
|
659
|
+
source_agent=connection.from_agent,
|
|
660
|
+
channel=connection.channel,
|
|
661
|
+
workflow_name=self.name
|
|
662
|
+
)
|
|
663
|
+
finally:
|
|
664
|
+
dest_agent.backpressure_controller.release_processing_slot()
|
|
665
|
+
else:
|
|
666
|
+
await dest_agent.receive_message(
|
|
667
|
+
data=data,
|
|
668
|
+
source_agent=connection.from_agent,
|
|
669
|
+
channel=connection.channel,
|
|
670
|
+
workflow_name=self.name
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
# Acknowledge message if reliability is enabled and message_id provided
|
|
674
|
+
if message_id and self.reliability_config and self.reliability_config.acknowledgments:
|
|
675
|
+
await self.relay_manager.ack_message(message_id)
|
|
676
|
+
|
|
677
|
+
# Fallback to old API for legacy agents
|
|
678
|
+
elif hasattr(dest_agent, '_process'):
|
|
679
|
+
context = enriched_context
|
|
680
|
+
|
|
681
|
+
# Handle backpressure if enabled
|
|
682
|
+
if (self.reliability_config and
|
|
683
|
+
self.reliability_config.backpressure_control and
|
|
684
|
+
hasattr(dest_agent, 'backpressure_controller') and
|
|
685
|
+
dest_agent.backpressure_controller):
|
|
686
|
+
|
|
687
|
+
# Check if agent can handle the task
|
|
688
|
+
if not await dest_agent.backpressure_controller.acquire_processing_slot():
|
|
689
|
+
raise BackpressureError(
|
|
690
|
+
f"Agent {connection.to_agent} queue is full",
|
|
691
|
+
agent_id=getattr(dest_agent, 'agent_id', connection.to_agent)
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
try:
|
|
695
|
+
await dest_agent._process(connection.task, data, context)
|
|
696
|
+
finally:
|
|
697
|
+
dest_agent.backpressure_controller.release_processing_slot()
|
|
698
|
+
else:
|
|
699
|
+
await dest_agent._process(connection.task, data, context)
|
|
700
|
+
|
|
701
|
+
# Acknowledge message if reliability is enabled and message_id provided
|
|
702
|
+
if message_id and self.reliability_config and self.reliability_config.acknowledgments:
|
|
703
|
+
await self.relay_manager.ack_message(message_id)
|
|
704
|
+
|
|
705
|
+
else:
|
|
706
|
+
logger.warning(f"Agent '{connection.to_agent}' has no receive_message() or _process() method")
|
|
707
|
+
# NACK the message since we can't process it
|
|
708
|
+
if message_id and self.reliability_config and self.reliability_config.acknowledgments:
|
|
709
|
+
await self.relay_manager.nack_message(message_id, "Agent has no receive_message() or _process() method")
|
|
710
|
+
|
|
711
|
+
except Exception as e:
|
|
712
|
+
# Remove from processed on error so it can be retried
|
|
713
|
+
if message_id:
|
|
714
|
+
self._processed_messages.discard(message_id)
|
|
715
|
+
|
|
716
|
+
# Trace communication error
|
|
717
|
+
await self._trace_workflow_communication(
|
|
718
|
+
from_agent=connection.from_agent,
|
|
719
|
+
to_agent=connection.to_agent,
|
|
720
|
+
channel=connection.channel,
|
|
721
|
+
data=data,
|
|
722
|
+
success=False,
|
|
723
|
+
error_message=str(e),
|
|
724
|
+
message_id=message_id
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
# NACK the message on error
|
|
728
|
+
if message_id and self.reliability_config and self.reliability_config.acknowledgments:
|
|
729
|
+
await self.relay_manager.nack_message(message_id, str(e))
|
|
730
|
+
|
|
731
|
+
# Enhanced error propagation
|
|
732
|
+
error_context = {
|
|
733
|
+
'connection': str(connection),
|
|
734
|
+
'processing_time': time.time() - start_time,
|
|
735
|
+
'message_id': message_id,
|
|
736
|
+
'workflow': self.name
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
workflow_error = WorkflowError(
|
|
740
|
+
f"Agent {connection.to_agent} failed processing {connection.task}: {str(e)}",
|
|
741
|
+
workflow_name=self.name,
|
|
742
|
+
context=error_context
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
logger.error(f"Error in reliable workflow communication {connection}: {str(e)}")
|
|
746
|
+
|
|
747
|
+
raise workflow_error
|
|
748
|
+
|
|
749
|
+
return reliable_callback
|
|
750
|
+
|
|
751
|
+
async def _cleanup_connections(self) -> None:
|
|
752
|
+
"""Clean up relay subscriptions."""
|
|
753
|
+
for channel, callback in self._subscriptions:
|
|
754
|
+
try:
|
|
755
|
+
self.relay_manager.unsubscribe(channel, callback)
|
|
756
|
+
except Exception as e:
|
|
757
|
+
logger.warning(f"Error cleaning up subscription for channel '{channel}': {str(e)}")
|
|
758
|
+
|
|
759
|
+
self._subscriptions.clear()
|
|
760
|
+
logger.debug("Cleaned up relay subscriptions")
|
|
761
|
+
|
|
762
|
+
async def _cleanup_dedup_cache(self) -> None:
|
|
763
|
+
"""Periodic cleanup of dedup cache to prevent memory leaks."""
|
|
764
|
+
while self.status == WorkflowStatus.RUNNING:
|
|
765
|
+
try:
|
|
766
|
+
await asyncio.sleep(300) # Every 5 minutes
|
|
767
|
+
|
|
768
|
+
# Simple size-based cleanup
|
|
769
|
+
if len(self._processed_messages) > self._dedup_max_size:
|
|
770
|
+
logger.warning(
|
|
771
|
+
f"Dedup cache exceeded {self._dedup_max_size} entries, clearing"
|
|
772
|
+
)
|
|
773
|
+
self._processed_messages.clear()
|
|
774
|
+
except asyncio.CancelledError:
|
|
775
|
+
break
|
|
776
|
+
except Exception as e:
|
|
777
|
+
logger.error(f"Error in dedup cleanup: {e}")
|
|
778
|
+
|
|
779
|
+
async def inject_data(self, agent_name: str, data: Any, task: str = "inject") -> None:
|
|
780
|
+
"""
|
|
781
|
+
Inject data into a specific agent to trigger workflow processing.
|
|
782
|
+
|
|
783
|
+
Args:
|
|
784
|
+
agent_name: Name of the agent to inject data into
|
|
785
|
+
data: Data to inject
|
|
786
|
+
task: Task to execute
|
|
787
|
+
"""
|
|
788
|
+
if self.status != WorkflowStatus.RUNNING:
|
|
789
|
+
raise WorkflowError(f"Cannot inject data - workflow is not running (status: {self.status})")
|
|
790
|
+
|
|
791
|
+
if agent_name not in self.agents:
|
|
792
|
+
raise WorkflowError(f"Agent '{agent_name}' not found in workflow")
|
|
793
|
+
|
|
794
|
+
agent = self.agents[agent_name]
|
|
795
|
+
|
|
796
|
+
try:
|
|
797
|
+
# Trace data injection
|
|
798
|
+
await self._trace_workflow_event("data_injected", {
|
|
799
|
+
"workflow_name": self.name,
|
|
800
|
+
"target_agent": agent_name,
|
|
801
|
+
"task": task,
|
|
802
|
+
"data_type": type(data).__name__
|
|
803
|
+
})
|
|
804
|
+
|
|
805
|
+
# Process data with the agent
|
|
806
|
+
if hasattr(agent, 'process'):
|
|
807
|
+
await agent.process(task, data, {'workflow': self.name, 'injection': True})
|
|
808
|
+
logger.debug(f"Injected data into agent '{agent_name}' in workflow '{self.name}'")
|
|
809
|
+
else:
|
|
810
|
+
logger.warning(f"Agent '{agent_name}' has no process method")
|
|
811
|
+
|
|
812
|
+
except Exception as e:
|
|
813
|
+
# Log the error but don't raise it - allow workflow to continue
|
|
814
|
+
logger.error(f"Error injecting data into agent '{agent_name}': {str(e)}")
|
|
815
|
+
|
|
816
|
+
# Tracing methods for workflow events
|
|
817
|
+
|
|
818
|
+
async def _trace_workflow_communication(
|
|
819
|
+
self,
|
|
820
|
+
from_agent: str,
|
|
821
|
+
to_agent: str,
|
|
822
|
+
channel: str,
|
|
823
|
+
data: Any,
|
|
824
|
+
success: bool,
|
|
825
|
+
error_message: Optional[str] = None,
|
|
826
|
+
message_id: Optional[str] = None
|
|
827
|
+
) -> None:
|
|
828
|
+
"""Trace workflow communication using the unified tracing system."""
|
|
829
|
+
try:
|
|
830
|
+
# Create a communication span
|
|
831
|
+
span_id = self.trace_manager.start_span(
|
|
832
|
+
operation_name=f"workflow_communication",
|
|
833
|
+
trace_type=TraceType.WORKFLOW_COMMUNICATION,
|
|
834
|
+
input_data={
|
|
835
|
+
"from_agent": from_agent,
|
|
836
|
+
"to_agent": to_agent,
|
|
837
|
+
"channel": channel,
|
|
838
|
+
"data_preview": str(data)[:200] if data else None
|
|
839
|
+
},
|
|
840
|
+
workflow_name=self.name,
|
|
841
|
+
from_agent=from_agent,
|
|
842
|
+
to_agent=to_agent,
|
|
843
|
+
channel=channel,
|
|
844
|
+
data_type=type(data).__name__,
|
|
845
|
+
message_id=message_id,
|
|
846
|
+
reliability_enabled=str(self._reliability_enabled)
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
# End the span with the result
|
|
850
|
+
self.trace_manager.end_span(
|
|
851
|
+
span_id=span_id,
|
|
852
|
+
status=TraceStatus.SUCCESS if success else TraceStatus.ERROR,
|
|
853
|
+
output_data={"communication_processed": success},
|
|
854
|
+
error_message=error_message
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
except Exception as e:
|
|
858
|
+
logger.warning(f"Failed to trace workflow communication: {e}")
|
|
859
|
+
|
|
860
|
+
async def _trace_workflow_event(self, event_type: str, event_data: Dict[str, Any]) -> None:
|
|
861
|
+
"""Trace general workflow events using the unified tracing system."""
|
|
862
|
+
try:
|
|
863
|
+
span_id = self.trace_manager.start_span(
|
|
864
|
+
operation_name=f"workflow_{event_type}",
|
|
865
|
+
trace_type=TraceType.WORKFLOW_COMMUNICATION,
|
|
866
|
+
input_data=event_data,
|
|
867
|
+
metadata={
|
|
868
|
+
"workflow_name": self.name,
|
|
869
|
+
"event_type": event_type
|
|
870
|
+
}
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
self.trace_manager.end_span(
|
|
874
|
+
span_id=span_id,
|
|
875
|
+
status=TraceStatus.SUCCESS,
|
|
876
|
+
output_data={"event_recorded": True}
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
except Exception as e:
|
|
880
|
+
logger.warning(f"Failed to trace workflow event: {e}")
|
|
881
|
+
|
|
882
|
+
# Simplified query methods using unified tracing
|
|
883
|
+
|
|
884
|
+
def get_recent_communication(self, limit: int = 20) -> List[Dict[str, Any]]:
|
|
885
|
+
"""
|
|
886
|
+
Get recent workflow communication from the unified tracing system.
|
|
887
|
+
|
|
888
|
+
Args:
|
|
889
|
+
limit: Maximum number of communication events to return
|
|
890
|
+
|
|
891
|
+
Returns:
|
|
892
|
+
List of recent workflow communication events
|
|
893
|
+
"""
|
|
894
|
+
try:
|
|
895
|
+
# Get recent operations from unified system
|
|
896
|
+
return self.trace_manager.get_recent_operations(limit=limit)
|
|
897
|
+
except Exception as e:
|
|
898
|
+
logger.warning(f"Failed to get recent communication: {e}")
|
|
899
|
+
return []
|
|
900
|
+
|
|
901
|
+
def get_communication_log(self, count: int = 20) -> List[Dict[str, Any]]:
|
|
902
|
+
"""
|
|
903
|
+
Get workflow communication log (alias for get_recent_communication).
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
count: Maximum number of communication events to return
|
|
907
|
+
|
|
908
|
+
Returns:
|
|
909
|
+
List of recent workflow communication events
|
|
910
|
+
"""
|
|
911
|
+
return self.get_recent_communication(limit=count)
|
|
912
|
+
|
|
913
|
+
def get_workflow_stats(self) -> Dict[str, Any]:
|
|
914
|
+
"""Get workflow statistics from the unified tracing system."""
|
|
915
|
+
try:
|
|
916
|
+
# Get workflow-specific metrics
|
|
917
|
+
workflow_metrics = self.trace_manager.get_workflow_metrics(self.name)
|
|
918
|
+
return workflow_metrics
|
|
919
|
+
except Exception as e:
|
|
920
|
+
logger.warning(f"Failed to get workflow stats: {e}")
|
|
921
|
+
return {}
|
|
922
|
+
|
|
923
|
+
# Basic workflow information methods
|
|
924
|
+
|
|
925
|
+
def get_agent(self, name: str) -> Optional[Any]:
|
|
926
|
+
"""Get an agent by name."""
|
|
927
|
+
return self.agents.get(name)
|
|
928
|
+
|
|
929
|
+
def list_agents(self) -> List[str]:
|
|
930
|
+
"""List all agent names in the workflow."""
|
|
931
|
+
return list(self.agents.keys())
|
|
932
|
+
|
|
933
|
+
def list_connections(self) -> List[str]:
|
|
934
|
+
"""List all connections as strings."""
|
|
935
|
+
return [str(conn) for conn in self.connections]
|
|
936
|
+
|
|
937
|
+
def get_channel_data(self, channel: str, count: int = 1) -> List[Any]:
|
|
938
|
+
"""Get latest data from a relay channel."""
|
|
939
|
+
return self.relay_manager.get_latest(channel, count)
|
|
940
|
+
|
|
941
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
942
|
+
"""Get comprehensive workflow statistics including reliability metrics."""
|
|
943
|
+
running_time = None
|
|
944
|
+
if self.started_at:
|
|
945
|
+
end_time = self.stopped_at or time.time()
|
|
946
|
+
running_time = end_time - self.started_at
|
|
947
|
+
|
|
948
|
+
stats = {
|
|
949
|
+
'name': self.name,
|
|
950
|
+
'project_id': self.project_id,
|
|
951
|
+
'status': self.status.value,
|
|
952
|
+
'agent_count': len(self.agents),
|
|
953
|
+
'connection_count': len(self.connections),
|
|
954
|
+
'channel_count': len(self.channels),
|
|
955
|
+
'created_at': self.created_at,
|
|
956
|
+
'started_at': self.started_at,
|
|
957
|
+
'stopped_at': self.stopped_at,
|
|
958
|
+
'running_time': running_time,
|
|
959
|
+
'error': self.error,
|
|
960
|
+
'agents': list(self.agents.keys()),
|
|
961
|
+
'channels': list(self.channels),
|
|
962
|
+
'reliability_enabled': self._reliability_enabled
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
# Add reliability statistics if enabled
|
|
966
|
+
if self._reliability_enabled and self.reliability_config:
|
|
967
|
+
stats['reliability_config'] = {
|
|
968
|
+
'acknowledgments': self.reliability_config.acknowledgments,
|
|
969
|
+
'task_tracking': self.reliability_config.task_tracking,
|
|
970
|
+
'backpressure_control': self.reliability_config.backpressure_control
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
# Get relay manager reliability stats
|
|
974
|
+
relay_stats = self.relay_manager.get_stats()
|
|
975
|
+
if 'pending_messages' in relay_stats:
|
|
976
|
+
stats['pending_messages'] = relay_stats['pending_messages']
|
|
977
|
+
stats['active_timeouts'] = relay_stats['active_timeouts']
|
|
978
|
+
|
|
979
|
+
return stats
|
|
980
|
+
|
|
981
|
+
# Reliability management methods
|
|
982
|
+
|
|
983
|
+
def get_pending_messages(self) -> List[Dict[str, Any]]:
|
|
984
|
+
"""Get list of pending messages waiting for acknowledgment."""
|
|
985
|
+
if not self._reliability_enabled:
|
|
986
|
+
return []
|
|
987
|
+
return self.relay_manager.get_pending_messages()
|
|
988
|
+
|
|
989
|
+
async def get_agent_reliability_stats(self, agent_name: str) -> Dict[str, Any]:
|
|
990
|
+
"""Get reliability statistics for a specific agent."""
|
|
991
|
+
if agent_name not in self.agents:
|
|
992
|
+
return {}
|
|
993
|
+
|
|
994
|
+
agent = self.agents[agent_name]
|
|
995
|
+
stats = {'agent_name': agent_name, 'reliability_enabled': False}
|
|
996
|
+
|
|
997
|
+
if hasattr(agent, 'enable_reliability') and agent.enable_reliability:
|
|
998
|
+
stats['reliability_enabled'] = True
|
|
999
|
+
|
|
1000
|
+
# Get task management stats
|
|
1001
|
+
if hasattr(agent, 'get_agent_tasks'):
|
|
1002
|
+
try:
|
|
1003
|
+
tasks = await agent.get_agent_tasks()
|
|
1004
|
+
stats['total_tasks'] = len(tasks)
|
|
1005
|
+
stats['tasks_by_status'] = {}
|
|
1006
|
+
for task in tasks:
|
|
1007
|
+
status = task.get('status', 'unknown')
|
|
1008
|
+
stats['tasks_by_status'][status] = stats['tasks_by_status'].get(status, 0) + 1
|
|
1009
|
+
except Exception as e:
|
|
1010
|
+
logger.warning(f"Failed to get task stats for agent {agent_name}: {e}")
|
|
1011
|
+
|
|
1012
|
+
# Get backpressure stats
|
|
1013
|
+
if hasattr(agent, 'get_backpressure_stats'):
|
|
1014
|
+
try:
|
|
1015
|
+
bp_stats = agent.get_backpressure_stats()
|
|
1016
|
+
stats['backpressure'] = bp_stats
|
|
1017
|
+
except Exception as e:
|
|
1018
|
+
logger.warning(f"Failed to get backpressure stats for agent {agent_name}: {e}")
|
|
1019
|
+
|
|
1020
|
+
return stats
|
|
1021
|
+
|
|
1022
|
+
def is_reliability_enabled(self) -> bool:
|
|
1023
|
+
"""Check if reliability features are enabled for this workflow."""
|
|
1024
|
+
return self._reliability_enabled
|
|
1025
|
+
|
|
1026
|
+
def get_reliability_config(self) -> Optional[ReliabilityConfig]:
|
|
1027
|
+
"""Get the current reliability configuration."""
|
|
1028
|
+
return self.reliability_config
|
|
1029
|
+
|
|
1030
|
+
def get_token_usage(self) -> Dict[str, Any]:
|
|
1031
|
+
"""
|
|
1032
|
+
Aggregate token usage from all agents in the workflow.
|
|
1033
|
+
|
|
1034
|
+
Returns:
|
|
1035
|
+
Dictionary containing aggregated token usage across all agents
|
|
1036
|
+
"""
|
|
1037
|
+
total_usage = {
|
|
1038
|
+
"total_tokens": 0,
|
|
1039
|
+
"prompt_tokens": 0,
|
|
1040
|
+
"completion_tokens": 0,
|
|
1041
|
+
"llm_calls": 0,
|
|
1042
|
+
"models_used": [],
|
|
1043
|
+
"agents_with_usage": []
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
for agent_name, agent in self.agents.items():
|
|
1047
|
+
try:
|
|
1048
|
+
# Method 1: Check if agent has get_token_usage()
|
|
1049
|
+
if hasattr(agent, 'get_token_usage'):
|
|
1050
|
+
usage = agent.get_token_usage()
|
|
1051
|
+
if usage and isinstance(usage, dict):
|
|
1052
|
+
tokens = usage.get("total_tokens", 0)
|
|
1053
|
+
if tokens > 0:
|
|
1054
|
+
total_usage["total_tokens"] += tokens
|
|
1055
|
+
total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
|
1056
|
+
total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
|
1057
|
+
total_usage["llm_calls"] += usage.get("total_calls", usage.get("llm_calls", 0))
|
|
1058
|
+
total_usage["agents_with_usage"].append(agent_name)
|
|
1059
|
+
logger.debug(f"Workflow: Agent '{agent_name}' used {tokens} tokens")
|
|
1060
|
+
|
|
1061
|
+
# Method 2: Check if agent has llm_provider with usage tracking
|
|
1062
|
+
elif hasattr(agent, 'llm_provider') and hasattr(agent.llm_provider, 'get_token_usage'):
|
|
1063
|
+
llm_usage = agent.llm_provider.get_token_usage()
|
|
1064
|
+
if llm_usage and isinstance(llm_usage, dict):
|
|
1065
|
+
tokens = llm_usage.get("total_tokens", 0)
|
|
1066
|
+
if tokens > 0:
|
|
1067
|
+
total_usage["total_tokens"] += tokens
|
|
1068
|
+
total_usage["prompt_tokens"] += llm_usage.get("prompt_tokens", 0)
|
|
1069
|
+
total_usage["completion_tokens"] += llm_usage.get("completion_tokens", 0)
|
|
1070
|
+
total_usage["llm_calls"] += 1
|
|
1071
|
+
total_usage["agents_with_usage"].append(agent_name)
|
|
1072
|
+
logger.debug(f"Workflow: Agent '{agent_name}' (via llm_provider) used {tokens} tokens")
|
|
1073
|
+
|
|
1074
|
+
except Exception as e:
|
|
1075
|
+
logger.warning(f"Failed to get token usage from agent '{agent_name}': {e}")
|
|
1076
|
+
|
|
1077
|
+
if total_usage["total_tokens"] > 0:
|
|
1078
|
+
logger.info(f"Workflow '{self.name}' total token usage: {total_usage['total_tokens']} tokens across {len(total_usage['agents_with_usage'])} agents")
|
|
1079
|
+
return total_usage
|
|
1080
|
+
else:
|
|
1081
|
+
logger.debug(f"Workflow '{self.name}' has no token usage")
|
|
1082
|
+
return None
|
|
1083
|
+
|
|
1084
|
+
def health_check(self) -> Dict[str, Any]:
|
|
1085
|
+
"""
|
|
1086
|
+
Comprehensive workflow health check.
|
|
1087
|
+
|
|
1088
|
+
Returns:
|
|
1089
|
+
Dictionary containing health status and any issues found
|
|
1090
|
+
"""
|
|
1091
|
+
health = {
|
|
1092
|
+
'status': self.status.value,
|
|
1093
|
+
'healthy': self.status == WorkflowStatus.RUNNING,
|
|
1094
|
+
'agents': {},
|
|
1095
|
+
'issues': []
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
# Check each agent (with safe attribute access)
|
|
1099
|
+
for agent_name, agent in self.agents.items():
|
|
1100
|
+
agent_health = {
|
|
1101
|
+
'name': agent_name,
|
|
1102
|
+
'has_process': hasattr(agent, 'process'),
|
|
1103
|
+
'running': self.status == WorkflowStatus.RUNNING
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
# Check if agent has health method (optional)
|
|
1107
|
+
if hasattr(agent, 'get_health'):
|
|
1108
|
+
try:
|
|
1109
|
+
agent_health.update(agent.get_health())
|
|
1110
|
+
except Exception as e:
|
|
1111
|
+
agent_health['health_error'] = str(e)
|
|
1112
|
+
|
|
1113
|
+
health['agents'][agent_name] = agent_health
|
|
1114
|
+
|
|
1115
|
+
if not agent_health['has_process']:
|
|
1116
|
+
health['issues'].append(f"Agent '{agent_name}' has no process method")
|
|
1117
|
+
health['healthy'] = False
|
|
1118
|
+
|
|
1119
|
+
# Check subscriptions for potential memory leaks
|
|
1120
|
+
subscription_count = len(self._subscriptions)
|
|
1121
|
+
health['subscription_count'] = subscription_count
|
|
1122
|
+
if subscription_count > 1000:
|
|
1123
|
+
health['issues'].append(f"High subscription count: {subscription_count}")
|
|
1124
|
+
health['healthy'] = False
|
|
1125
|
+
|
|
1126
|
+
# Check pending messages (if reliability enabled)
|
|
1127
|
+
if self._reliability_enabled:
|
|
1128
|
+
pending = self.get_pending_messages()
|
|
1129
|
+
health['pending_message_count'] = len(pending)
|
|
1130
|
+
if len(pending) > 100:
|
|
1131
|
+
health['issues'].append(f"High pending message count: {len(pending)}")
|
|
1132
|
+
health['healthy'] = False
|
|
1133
|
+
|
|
1134
|
+
return health
|
|
1135
|
+
|
|
1136
|
+
# Context manager support
|
|
1137
|
+
async def __aenter__(self) -> "Workflow":
|
|
1138
|
+
"""Async context manager entry."""
|
|
1139
|
+
await self.start()
|
|
1140
|
+
return self
|
|
1141
|
+
|
|
1142
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
1143
|
+
"""Async context manager exit."""
|
|
1144
|
+
await self.stop()
|