puffinflow 2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. puffinflow/__init__.py +132 -0
  2. puffinflow/core/__init__.py +110 -0
  3. puffinflow/core/agent/__init__.py +320 -0
  4. puffinflow/core/agent/base.py +1635 -0
  5. puffinflow/core/agent/checkpoint.py +50 -0
  6. puffinflow/core/agent/context.py +521 -0
  7. puffinflow/core/agent/decorators/__init__.py +90 -0
  8. puffinflow/core/agent/decorators/builder.py +454 -0
  9. puffinflow/core/agent/decorators/flexible.py +714 -0
  10. puffinflow/core/agent/decorators/inspection.py +144 -0
  11. puffinflow/core/agent/dependencies.py +57 -0
  12. puffinflow/core/agent/scheduling/__init__.py +21 -0
  13. puffinflow/core/agent/scheduling/builder.py +160 -0
  14. puffinflow/core/agent/scheduling/exceptions.py +35 -0
  15. puffinflow/core/agent/scheduling/inputs.py +137 -0
  16. puffinflow/core/agent/scheduling/parser.py +209 -0
  17. puffinflow/core/agent/scheduling/scheduler.py +413 -0
  18. puffinflow/core/agent/state.py +141 -0
  19. puffinflow/core/config.py +62 -0
  20. puffinflow/core/coordination/__init__.py +137 -0
  21. puffinflow/core/coordination/agent_group.py +359 -0
  22. puffinflow/core/coordination/agent_pool.py +629 -0
  23. puffinflow/core/coordination/agent_team.py +577 -0
  24. puffinflow/core/coordination/coordinator.py +720 -0
  25. puffinflow/core/coordination/deadlock.py +1759 -0
  26. puffinflow/core/coordination/fluent_api.py +421 -0
  27. puffinflow/core/coordination/primitives.py +478 -0
  28. puffinflow/core/coordination/rate_limiter.py +520 -0
  29. puffinflow/core/observability/__init__.py +47 -0
  30. puffinflow/core/observability/agent.py +139 -0
  31. puffinflow/core/observability/alerting.py +73 -0
  32. puffinflow/core/observability/config.py +127 -0
  33. puffinflow/core/observability/context.py +88 -0
  34. puffinflow/core/observability/core.py +147 -0
  35. puffinflow/core/observability/decorators.py +105 -0
  36. puffinflow/core/observability/events.py +71 -0
  37. puffinflow/core/observability/interfaces.py +196 -0
  38. puffinflow/core/observability/metrics.py +137 -0
  39. puffinflow/core/observability/tracing.py +209 -0
  40. puffinflow/core/reliability/__init__.py +27 -0
  41. puffinflow/core/reliability/bulkhead.py +96 -0
  42. puffinflow/core/reliability/circuit_breaker.py +149 -0
  43. puffinflow/core/reliability/leak_detector.py +122 -0
  44. puffinflow/core/resources/__init__.py +77 -0
  45. puffinflow/core/resources/allocation.py +790 -0
  46. puffinflow/core/resources/pool.py +645 -0
  47. puffinflow/core/resources/quotas.py +567 -0
  48. puffinflow/core/resources/requirements.py +217 -0
  49. puffinflow/version.py +21 -0
  50. puffinflow-2.dev0.dist-info/METADATA +334 -0
  51. puffinflow-2.dev0.dist-info/RECORD +55 -0
  52. puffinflow-2.dev0.dist-info/WHEEL +5 -0
  53. puffinflow-2.dev0.dist-info/entry_points.txt +3 -0
  54. puffinflow-2.dev0.dist-info/licenses/LICENSE +21 -0
  55. puffinflow-2.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,141 @@
