puffinflow 2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- puffinflow/__init__.py +132 -0
- puffinflow/core/__init__.py +110 -0
- puffinflow/core/agent/__init__.py +320 -0
- puffinflow/core/agent/base.py +1635 -0
- puffinflow/core/agent/checkpoint.py +50 -0
- puffinflow/core/agent/context.py +521 -0
- puffinflow/core/agent/decorators/__init__.py +90 -0
- puffinflow/core/agent/decorators/builder.py +454 -0
- puffinflow/core/agent/decorators/flexible.py +714 -0
- puffinflow/core/agent/decorators/inspection.py +144 -0
- puffinflow/core/agent/dependencies.py +57 -0
- puffinflow/core/agent/scheduling/__init__.py +21 -0
- puffinflow/core/agent/scheduling/builder.py +160 -0
- puffinflow/core/agent/scheduling/exceptions.py +35 -0
- puffinflow/core/agent/scheduling/inputs.py +137 -0
- puffinflow/core/agent/scheduling/parser.py +209 -0
- puffinflow/core/agent/scheduling/scheduler.py +413 -0
- puffinflow/core/agent/state.py +141 -0
- puffinflow/core/config.py +62 -0
- puffinflow/core/coordination/__init__.py +137 -0
- puffinflow/core/coordination/agent_group.py +359 -0
- puffinflow/core/coordination/agent_pool.py +629 -0
- puffinflow/core/coordination/agent_team.py +577 -0
- puffinflow/core/coordination/coordinator.py +720 -0
- puffinflow/core/coordination/deadlock.py +1759 -0
- puffinflow/core/coordination/fluent_api.py +421 -0
- puffinflow/core/coordination/primitives.py +478 -0
- puffinflow/core/coordination/rate_limiter.py +520 -0
- puffinflow/core/observability/__init__.py +47 -0
- puffinflow/core/observability/agent.py +139 -0
- puffinflow/core/observability/alerting.py +73 -0
- puffinflow/core/observability/config.py +127 -0
- puffinflow/core/observability/context.py +88 -0
- puffinflow/core/observability/core.py +147 -0
- puffinflow/core/observability/decorators.py +105 -0
- puffinflow/core/observability/events.py +71 -0
- puffinflow/core/observability/interfaces.py +196 -0
- puffinflow/core/observability/metrics.py +137 -0
- puffinflow/core/observability/tracing.py +209 -0
- puffinflow/core/reliability/__init__.py +27 -0
- puffinflow/core/reliability/bulkhead.py +96 -0
- puffinflow/core/reliability/circuit_breaker.py +149 -0
- puffinflow/core/reliability/leak_detector.py +122 -0
- puffinflow/core/resources/__init__.py +77 -0
- puffinflow/core/resources/allocation.py +790 -0
- puffinflow/core/resources/pool.py +645 -0
- puffinflow/core/resources/quotas.py +567 -0
- puffinflow/core/resources/requirements.py +217 -0
- puffinflow/version.py +21 -0
- puffinflow-2.dev0.dist-info/METADATA +334 -0
- puffinflow-2.dev0.dist-info/RECORD +55 -0
- puffinflow-2.dev0.dist-info/WHEEL +5 -0
- puffinflow-2.dev0.dist-info/entry_points.txt +3 -0
- puffinflow-2.dev0.dist-info/licenses/LICENSE +21 -0
- puffinflow-2.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""State management types and enums."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import random
|
|
5
|
+
import uuid
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from enum import Enum, IntEnum
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional, Union, runtime_checkable
|
|
9
|
+
|
|
10
|
+
from typing_extensions import Protocol
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..resources.requirements import ResourceRequirements
|
|
14
|
+
from .context import Context
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from ..resources.requirements import ResourceRequirements
|
|
18
|
+
except ImportError:
|
|
19
|
+
ResourceRequirements = None # type: ignore
|
|
20
|
+
|
|
21
|
+
# Type definitions
|
|
22
|
+
from typing import TYPE_CHECKING
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from .base import Agent
|
|
26
|
+
|
|
27
|
+
StateResult = Union[str, list[Union[str, tuple["Agent", str]]], None]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Priority(IntEnum):
|
|
31
|
+
"""Priority levels for state execution."""
|
|
32
|
+
|
|
33
|
+
LOW = 0
|
|
34
|
+
NORMAL = 1
|
|
35
|
+
HIGH = 2
|
|
36
|
+
CRITICAL = 3
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AgentStatus(str, Enum):
|
|
40
|
+
"""Agent execution status."""
|
|
41
|
+
|
|
42
|
+
IDLE = "idle"
|
|
43
|
+
RUNNING = "running"
|
|
44
|
+
PAUSED = "paused"
|
|
45
|
+
COMPLETED = "completed"
|
|
46
|
+
FAILED = "failed"
|
|
47
|
+
CANCELLED = "cancelled"
|
|
48
|
+
|
|
49
|
+
def __str__(self) -> str:
|
|
50
|
+
return self.value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class StateStatus(str, Enum):
|
|
54
|
+
"""State execution status."""
|
|
55
|
+
|
|
56
|
+
PENDING = "pending"
|
|
57
|
+
READY = "ready"
|
|
58
|
+
RUNNING = "running"
|
|
59
|
+
COMPLETED = "completed"
|
|
60
|
+
FAILED = "failed"
|
|
61
|
+
CANCELLED = "cancelled"
|
|
62
|
+
BLOCKED = "blocked"
|
|
63
|
+
TIMEOUT = "timeout"
|
|
64
|
+
RETRYING = "retrying"
|
|
65
|
+
|
|
66
|
+
def __str__(self) -> str:
|
|
67
|
+
return self.value
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@runtime_checkable
|
|
71
|
+
class StateFunction(Protocol):
|
|
72
|
+
"""Protocol for state functions."""
|
|
73
|
+
|
|
74
|
+
async def __call__(self, context: "Context") -> StateResult:
|
|
75
|
+
...
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class RetryPolicy:
|
|
80
|
+
max_retries: int = 3
|
|
81
|
+
initial_delay: float = 1.0
|
|
82
|
+
exponential_base: float = 2.0
|
|
83
|
+
jitter: bool = True
|
|
84
|
+
# Dead letter handling
|
|
85
|
+
dead_letter_on_max_retries: bool = True
|
|
86
|
+
dead_letter_on_timeout: bool = True
|
|
87
|
+
|
|
88
|
+
async def wait(self, attempt: int) -> None:
|
|
89
|
+
delay = min(
|
|
90
|
+
self.initial_delay * (self.exponential_base**attempt),
|
|
91
|
+
60.0, # Max 60 seconds
|
|
92
|
+
)
|
|
93
|
+
if self.jitter:
|
|
94
|
+
delay *= 0.5 + random.random() * 0.5
|
|
95
|
+
await asyncio.sleep(delay)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Dead letter data structure
|
|
99
|
+
@dataclass
|
|
100
|
+
class DeadLetter:
|
|
101
|
+
state_name: str
|
|
102
|
+
agent_name: str
|
|
103
|
+
error_message: str
|
|
104
|
+
error_type: str
|
|
105
|
+
attempts: int
|
|
106
|
+
failed_at: float
|
|
107
|
+
timeout_occurred: bool = False
|
|
108
|
+
context_snapshot: dict[str, Any] = field(default_factory=dict)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@dataclass
|
|
112
|
+
class StateMetadata:
|
|
113
|
+
"""Metadata for state execution."""
|
|
114
|
+
|
|
115
|
+
status: StateStatus
|
|
116
|
+
attempts: int = 0
|
|
117
|
+
max_retries: int = 3
|
|
118
|
+
resources: Optional["ResourceRequirements"] = None
|
|
119
|
+
dependencies: dict[str, Any] = field(default_factory=dict)
|
|
120
|
+
satisfied_dependencies: set = field(default_factory=set)
|
|
121
|
+
last_execution: Optional[float] = None
|
|
122
|
+
last_success: Optional[float] = None
|
|
123
|
+
state_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
124
|
+
retry_policy: Optional[RetryPolicy] = None
|
|
125
|
+
priority: Priority = Priority.NORMAL
|
|
126
|
+
coordination_primitives: list[Any] = field(default_factory=list)
|
|
127
|
+
|
|
128
|
+
def __post_init__(self) -> None:
|
|
129
|
+
"""Initialize resources if not provided."""
|
|
130
|
+
if self.resources is None and ResourceRequirements is not None:
|
|
131
|
+
self.resources = ResourceRequirements()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass(order=True)
|
|
135
|
+
class PrioritizedState:
|
|
136
|
+
"""State with priority for queue management."""
|
|
137
|
+
|
|
138
|
+
priority: int
|
|
139
|
+
timestamp: float
|
|
140
|
+
state_name: str = field(compare=False)
|
|
141
|
+
metadata: StateMetadata = field(compare=False)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# core/config.py
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
from pydantic_settings import BaseSettings
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Settings(BaseSettings):
|
|
10
|
+
app_name: str = "PuffinFlow"
|
|
11
|
+
environment: str = Field(default="development", alias="ENVIRONMENT")
|
|
12
|
+
debug: bool = Field(default=False, alias="DEBUG")
|
|
13
|
+
|
|
14
|
+
# Resource limits
|
|
15
|
+
max_cpu_units: float = Field(default=4.0, alias="MAX_CPU_UNITS")
|
|
16
|
+
max_memory_mb: float = Field(default=4096.0, alias="MAX_MEMORY_MB")
|
|
17
|
+
max_io_weight: float = Field(default=100.0, alias="MAX_IO_WEIGHT")
|
|
18
|
+
max_network_weight: float = Field(default=100.0, alias="MAX_NETWORK_WEIGHT")
|
|
19
|
+
max_gpu_units: float = Field(default=0.0, alias="MAX_GPU_UNITS")
|
|
20
|
+
|
|
21
|
+
# Worker configuration
|
|
22
|
+
worker_concurrency: int = Field(default=10, alias="WORKER_CONCURRENCY")
|
|
23
|
+
worker_timeout: float = Field(default=300.0, alias="WORKER_TIMEOUT")
|
|
24
|
+
|
|
25
|
+
# Observability
|
|
26
|
+
enable_metrics: bool = Field(default=True, alias="ENABLE_METRICS")
|
|
27
|
+
metrics_port: int = Field(default=9090, alias="METRICS_PORT")
|
|
28
|
+
otlp_endpoint: Optional[str] = Field(default=None, alias="OTLP_ENDPOINT")
|
|
29
|
+
|
|
30
|
+
# Core features that are implemented
|
|
31
|
+
enable_scheduling: bool = Field(default=True, alias="ENABLE_SCHEDULING")
|
|
32
|
+
|
|
33
|
+
# Storage and checkpointing
|
|
34
|
+
storage_backend: str = Field(default="sqlite", alias="STORAGE_BACKEND")
|
|
35
|
+
checkpoint_interval: int = Field(default=60, alias="CHECKPOINT_INTERVAL")
|
|
36
|
+
|
|
37
|
+
class Config:
|
|
38
|
+
env_file = ".env"
|
|
39
|
+
case_sensitive = False
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Features:
|
|
43
|
+
def __init__(self, settings: Settings):
|
|
44
|
+
self._settings = settings
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def scheduling(self) -> bool:
|
|
48
|
+
return self._settings.enable_scheduling
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def metrics(self) -> bool:
|
|
52
|
+
return self._settings.enable_metrics
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@lru_cache
|
|
56
|
+
def get_settings() -> Settings:
|
|
57
|
+
return Settings()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@lru_cache
|
|
61
|
+
def get_features() -> Features:
|
|
62
|
+
return Features(get_settings())
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Enhanced coordination module for multi-agent workflows."""
|
|
2
|
+
|
|
3
|
+
from .agent_group import (
|
|
4
|
+
AgentGroup,
|
|
5
|
+
AgentOrchestrator,
|
|
6
|
+
ExecutionStrategy,
|
|
7
|
+
GroupResult,
|
|
8
|
+
OrchestrationExecution,
|
|
9
|
+
OrchestrationResult,
|
|
10
|
+
ParallelAgentGroup,
|
|
11
|
+
StageConfig,
|
|
12
|
+
)
|
|
13
|
+
from .agent_pool import (
|
|
14
|
+
AgentPool,
|
|
15
|
+
CompletedWork,
|
|
16
|
+
DynamicProcessingPool,
|
|
17
|
+
PoolContext,
|
|
18
|
+
ScalingPolicy,
|
|
19
|
+
WorkItem,
|
|
20
|
+
WorkProcessor,
|
|
21
|
+
WorkQueue,
|
|
22
|
+
)
|
|
23
|
+
from .agent_team import (
|
|
24
|
+
AgentTeam,
|
|
25
|
+
Event,
|
|
26
|
+
EventBus,
|
|
27
|
+
Message,
|
|
28
|
+
TeamResult,
|
|
29
|
+
create_team,
|
|
30
|
+
run_agents_parallel,
|
|
31
|
+
run_agents_sequential,
|
|
32
|
+
)
|
|
33
|
+
from .fluent_api import (
|
|
34
|
+
Agents,
|
|
35
|
+
ConditionalAgents,
|
|
36
|
+
FluentResult,
|
|
37
|
+
PipelineAgents,
|
|
38
|
+
collect_agent_outputs,
|
|
39
|
+
create_agent_team,
|
|
40
|
+
create_pipeline,
|
|
41
|
+
get_best_agent,
|
|
42
|
+
run_parallel_agents,
|
|
43
|
+
run_sequential_agents,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Import existing coordination components
|
|
47
|
+
try:
|
|
48
|
+
from .coordinator import AgentCoordinator, CoordinationConfig, enhance_agent
|
|
49
|
+
from .deadlock import DeadlockDetector, DeadlockResolutionStrategy
|
|
50
|
+
from .primitives import (
|
|
51
|
+
Barrier,
|
|
52
|
+
CoordinationPrimitive,
|
|
53
|
+
Lease,
|
|
54
|
+
Lock,
|
|
55
|
+
Mutex,
|
|
56
|
+
PrimitiveType,
|
|
57
|
+
Quota,
|
|
58
|
+
Semaphore,
|
|
59
|
+
create_primitive,
|
|
60
|
+
)
|
|
61
|
+
from .rate_limiter import (
|
|
62
|
+
AdaptiveRateLimiter,
|
|
63
|
+
CompositeRateLimiter,
|
|
64
|
+
FixedWindow,
|
|
65
|
+
LeakyBucket,
|
|
66
|
+
RateLimiter,
|
|
67
|
+
RateLimitStrategy,
|
|
68
|
+
SlidingWindow,
|
|
69
|
+
TokenBucket,
|
|
70
|
+
)
|
|
71
|
+
except ImportError:
|
|
72
|
+
# Some coordination components may not be available
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
__all__ = [
|
|
76
|
+
"AdaptiveRateLimiter",
|
|
77
|
+
# Existing coordination (if available)
|
|
78
|
+
"AgentCoordinator",
|
|
79
|
+
# Group coordination
|
|
80
|
+
"AgentGroup",
|
|
81
|
+
"AgentOrchestrator",
|
|
82
|
+
# Agent pools
|
|
83
|
+
"AgentPool",
|
|
84
|
+
# Team coordination
|
|
85
|
+
"AgentTeam",
|
|
86
|
+
# Fluent APIs
|
|
87
|
+
"Agents",
|
|
88
|
+
"Barrier",
|
|
89
|
+
"CompletedWork",
|
|
90
|
+
"CompositeRateLimiter",
|
|
91
|
+
"ConditionalAgents",
|
|
92
|
+
"CoordinationConfig",
|
|
93
|
+
"CoordinationPrimitive",
|
|
94
|
+
"DeadlockDetector",
|
|
95
|
+
"DeadlockResolutionStrategy",
|
|
96
|
+
"DynamicProcessingPool",
|
|
97
|
+
"Event",
|
|
98
|
+
"EventBus",
|
|
99
|
+
"ExecutionStrategy",
|
|
100
|
+
"FixedWindow",
|
|
101
|
+
"FluentResult",
|
|
102
|
+
"GroupResult",
|
|
103
|
+
"LeakyBucket",
|
|
104
|
+
"Lease",
|
|
105
|
+
"Lock",
|
|
106
|
+
"Message",
|
|
107
|
+
"Mutex",
|
|
108
|
+
"OrchestrationExecution",
|
|
109
|
+
"OrchestrationResult",
|
|
110
|
+
"ParallelAgentGroup",
|
|
111
|
+
"PipelineAgents",
|
|
112
|
+
"PoolContext",
|
|
113
|
+
"PrimitiveType",
|
|
114
|
+
"Quota",
|
|
115
|
+
"RateLimitStrategy",
|
|
116
|
+
"RateLimiter",
|
|
117
|
+
"ScalingPolicy",
|
|
118
|
+
"Semaphore",
|
|
119
|
+
"SlidingWindow",
|
|
120
|
+
"StageConfig",
|
|
121
|
+
"TeamResult",
|
|
122
|
+
"TokenBucket",
|
|
123
|
+
"WorkItem",
|
|
124
|
+
"WorkProcessor",
|
|
125
|
+
"WorkQueue",
|
|
126
|
+
"collect_agent_outputs",
|
|
127
|
+
"create_agent_team",
|
|
128
|
+
"create_pipeline",
|
|
129
|
+
"create_primitive",
|
|
130
|
+
"create_team",
|
|
131
|
+
"enhance_agent",
|
|
132
|
+
"get_best_agent",
|
|
133
|
+
"run_agents_parallel",
|
|
134
|
+
"run_agents_sequential",
|
|
135
|
+
"run_parallel_agents",
|
|
136
|
+
"run_sequential_agents",
|
|
137
|
+
]
|
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""Agent groups for advanced coordination patterns."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
6
|
+
|
|
7
|
+
from ..agent.base import Agent, AgentResult
|
|
8
|
+
from .agent_team import AgentTeam, TeamResult
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import asyncio
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ExecutionStrategy:
|
|
17
|
+
"""Execution strategy constants."""
|
|
18
|
+
|
|
19
|
+
PARALLEL = "parallel"
|
|
20
|
+
SEQUENTIAL = "sequential"
|
|
21
|
+
PIPELINE = "pipeline"
|
|
22
|
+
FAN_OUT = "fan_out"
|
|
23
|
+
FAN_IN = "fan_in"
|
|
24
|
+
CONDITIONAL = "conditional"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class StageConfig:
|
|
29
|
+
"""Configuration for execution stage."""
|
|
30
|
+
|
|
31
|
+
name: str
|
|
32
|
+
agents: list[str]
|
|
33
|
+
strategy: str = ExecutionStrategy.PARALLEL
|
|
34
|
+
depends_on: Optional[list[str]] = None
|
|
35
|
+
condition: Optional[Callable] = None
|
|
36
|
+
timeout: Optional[float] = None
|
|
37
|
+
|
|
38
|
+
def __post_init__(self) -> None:
|
|
39
|
+
if self.depends_on is None:
|
|
40
|
+
self.depends_on = []
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AgentGroup:
|
|
44
|
+
"""Simple agent group for basic parallel execution."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, agents: list[Agent]):
|
|
47
|
+
self.agents = {agent.name: agent for agent in agents}
|
|
48
|
+
self._results: dict[str, AgentResult] = {}
|
|
49
|
+
|
|
50
|
+
async def run_parallel(self, timeout: Optional[float] = None) -> "GroupResult":
|
|
51
|
+
"""Run all agents in parallel."""
|
|
52
|
+
team = AgentTeam("group_execution")
|
|
53
|
+
team.add_agents(list(self.agents.values()))
|
|
54
|
+
|
|
55
|
+
result = await team.run_parallel(timeout)
|
|
56
|
+
return GroupResult(result)
|
|
57
|
+
|
|
58
|
+
async def collect_all(self) -> "GroupResult":
|
|
59
|
+
"""Run and collect all results."""
|
|
60
|
+
return await self.run_parallel()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GroupResult:
|
|
64
|
+
"""Wrapper for team results with additional methods."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, team_result: TeamResult):
|
|
67
|
+
self._team_result = team_result
|
|
68
|
+
|
|
69
|
+
def __getattr__(self, name: str) -> Any:
|
|
70
|
+
"""Delegate to team result."""
|
|
71
|
+
return getattr(self._team_result, name)
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def agents(self) -> list[AgentResult]:
|
|
75
|
+
"""Get list of agent results."""
|
|
76
|
+
return list(self._team_result.agent_results.values())
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ParallelAgentGroup(AgentGroup):
|
|
80
|
+
"""Enhanced parallel agent group."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, name: str):
|
|
83
|
+
self.name = name
|
|
84
|
+
self._agents: list[Agent] = []
|
|
85
|
+
|
|
86
|
+
def add_agent(self, agent: Agent) -> "ParallelAgentGroup":
|
|
87
|
+
"""Add agent to group."""
|
|
88
|
+
self._agents.append(agent)
|
|
89
|
+
return self
|
|
90
|
+
|
|
91
|
+
async def run_and_collect(self, timeout: Optional[float] = None) -> GroupResult:
|
|
92
|
+
"""Run all agents and collect results."""
|
|
93
|
+
return await AgentGroup(self._agents).run_parallel(timeout)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class AgentOrchestrator:
|
|
97
|
+
"""Advanced agent orchestrator with multiple execution strategies."""
|
|
98
|
+
|
|
99
|
+
def __init__(self, name: str):
|
|
100
|
+
self.name = name
|
|
101
|
+
self._agents: dict[str, Agent] = {}
|
|
102
|
+
self._stages: list[StageConfig] = []
|
|
103
|
+
self._global_variables: dict[str, Any] = {}
|
|
104
|
+
self._stage_results: dict[str, TeamResult] = {}
|
|
105
|
+
self._execution_context: dict[str, Any] = {}
|
|
106
|
+
|
|
107
|
+
def add_agent(self, agent: Agent) -> "AgentOrchestrator":
|
|
108
|
+
"""Add agent to orchestrator."""
|
|
109
|
+
self._agents[agent.name] = agent
|
|
110
|
+
return self
|
|
111
|
+
|
|
112
|
+
def add_agents(self, agents: list[Agent]) -> "AgentOrchestrator":
|
|
113
|
+
"""Add multiple agents."""
|
|
114
|
+
for agent in agents:
|
|
115
|
+
self.add_agent(agent)
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def add_stage(
|
|
119
|
+
self,
|
|
120
|
+
name: str,
|
|
121
|
+
agents: list[Agent],
|
|
122
|
+
strategy: str = ExecutionStrategy.PARALLEL,
|
|
123
|
+
depends_on: Optional[list[str]] = None,
|
|
124
|
+
condition: Optional[Callable] = None,
|
|
125
|
+
timeout: Optional[float] = None,
|
|
126
|
+
) -> "AgentOrchestrator":
|
|
127
|
+
"""Add execution stage."""
|
|
128
|
+
# Add agents if not already added
|
|
129
|
+
for agent in agents:
|
|
130
|
+
if agent.name not in self._agents:
|
|
131
|
+
self.add_agent(agent)
|
|
132
|
+
|
|
133
|
+
stage = StageConfig(
|
|
134
|
+
name=name,
|
|
135
|
+
agents=[agent.name for agent in agents],
|
|
136
|
+
strategy=strategy,
|
|
137
|
+
depends_on=depends_on or [],
|
|
138
|
+
condition=condition,
|
|
139
|
+
timeout=timeout,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
self._stages.append(stage)
|
|
143
|
+
return self
|
|
144
|
+
|
|
145
|
+
def set_global_variable(self, key: str, value: Any) -> None:
|
|
146
|
+
"""Set global variable for all agents."""
|
|
147
|
+
self._global_variables[key] = value
|
|
148
|
+
for agent in self._agents.values():
|
|
149
|
+
agent.set_shared_variable(key, value)
|
|
150
|
+
|
|
151
|
+
def get_global_variable(self, key: str, default: Any = None) -> Any:
|
|
152
|
+
"""Get global variable."""
|
|
153
|
+
return self._global_variables.get(key, default)
|
|
154
|
+
|
|
155
|
+
def run_with_monitoring(self) -> "OrchestrationExecution":
|
|
156
|
+
"""Run with monitoring capability."""
|
|
157
|
+
return OrchestrationExecution(self)
|
|
158
|
+
|
|
159
|
+
async def run(self) -> "OrchestrationResult":
|
|
160
|
+
"""Run the complete orchestration."""
|
|
161
|
+
async with self.run_with_monitoring() as execution:
|
|
162
|
+
result = await execution.wait_for_completion()
|
|
163
|
+
return result
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class OrchestrationExecution:
|
|
167
|
+
"""Context manager for orchestration execution with monitoring."""
|
|
168
|
+
|
|
169
|
+
def __init__(self, orchestrator: AgentOrchestrator):
|
|
170
|
+
self.orchestrator = orchestrator
|
|
171
|
+
self.start_time: Optional[float] = None
|
|
172
|
+
self.end_time: Optional[float] = None
|
|
173
|
+
self._stage_results: dict[str, TeamResult] = {}
|
|
174
|
+
self._completed_stages: set[str] = set()
|
|
175
|
+
self._running_stages: dict[str, asyncio.Task] = {}
|
|
176
|
+
self._stage_timings: dict[str, tuple] = {}
|
|
177
|
+
|
|
178
|
+
async def __aenter__(self) -> "OrchestrationExecution":
|
|
179
|
+
"""Start execution."""
|
|
180
|
+
import time
|
|
181
|
+
|
|
182
|
+
self.start_time = time.time()
|
|
183
|
+
return self
|
|
184
|
+
|
|
185
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
186
|
+
"""End execution."""
|
|
187
|
+
import time
|
|
188
|
+
|
|
189
|
+
self.end_time = time.time()
|
|
190
|
+
|
|
191
|
+
# Cancel any running stages
|
|
192
|
+
for task in self._running_stages.values():
|
|
193
|
+
if not task.done():
|
|
194
|
+
task.cancel()
|
|
195
|
+
|
|
196
|
+
async def wait_for_stage(self, stage_name: str) -> TeamResult:
|
|
197
|
+
"""Wait for specific stage to complete."""
|
|
198
|
+
if stage_name in self._completed_stages:
|
|
199
|
+
return self._stage_results[stage_name]
|
|
200
|
+
|
|
201
|
+
# Find stage config
|
|
202
|
+
stage_config = None
|
|
203
|
+
for stage in self.orchestrator._stages:
|
|
204
|
+
if stage.name == stage_name:
|
|
205
|
+
stage_config = stage
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
if not stage_config:
|
|
209
|
+
raise ValueError(f"Stage {stage_name} not found")
|
|
210
|
+
|
|
211
|
+
# Check dependencies
|
|
212
|
+
for dep in stage_config.depends_on or []:
|
|
213
|
+
if dep not in self._completed_stages:
|
|
214
|
+
await self.wait_for_stage(dep)
|
|
215
|
+
|
|
216
|
+
# Check condition
|
|
217
|
+
if stage_config.condition and not stage_config.condition():
|
|
218
|
+
# Create empty result for skipped stage
|
|
219
|
+
result = TeamResult(
|
|
220
|
+
team_name=f"{self.orchestrator.name}_{stage_name}", status="skipped"
|
|
221
|
+
)
|
|
222
|
+
self._stage_results[stage_name] = result
|
|
223
|
+
self._completed_stages.add(stage_name)
|
|
224
|
+
return result
|
|
225
|
+
|
|
226
|
+
# Execute stage
|
|
227
|
+
return await self._execute_stage(stage_config)
|
|
228
|
+
|
|
229
|
+
async def _execute_stage(self, stage_config: StageConfig) -> TeamResult:
|
|
230
|
+
"""Execute a single stage."""
|
|
231
|
+
import time
|
|
232
|
+
|
|
233
|
+
stage_start = time.time()
|
|
234
|
+
|
|
235
|
+
# Get agents for this stage
|
|
236
|
+
stage_agents = [
|
|
237
|
+
self.orchestrator._agents[name]
|
|
238
|
+
for name in stage_config.agents
|
|
239
|
+
if name in self.orchestrator._agents
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
if not stage_agents:
|
|
243
|
+
raise ValueError(f"No valid agents found for stage {stage_config.name}")
|
|
244
|
+
|
|
245
|
+
# Create team for stage
|
|
246
|
+
team = AgentTeam(f"{self.orchestrator.name}_{stage_config.name}")
|
|
247
|
+
team.add_agents(stage_agents)
|
|
248
|
+
|
|
249
|
+
# Set global variables
|
|
250
|
+
for key, value in self.orchestrator._global_variables.items():
|
|
251
|
+
team.set_global_variable(key, value)
|
|
252
|
+
|
|
253
|
+
# Execute based on strategy
|
|
254
|
+
if stage_config.strategy == ExecutionStrategy.PARALLEL:
|
|
255
|
+
result = await team.run_parallel(stage_config.timeout)
|
|
256
|
+
elif stage_config.strategy == ExecutionStrategy.SEQUENTIAL:
|
|
257
|
+
result = await team.run_sequential()
|
|
258
|
+
else:
|
|
259
|
+
# Default to parallel
|
|
260
|
+
result = await team.run_parallel(stage_config.timeout)
|
|
261
|
+
|
|
262
|
+
stage_end = time.time()
|
|
263
|
+
|
|
264
|
+
# Store results
|
|
265
|
+
self._stage_results[stage_config.name] = result
|
|
266
|
+
self._completed_stages.add(stage_config.name)
|
|
267
|
+
self._stage_timings[stage_config.name] = (stage_start, stage_end)
|
|
268
|
+
|
|
269
|
+
return result
|
|
270
|
+
|
|
271
|
+
def set_stage_input(self, stage_name: str, key: str, value: Any) -> None:
|
|
272
|
+
"""Set input for a stage."""
|
|
273
|
+
# Set variable for all agents in the stage
|
|
274
|
+
stage_config = None
|
|
275
|
+
for stage in self.orchestrator._stages:
|
|
276
|
+
if stage.name == stage_name:
|
|
277
|
+
stage_config = stage
|
|
278
|
+
break
|
|
279
|
+
|
|
280
|
+
if stage_config:
|
|
281
|
+
for agent_name in stage_config.agents:
|
|
282
|
+
agent = self.orchestrator._agents.get(agent_name)
|
|
283
|
+
if agent:
|
|
284
|
+
agent.set_variable(key, value)
|
|
285
|
+
|
|
286
|
+
def is_stage_complete(self, stage_name: str) -> bool:
|
|
287
|
+
"""Check if stage is complete."""
|
|
288
|
+
return stage_name in self._completed_stages
|
|
289
|
+
|
|
290
|
+
def get_stage_results(self, stage_name: str) -> Optional[TeamResult]:
|
|
291
|
+
"""Get results for a stage."""
|
|
292
|
+
return self._stage_results.get(stage_name)
|
|
293
|
+
|
|
294
|
+
async def get_final_results(self) -> "OrchestrationResult":
|
|
295
|
+
"""Get final orchestration results."""
|
|
296
|
+
# Wait for all stages
|
|
297
|
+
for stage in self.orchestrator._stages:
|
|
298
|
+
if stage.name not in self._completed_stages:
|
|
299
|
+
await self.wait_for_stage(stage.name)
|
|
300
|
+
|
|
301
|
+
return OrchestrationResult(
|
|
302
|
+
orchestrator_name=self.orchestrator.name,
|
|
303
|
+
stage_results=self._stage_results,
|
|
304
|
+
stage_timings=self._stage_timings,
|
|
305
|
+
total_duration=self.total_duration,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
async def wait_for_completion(self) -> "OrchestrationResult":
|
|
309
|
+
"""Wait for complete orchestration."""
|
|
310
|
+
return await self.get_final_results()
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
def total_duration(self) -> Optional[float]:
|
|
314
|
+
"""Get total execution duration."""
|
|
315
|
+
if self.start_time and self.end_time:
|
|
316
|
+
return self.end_time - self.start_time
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
def get_stage_timings(self) -> dict[str, float]:
|
|
320
|
+
"""Get timing for each stage."""
|
|
321
|
+
return {
|
|
322
|
+
stage: end - start for stage, (start, end) in self._stage_timings.items()
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@dataclass
|
|
327
|
+
class OrchestrationResult:
|
|
328
|
+
"""Result of orchestration execution."""
|
|
329
|
+
|
|
330
|
+
orchestrator_name: str
|
|
331
|
+
stage_results: dict[str, TeamResult]
|
|
332
|
+
stage_timings: dict[str, tuple]
|
|
333
|
+
total_duration: Optional[float]
|
|
334
|
+
|
|
335
|
+
def get_stage_result(self, stage_name: str) -> Optional[TeamResult]:
|
|
336
|
+
"""Get result for specific stage."""
|
|
337
|
+
return self.stage_results.get(stage_name)
|
|
338
|
+
|
|
339
|
+
def get_final_result(self) -> dict[str, Any]:
|
|
340
|
+
"""Get final aggregated result."""
|
|
341
|
+
# Combine results from all stages
|
|
342
|
+
all_outputs = {}
|
|
343
|
+
all_variables = {}
|
|
344
|
+
|
|
345
|
+
for stage_name, stage_result in self.stage_results.items():
|
|
346
|
+
# Collect outputs from all agents in stage
|
|
347
|
+
for agent_name, agent_result in stage_result.agent_results.items():
|
|
348
|
+
all_outputs[f"{stage_name}_{agent_name}"] = agent_result.outputs
|
|
349
|
+
all_variables[f"{stage_name}_{agent_name}"] = agent_result.variables
|
|
350
|
+
|
|
351
|
+
return {
|
|
352
|
+
"outputs": all_outputs,
|
|
353
|
+
"variables": all_variables,
|
|
354
|
+
"stage_count": len(self.stage_results),
|
|
355
|
+
"total_duration": self.total_duration,
|
|
356
|
+
"stage_durations": {
|
|
357
|
+
stage: end - start for stage, (start, end) in self.stage_timings.items()
|
|
358
|
+
},
|
|
359
|
+
}
|