1
+ """State management types and enums."""
2
+
3
+ import asyncio
4
+ import random
5
+ import uuid
6
+ from dataclasses import dataclass, field
7
+ from enum import Enum, IntEnum
8
+ from typing import TYPE_CHECKING, Any, Optional, Union, runtime_checkable
9
+
10
+ from typing_extensions import Protocol
11
+
12
+ if TYPE_CHECKING:
13
+ from ..resources.requirements import ResourceRequirements
14
+ from .context import Context
15
+
16
+ try:
17
+ from ..resources.requirements import ResourceRequirements
18
+ except ImportError:
19
+ ResourceRequirements = None # type: ignore
20
+
21
+ # Type definitions
22
+ from typing import TYPE_CHECKING
23
+
24
+ if TYPE_CHECKING:
25
+ from .base import Agent
26
+
27
+ StateResult = Union[str, list[Union[str, tuple["Agent", str]]], None]
28
+
29
+
30
+ class Priority(IntEnum):
31
+ """Priority levels for state execution."""
32
+
33
+ LOW = 0
34
+ NORMAL = 1
35
+ HIGH = 2
36
+ CRITICAL = 3
37
+
38
+
39
+ class AgentStatus(str, Enum):
40
+ """Agent execution status."""
41
+
42
+ IDLE = "idle"
43
+ RUNNING = "running"
44
+ PAUSED = "paused"
45
+ COMPLETED = "completed"
46
+ FAILED = "failed"
47
+ CANCELLED = "cancelled"
48
+
49
+ def __str__(self) -> str:
50
+ return self.value
51
+
52
+
53
+ class StateStatus(str, Enum):
54
+ """State execution status."""
55
+
56
+ PENDING = "pending"
57
+ READY = "ready"
58
+ RUNNING = "running"
59
+ COMPLETED = "completed"
60
+ FAILED = "failed"
61
+ CANCELLED = "cancelled"
62
+ BLOCKED = "blocked"
63
+ TIMEOUT = "timeout"
64
+ RETRYING = "retrying"
65
+
66
+ def __str__(self) -> str:
67
+ return self.value
68
+
69
+
70
+ @runtime_checkable
71
+ class StateFunction(Protocol):
72
+ """Protocol for state functions."""
73
+
74
+ async def __call__(self, context: "Context") -> StateResult:
75
+ ...
76
+
77
+
78
+ @dataclass
79
+ class RetryPolicy:
80
+ max_retries: int = 3
81
+ initial_delay: float = 1.0
82
+ exponential_base: float = 2.0
83
+ jitter: bool = True
84
+ # Dead letter handling
85
+ dead_letter_on_max_retries: bool = True
86
+ dead_letter_on_timeout: bool = True
87
+
88
+ async def wait(self, attempt: int) -> None:
89
+ delay = min(
90
+ self.initial_delay * (self.exponential_base**attempt),
91
+ 60.0, # Max 60 seconds
92
+ )
93
+ if self.jitter:
94
+ delay *= 0.5 + random.random() * 0.5
95
+ await asyncio.sleep(delay)
96
+
97
+
98
+ # Dead letter data structure
99
+ @dataclass
100
+ class DeadLetter:
101
+ state_name: str
102
+ agent_name: str
103
+ error_message: str
104
+ error_type: str
105
+ attempts: int
106
+ failed_at: float
107
+ timeout_occurred: bool = False
108
+ context_snapshot: dict[str, Any] = field(default_factory=dict)
109
+
110
+
111
+ @dataclass
112
+ class StateMetadata:
113
+ """Metadata for state execution."""
114
+
115
+ status: StateStatus
116
+ attempts: int = 0
117
+ max_retries: int = 3
118
+ resources: Optional["ResourceRequirements"] = None
119
+ dependencies: dict[str, Any] = field(default_factory=dict)
120
+ satisfied_dependencies: set = field(default_factory=set)
121
+ last_execution: Optional[float] = None
122
+ last_success: Optional[float] = None
123
+ state_id: str = field(default_factory=lambda: str(uuid.uuid4()))
124
+ retry_policy: Optional[RetryPolicy] = None
125
+ priority: Priority = Priority.NORMAL
126
+ coordination_primitives: list[Any] = field(default_factory=list)
127
+
128
+ def __post_init__(self) -> None:
129
+ """Initialize resources if not provided."""
130
+ if self.resources is None and ResourceRequirements is not None:
131
+ self.resources = ResourceRequirements()
132
+
133
+
134
+ @dataclass(order=True)
135
+ class PrioritizedState:
136
+ """State with priority for queue management."""
137
+
138
+ priority: int
139
+ timestamp: float
140
+ state_name: str = field(compare=False)
141
+ metadata: StateMetadata = field(compare=False)
@@ -0,0 +1,62 @@
1
+ # core/config.py
2
+ from functools import lru_cache
3
+ from typing import Optional
4
+
5
+ from pydantic import Field
6
+ from pydantic_settings import BaseSettings
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ app_name: str = "PuffinFlow"
11
+ environment: str = Field(default="development", alias="ENVIRONMENT")
12
+ debug: bool = Field(default=False, alias="DEBUG")
13
+
14
+ # Resource limits
15
+ max_cpu_units: float = Field(default=4.0, alias="MAX_CPU_UNITS")
16
+ max_memory_mb: float = Field(default=4096.0, alias="MAX_MEMORY_MB")
17
+ max_io_weight: float = Field(default=100.0, alias="MAX_IO_WEIGHT")
18
+ max_network_weight: float = Field(default=100.0, alias="MAX_NETWORK_WEIGHT")
19
+ max_gpu_units: float = Field(default=0.0, alias="MAX_GPU_UNITS")
20
+
21
+ # Worker configuration
22
+ worker_concurrency: int = Field(default=10, alias="WORKER_CONCURRENCY")
23
+ worker_timeout: float = Field(default=300.0, alias="WORKER_TIMEOUT")
24
+
25
+ # Observability
26
+ enable_metrics: bool = Field(default=True, alias="ENABLE_METRICS")
27
+ metrics_port: int = Field(default=9090, alias="METRICS_PORT")
28
+ otlp_endpoint: Optional[str] = Field(default=None, alias="OTLP_ENDPOINT")
29
+
30
+ # Core features that are implemented
31
+ enable_scheduling: bool = Field(default=True, alias="ENABLE_SCHEDULING")
32
+
33
+ # Storage and checkpointing
34
+ storage_backend: str = Field(default="sqlite", alias="STORAGE_BACKEND")
35
+ checkpoint_interval: int = Field(default=60, alias="CHECKPOINT_INTERVAL")
36
+
37
+ class Config:
38
+ env_file = ".env"
39
+ case_sensitive = False
40
+
41
+
42
+ class Features:
43
+ def __init__(self, settings: Settings):
44
+ self._settings = settings
45
+
46
+ @property
47
+ def scheduling(self) -> bool:
48
+ return self._settings.enable_scheduling
49
+
50
+ @property
51
+ def metrics(self) -> bool:
52
+ return self._settings.enable_metrics
53
+
54
+
55
+ @lru_cache
56
+ def get_settings() -> Settings:
57
+ return Settings()
58
+
59
+
60
+ @lru_cache
61
+ def get_features() -> Features:
62
+ return Features(get_settings())
@@ -0,0 +1,137 @@
1
+ """Enhanced coordination module for multi-agent workflows."""
2
+
3
+ from .agent_group import (
4
+ AgentGroup,
5
+ AgentOrchestrator,
6
+ ExecutionStrategy,
7
+ GroupResult,
8
+ OrchestrationExecution,
9
+ OrchestrationResult,
10
+ ParallelAgentGroup,
11
+ StageConfig,
12
+ )
13
+ from .agent_pool import (
14
+ AgentPool,
15
+ CompletedWork,
16
+ DynamicProcessingPool,
17
+ PoolContext,
18
+ ScalingPolicy,
19
+ WorkItem,
20
+ WorkProcessor,
21
+ WorkQueue,
22
+ )
23
+ from .agent_team import (
24
+ AgentTeam,
25
+ Event,
26
+ EventBus,
27
+ Message,
28
+ TeamResult,
29
+ create_team,
30
+ run_agents_parallel,
31
+ run_agents_sequential,
32
+ )
33
+ from .fluent_api import (
34
+ Agents,
35
+ ConditionalAgents,
36
+ FluentResult,
37
+ PipelineAgents,
38
+ collect_agent_outputs,
39
+ create_agent_team,
40
+ create_pipeline,
41
+ get_best_agent,
42
+ run_parallel_agents,
43
+ run_sequential_agents,
44
+ )
45
+
46
+ # Import existing coordination components
47
+ try:
48
+ from .coordinator import AgentCoordinator, CoordinationConfig, enhance_agent
49
+ from .deadlock import DeadlockDetector, DeadlockResolutionStrategy
50
+ from .primitives import (
51
+ Barrier,
52
+ CoordinationPrimitive,
53
+ Lease,
54
+ Lock,
55
+ Mutex,
56
+ PrimitiveType,
57
+ Quota,
58
+ Semaphore,
59
+ create_primitive,
60
+ )
61
+ from .rate_limiter import (
62
+ AdaptiveRateLimiter,
63
+ CompositeRateLimiter,
64
+ FixedWindow,
65
+ LeakyBucket,
66
+ RateLimiter,
67
+ RateLimitStrategy,
68
+ SlidingWindow,
69
+ TokenBucket,
70
+ )
71
+ except ImportError:
72
+ # Some coordination components may not be available
73
+ pass
74
+
75
+ __all__ = [
76
+ "AdaptiveRateLimiter",
77
+ # Existing coordination (if available)
78
+ "AgentCoordinator",
79
+ # Group coordination
80
+ "AgentGroup",
81
+ "AgentOrchestrator",
82
+ # Agent pools
83
+ "AgentPool",
84
+ # Team coordination
85
+ "AgentTeam",
86
+ # Fluent APIs
87
+ "Agents",
88
+ "Barrier",
89
+ "CompletedWork",
90
+ "CompositeRateLimiter",
91
+ "ConditionalAgents",
92
+ "CoordinationConfig",
93
+ "CoordinationPrimitive",
94
+ "DeadlockDetector",
95
+ "DeadlockResolutionStrategy",
96
+ "DynamicProcessingPool",
97
+ "Event",
98
+ "EventBus",
99
+ "ExecutionStrategy",
100
+ "FixedWindow",
101
+ "FluentResult",
102
+ "GroupResult",
103
+ "LeakyBucket",
104
+ "Lease",
105
+ "Lock",
106
+ "Message",
107
+ "Mutex",
108
+ "OrchestrationExecution",
109
+ "OrchestrationResult",
110
+ "ParallelAgentGroup",
111
+ "PipelineAgents",
112
+ "PoolContext",
113
+ "PrimitiveType",
114
+ "Quota",
115
+ "RateLimitStrategy",
116
+ "RateLimiter",
117
+ "ScalingPolicy",
118
+ "Semaphore",
119
+ "SlidingWindow",
120
+ "StageConfig",
121
+ "TeamResult",
122
+ "TokenBucket",
123
+ "WorkItem",
124
+ "WorkProcessor",
125
+ "WorkQueue",
126
+ "collect_agent_outputs",
127
+ "create_agent_team",
128
+ "create_pipeline",
129
+ "create_primitive",
130
+ "create_team",
131
+ "enhance_agent",
132
+ "get_best_agent",
133
+ "run_agents_parallel",
134
+ "run_agents_sequential",
135
+ "run_parallel_agents",
136
+ "run_sequential_agents",
137
+ ]
@@ -0,0 +1,359 @@
1
+ """Agent groups for advanced coordination patterns."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Callable, Optional
6
+
7
+ from ..agent.base import Agent, AgentResult
8
+ from .agent_team import AgentTeam, TeamResult
9
+
10
+ if TYPE_CHECKING:
11
+ import asyncio
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ExecutionStrategy:
17
+ """Execution strategy constants."""
18
+
19
+ PARALLEL = "parallel"
20
+ SEQUENTIAL = "sequential"
21
+ PIPELINE = "pipeline"
22
+ FAN_OUT = "fan_out"
23
+ FAN_IN = "fan_in"
24
+ CONDITIONAL = "conditional"
25
+
26
+
27
+ @dataclass
28
+ class StageConfig:
29
+ """Configuration for execution stage."""
30
+
31
+ name: str
32
+ agents: list[str]
33
+ strategy: str = ExecutionStrategy.PARALLEL
34
+ depends_on: Optional[list[str]] = None
35
+ condition: Optional[Callable] = None
36
+ timeout: Optional[float] = None
37
+
38
+ def __post_init__(self) -> None:
39
+ if self.depends_on is None:
40
+ self.depends_on = []
41
+
42
+
43
+ class AgentGroup:
44
+ """Simple agent group for basic parallel execution."""
45
+
46
+ def __init__(self, agents: list[Agent]):
47
+ self.agents = {agent.name: agent for agent in agents}
48
+ self._results: dict[str, AgentResult] = {}
49
+
50
+ async def run_parallel(self, timeout: Optional[float] = None) -> "GroupResult":
51
+ """Run all agents in parallel."""
52
+ team = AgentTeam("group_execution")
53
+ team.add_agents(list(self.agents.values()))
54
+
55
+ result = await team.run_parallel(timeout)
56
+ return GroupResult(result)
57
+
58
+ async def collect_all(self) -> "GroupResult":
59
+ """Run and collect all results."""
60
+ return await self.run_parallel()
61
+
62
+
63
+ class GroupResult:
64
+ """Wrapper for team results with additional methods."""
65
+
66
+ def __init__(self, team_result: TeamResult):
67
+ self._team_result = team_result
68
+
69
+ def __getattr__(self, name: str) -> Any:
70
+ """Delegate to team result."""
71
+ return getattr(self._team_result, name)
72
+
73
+ @property
74
+ def agents(self) -> list[AgentResult]:
75
+ """Get list of agent results."""
76
+ return list(self._team_result.agent_results.values())
77
+
78
+
79
+ class ParallelAgentGroup(AgentGroup):
80
+ """Enhanced parallel agent group."""
81
+
82
+ def __init__(self, name: str):
83
+ self.name = name
84
+ self._agents: list[Agent] = []
85
+
86
+ def add_agent(self, agent: Agent) -> "ParallelAgentGroup":
87
+ """Add agent to group."""
88
+ self._agents.append(agent)
89
+ return self
90
+
91
+ async def run_and_collect(self, timeout: Optional[float] = None) -> GroupResult:
92
+ """Run all agents and collect results."""
93
+ return await AgentGroup(self._agents).run_parallel(timeout)
94
+
95
+
96
+ class AgentOrchestrator:
97
+ """Advanced agent orchestrator with multiple execution strategies."""
98
+
99
+ def __init__(self, name: str):
100
+ self.name = name
101
+ self._agents: dict[str, Agent] = {}
102
+ self._stages: list[StageConfig] = []
103
+ self._global_variables: dict[str, Any] = {}
104
+ self._stage_results: dict[str, TeamResult] = {}
105
+ self._execution_context: dict[str, Any] = {}
106
+
107
+ def add_agent(self, agent: Agent) -> "AgentOrchestrator":
108
+ """Add agent to orchestrator."""
109
+ self._agents[agent.name] = agent
110
+ return self
111
+
112
+ def add_agents(self, agents: list[Agent]) -> "AgentOrchestrator":
113
+ """Add multiple agents."""
114
+ for agent in agents:
115
+ self.add_agent(agent)
116
+ return self
117
+
118
+ def add_stage(
119
+ self,
120
+ name: str,
121
+ agents: list[Agent],
122
+ strategy: str = ExecutionStrategy.PARALLEL,
123
+ depends_on: Optional[list[str]] = None,
124
+ condition: Optional[Callable] = None,
125
+ timeout: Optional[float] = None,
126
+ ) -> "AgentOrchestrator":
127
+ """Add execution stage."""
128
+ # Add agents if not already added
129
+ for agent in agents:
130
+ if agent.name not in self._agents:
131
+ self.add_agent(agent)
132
+
133
+ stage = StageConfig(
134
+ name=name,
135
+ agents=[agent.name for agent in agents],
136
+ strategy=strategy,
137
+ depends_on=depends_on or [],
138
+ condition=condition,
139
+ timeout=timeout,
140
+ )
141
+
142
+ self._stages.append(stage)
143
+ return self
144
+
145
+ def set_global_variable(self, key: str, value: Any) -> None:
146
+ """Set global variable for all agents."""
147
+ self._global_variables[key] = value
148
+ for agent in self._agents.values():
149
+ agent.set_shared_variable(key, value)
150
+
151
+ def get_global_variable(self, key: str, default: Any = None) -> Any:
152
+ """Get global variable."""
153
+ return self._global_variables.get(key, default)
154
+
155
+ def run_with_monitoring(self) -> "OrchestrationExecution":
156
+ """Run with monitoring capability."""
157
+ return OrchestrationExecution(self)
158
+
159
+ async def run(self) -> "OrchestrationResult":
160
+ """Run the complete orchestration."""
161
+ async with self.run_with_monitoring() as execution:
162
+ result = await execution.wait_for_completion()
163
+ return result
164
+
165
+
166
+ class OrchestrationExecution:
167
+ """Context manager for orchestration execution with monitoring."""
168
+
169
+ def __init__(self, orchestrator: AgentOrchestrator):
170
+ self.orchestrator = orchestrator
171
+ self.start_time: Optional[float] = None
172
+ self.end_time: Optional[float] = None
173
+ self._stage_results: dict[str, TeamResult] = {}
174
+ self._completed_stages: set[str] = set()
175
+ self._running_stages: dict[str, asyncio.Task] = {}
176
+ self._stage_timings: dict[str, tuple] = {}
177
+
178
+ async def __aenter__(self) -> "OrchestrationExecution":
179
+ """Start execution."""
180
+ import time
181
+
182
+ self.start_time = time.time()
183
+ return self
184
+
185
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
186
+ """End execution."""
187
+ import time
188
+
189
+ self.end_time = time.time()
190
+
191
+ # Cancel any running stages
192
+ for task in self._running_stages.values():
193
+ if not task.done():
194
+ task.cancel()
195
+
196
+ async def wait_for_stage(self, stage_name: str) -> TeamResult:
197
+ """Wait for specific stage to complete."""
198
+ if stage_name in self._completed_stages:
199
+ return self._stage_results[stage_name]
200
+
201
+ # Find stage config
202
+ stage_config = None
203
+ for stage in self.orchestrator._stages:
204
+ if stage.name == stage_name:
205
+ stage_config = stage
206
+ break
207
+
208
+ if not stage_config:
209
+ raise ValueError(f"Stage {stage_name} not found")
210
+
211
+ # Check dependencies
212
+ for dep in stage_config.depends_on or []:
213
+ if dep not in self._completed_stages:
214
+ await self.wait_for_stage(dep)
215
+
216
+ # Check condition
217
+ if stage_config.condition and not stage_config.condition():
218
+ # Create empty result for skipped stage
219
+ result = TeamResult(
220
+ team_name=f"{self.orchestrator.name}_{stage_name}", status="skipped"
221
+ )
222
+ self._stage_results[stage_name] = result
223
+ self._completed_stages.add(stage_name)
224
+ return result
225
+
226
+ # Execute stage
227
+ return await self._execute_stage(stage_config)
228
+
229
+ async def _execute_stage(self, stage_config: StageConfig) -> TeamResult:
230
+ """Execute a single stage."""
231
+ import time
232
+
233
+ stage_start = time.time()
234
+
235
+ # Get agents for this stage
236
+ stage_agents = [
237
+ self.orchestrator._agents[name]
238
+ for name in stage_config.agents
239
+ if name in self.orchestrator._agents
240
+ ]
241
+
242
+ if not stage_agents:
243
+ raise ValueError(f"No valid agents found for stage {stage_config.name}")
244
+
245
+ # Create team for stage
246
+ team = AgentTeam(f"{self.orchestrator.name}_{stage_config.name}")
247
+ team.add_agents(stage_agents)
248
+
249
+ # Set global variables
250
+ for key, value in self.orchestrator._global_variables.items():
251
+ team.set_global_variable(key, value)
252
+
253
+ # Execute based on strategy
254
+ if stage_config.strategy == ExecutionStrategy.PARALLEL:
255
+ result = await team.run_parallel(stage_config.timeout)
256
+ elif stage_config.strategy == ExecutionStrategy.SEQUENTIAL:
257
+ result = await team.run_sequential()
258
+ else:
259
+ # Default to parallel
260
+ result = await team.run_parallel(stage_config.timeout)
261
+
262
+ stage_end = time.time()
263
+
264
+ # Store results
265
+ self._stage_results[stage_config.name] = result
266
+ self._completed_stages.add(stage_config.name)
267
+ self._stage_timings[stage_config.name] = (stage_start, stage_end)
268
+
269
+ return result
270
+
271
+ def set_stage_input(self, stage_name: str, key: str, value: Any) -> None:
272
+ """Set input for a stage."""
273
+ # Set variable for all agents in the stage
274
+ stage_config = None
275
+ for stage in self.orchestrator._stages:
276
+ if stage.name == stage_name:
277
+ stage_config = stage
278
+ break
279
+
280
+ if stage_config:
281
+ for agent_name in stage_config.agents:
282
+ agent = self.orchestrator._agents.get(agent_name)
283
+ if agent:
284
+ agent.set_variable(key, value)
285
+
286
+ def is_stage_complete(self, stage_name: str) -> bool:
287
+ """Check if stage is complete."""
288
+ return stage_name in self._completed_stages
289
+
290
+ def get_stage_results(self, stage_name: str) -> Optional[TeamResult]:
291
+ """Get results for a stage."""
292
+ return self._stage_results.get(stage_name)
293
+
294
+ async def get_final_results(self) -> "OrchestrationResult":
295
+ """Get final orchestration results."""
296
+ # Wait for all stages
297
+ for stage in self.orchestrator._stages:
298
+ if stage.name not in self._completed_stages:
299
+ await self.wait_for_stage(stage.name)
300
+
301
+ return OrchestrationResult(
302
+ orchestrator_name=self.orchestrator.name,
303
+ stage_results=self._stage_results,
304
+ stage_timings=self._stage_timings,
305
+ total_duration=self.total_duration,
306
+ )
307
+
308
+ async def wait_for_completion(self) -> "OrchestrationResult":
309
+ """Wait for complete orchestration."""
310
+ return await self.get_final_results()
311
+
312
+ @property
313
+ def total_duration(self) -> Optional[float]:
314
+ """Get total execution duration."""
315
+ if self.start_time and self.end_time:
316
+ return self.end_time - self.start_time
317
+ return None
318
+
319
+ def get_stage_timings(self) -> dict[str, float]:
320
+ """Get timing for each stage."""
321
+ return {
322
+ stage: end - start for stage, (start, end) in self._stage_timings.items()
323
+ }
324
+
325
+
326
+ @dataclass
327
+ class OrchestrationResult:
328
+ """Result of orchestration execution."""
329
+
330
+ orchestrator_name: str
331
+ stage_results: dict[str, TeamResult]
332
+ stage_timings: dict[str, tuple]
333
+ total_duration: Optional[float]
334
+
335
+ def get_stage_result(self, stage_name: str) -> Optional[TeamResult]:
336
+ """Get result for specific stage."""
337
+ return self.stage_results.get(stage_name)
338
+
339
+ def get_final_result(self) -> dict[str, Any]:
340
+ """Get final aggregated result."""
341
+ # Combine results from all stages
342
+ all_outputs = {}
343
+ all_variables = {}
344
+
345
+ for stage_name, stage_result in self.stage_results.items():
346
+ # Collect outputs from all agents in stage
347
+ for agent_name, agent_result in stage_result.agent_results.items():
348
+ all_outputs[f"{stage_name}_{agent_name}"] = agent_result.outputs
349
+ all_variables[f"{stage_name}_{agent_name}"] = agent_result.variables
350
+
351
+ return {
352
+ "outputs": all_outputs,
353
+ "variables": all_variables,
354
+ "stage_count": len(self.stage_results),
355
+ "total_duration": self.total_duration,
356
+ "stage_durations": {
357
+ stage: end - start for stage, (start, end) in self.stage_timings.items()
358
+ },
359
+ }