pyagent-patterns 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pyagent_patterns/__init__.py +20 -0
  2. pyagent_patterns/advanced/__init__.py +8 -0
  3. pyagent_patterns/advanced/human_in_the_loop.py +103 -0
  4. pyagent_patterns/advanced/react.py +132 -0
  5. pyagent_patterns/advanced/swarm.py +106 -0
  6. pyagent_patterns/advanced/talker_reasoner.py +92 -0
  7. pyagent_patterns/advisor.py +166 -0
  8. pyagent_patterns/base.py +215 -0
  9. pyagent_patterns/composite.py +105 -0
  10. pyagent_patterns/guardrails.py +165 -0
  11. pyagent_patterns/orchestration/__init__.py +9 -0
  12. pyagent_patterns/orchestration/fan_out_fan_in.py +76 -0
  13. pyagent_patterns/orchestration/hierarchical.py +110 -0
  14. pyagent_patterns/orchestration/orchestrator_workers.py +97 -0
  15. pyagent_patterns/orchestration/pipeline.py +57 -0
  16. pyagent_patterns/orchestration/supervisor.py +88 -0
  17. pyagent_patterns/py.typed +0 -0
  18. pyagent_patterns/recovery.py +175 -0
  19. pyagent_patterns/registry.py +71 -0
  20. pyagent_patterns/resolution/__init__.py +9 -0
  21. pyagent_patterns/resolution/cross_reflection.py +79 -0
  22. pyagent_patterns/resolution/debate.py +103 -0
  23. pyagent_patterns/resolution/evaluator_optimizer.py +108 -0
  24. pyagent_patterns/resolution/self_reflection.py +80 -0
  25. pyagent_patterns/resolution/voting.py +93 -0
  26. pyagent_patterns/streaming.py +119 -0
  27. pyagent_patterns/structural/__init__.py +8 -0
  28. pyagent_patterns/structural/blackboard.py +137 -0
  29. pyagent_patterns/structural/layered.py +70 -0
  30. pyagent_patterns/structural/role_based.py +61 -0
  31. pyagent_patterns/structural/topology.py +123 -0
  32. pyagent_patterns-0.1.0.dist-info/METADATA +59 -0
  33. pyagent_patterns-0.1.0.dist-info/RECORD +34 -0
  34. pyagent_patterns-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,215 @@
1
+ """Core abstractions: Pattern, Agent, Message, Context, Result."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import time
7
+ import uuid
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+ from typing import Any, AsyncIterator, Callable, Protocol, runtime_checkable
12
+
13
+
14
+ class Role(str, Enum):
15
+ """Well-known agent roles."""
16
+
17
+ SYSTEM = "system"
18
+ USER = "user"
19
+ ASSISTANT = "assistant"
20
+ TOOL = "tool"
21
+
22
+
23
+ @dataclass(frozen=True, slots=True)
24
+ class Message:
25
+ """A single message in an agent conversation.
26
+
27
+ Attributes:
28
+ role: The sender role (system, user, assistant, tool).
29
+ content: The text content of the message.
30
+ name: Optional agent name for multi-agent conversations.
31
+ metadata: Arbitrary key-value metadata attached to the message.
32
+ """
33
+
34
+ role: Role
35
+ content: str
36
+ name: str | None = None
37
+ metadata: dict[str, Any] = field(default_factory=dict)
38
+
39
+ @classmethod
40
+ def system(cls, content: str, **kw: Any) -> Message:
41
+ return cls(role=Role.SYSTEM, content=content, **kw)
42
+
43
+ @classmethod
44
+ def user(cls, content: str, **kw: Any) -> Message:
45
+ return cls(role=Role.USER, content=content, **kw)
46
+
47
+ @classmethod
48
+ def assistant(cls, content: str, name: str | None = None, **kw: Any) -> Message:
49
+ return cls(role=Role.ASSISTANT, content=content, name=name, **kw)
50
+
51
+
52
+ @runtime_checkable
53
+ class LLMCallable(Protocol):
54
+ """Protocol for any LLM backend — sync or async."""
55
+
56
+ async def __call__(self, messages: list[Message]) -> str: ...
57
+
58
+
59
+ class MockLLM:
60
+ """A mock LLM for testing that echoes or returns canned responses.
61
+
62
+ Args:
63
+ responses: If provided, returns these in order (cycling). Otherwise echoes last user message.
64
+ delay: Simulated latency in seconds.
65
+ """
66
+
67
+ def __init__(self, responses: list[str] | None = None, delay: float = 0.0) -> None:
68
+ self._responses = responses or []
69
+ self._index = 0
70
+ self._delay = delay
71
+ self.call_count = 0
72
+ self.call_log: list[list[Message]] = []
73
+
74
+ async def __call__(self, messages: list[Message]) -> str:
75
+ if self._delay > 0:
76
+ await asyncio.sleep(self._delay)
77
+ self.call_count += 1
78
+ self.call_log.append(list(messages))
79
+ if self._responses:
80
+ resp = self._responses[self._index % len(self._responses)]
81
+ self._index += 1
82
+ return resp
83
+ # Echo the last user message
84
+ for msg in reversed(messages):
85
+ if msg.role == Role.USER:
86
+ return f"[MockLLM] Echo: {msg.content}"
87
+ return "[MockLLM] No user message found"
88
+
89
+
90
+ @dataclass
91
+ class Context:
92
+ """Shared execution context for a pattern run.
93
+
94
+ Attributes:
95
+ task: The original user task/prompt.
96
+ messages: Accumulated message history.
97
+ metadata: Arbitrary shared state across agents.
98
+ parent_id: ID of the parent context (for nested patterns).
99
+ """
100
+
101
+ task: str
102
+ messages: list[Message] = field(default_factory=list)
103
+ metadata: dict[str, Any] = field(default_factory=dict)
104
+ parent_id: str | None = None
105
+ _id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
106
+
107
+ @property
108
+ def id(self) -> str:
109
+ return self._id
110
+
111
+ def child(self, task: str | None = None) -> Context:
112
+ """Create a child context for nested pattern execution."""
113
+ return Context(
114
+ task=task or self.task,
115
+ metadata=dict(self.metadata),
116
+ parent_id=self._id,
117
+ )
118
+
119
+
120
+ @dataclass
121
+ class Result:
122
+ """Outcome of a pattern execution.
123
+
124
+ Attributes:
125
+ output: The final output text.
126
+ messages: All messages generated during execution.
127
+ metadata: Pattern-specific metadata (rounds, consensus, votes, etc.).
128
+ duration_seconds: Wall-clock execution time.
129
+ token_estimate: Rough estimate of total tokens consumed.
130
+ cost_estimate: Rough estimate of total cost in USD.
131
+ """
132
+
133
+ output: str
134
+ messages: list[Message] = field(default_factory=list)
135
+ metadata: dict[str, Any] = field(default_factory=dict)
136
+ duration_seconds: float = 0.0
137
+ token_estimate: int = 0
138
+ cost_estimate: float = 0.0
139
+
140
+
141
+ @dataclass
142
+ class Agent:
143
+ """An LLM-backed agent with a name, system prompt, and callable.
144
+
145
+ Args:
146
+ name: Human-readable agent name.
147
+ llm: The LLM callable to use for this agent.
148
+ system_prompt: Optional system prompt prepended to every call.
149
+ description: Description of the agent's purpose (for routing/selection).
150
+ """
151
+
152
+ name: str
153
+ llm: LLMCallable
154
+ system_prompt: str = ""
155
+ description: str = ""
156
+
157
+ async def run(self, messages: list[Message]) -> Message:
158
+ """Send messages to the LLM and return an assistant message."""
159
+ call_messages = list(messages)
160
+ if self.system_prompt:
161
+ call_messages.insert(0, Message.system(self.system_prompt))
162
+ content = await self.llm(call_messages)
163
+ return Message.assistant(content, name=self.name)
164
+
165
+
166
+ class Pattern(ABC):
167
+ """Abstract base class for all multi-agent patterns.
168
+
169
+ Subclasses must implement `_execute`. The `run` method handles timing,
170
+ context creation, and metadata collection.
171
+ """
172
+
173
+ @property
174
+ @abstractmethod
175
+ def pattern_type(self) -> str:
176
+ """Return the pattern type name (e.g., 'supervisor', 'debate')."""
177
+ ...
178
+
179
+ async def run(self, task: str, context: Context | None = None) -> Result:
180
+ """Execute the pattern on the given task.
181
+
182
+ Args:
183
+ task: The user task or prompt.
184
+ context: Optional existing context. Created automatically if None.
185
+
186
+ Returns:
187
+ Result with output, messages, metadata, timing, and cost estimates.
188
+ """
189
+ ctx = context or Context(task=task)
190
+ ctx.messages.append(Message.user(task))
191
+
192
+ start = time.perf_counter()
193
+ result = await self._execute(ctx)
194
+ result.duration_seconds = time.perf_counter() - start
195
+ result.metadata["pattern_type"] = self.pattern_type
196
+
197
+ # Rough token estimate: ~4 chars per token
198
+ total_chars = sum(len(m.content) for m in result.messages)
199
+ result.token_estimate = total_chars // 4
200
+
201
+ return result
202
+
203
+ @abstractmethod
204
+ async def _execute(self, ctx: Context) -> Result:
205
+ """Implement the pattern logic. Called by `run`."""
206
+ ...
207
+
208
+ async def stream(self, task: str, context: Context | None = None) -> AsyncIterator[str]:
209
+ """Stream partial results as they become available.
210
+
211
+ Default implementation runs the full pattern and yields the result.
212
+ Subclasses can override for true streaming.
213
+ """
214
+ result = await self.run(task, context)
215
+ yield result.output
@@ -0,0 +1,105 @@
1
+ """Composite patterns: chain/nest multiple patterns with escalation triggers.
2
+
3
+ CompositePattern runs patterns in sequence, escalating to the next
4
+ if the current pattern's result doesn't meet a quality check.
5
+
6
+ EscalationChain is a pre-built composite: Reflection → Debate → Voting → Human.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable
12
+
13
+ from pyagent_patterns.base import Context, Pattern, Result
14
+
15
+
16
+ # Quality check: returns True if the result is acceptable
17
+ QualityCheckFn = Callable[[Result], bool]
18
+
19
+
20
+ def always_pass(result: Result) -> bool:
21
+ """Default quality check: always passes."""
22
+ return True
23
+
24
+
25
+ def min_length_check(min_chars: int = 50) -> QualityCheckFn:
26
+ """Quality check: output must be at least min_chars characters."""
27
+
28
+ def check(result: Result) -> bool:
29
+ return len(result.output) >= min_chars
30
+
31
+ return check
32
+
33
+
34
+ class CompositePattern(Pattern):
35
+ """Chain multiple patterns with escalation on quality failure.
36
+
37
+ Args:
38
+ patterns: Ordered list of patterns to try.
39
+ quality_check: Function that evaluates whether a pattern's result
40
+ is acceptable. If it returns False, the next pattern is tried.
41
+ combine_results: If True, passes previous output as context to next pattern.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ patterns: list[Pattern],
47
+ quality_check: QualityCheckFn = always_pass,
48
+ combine_results: bool = True,
49
+ ) -> None:
50
+ if not patterns:
51
+ raise ValueError("CompositePattern requires at least one pattern")
52
+ self._patterns = patterns
53
+ self._quality_check = quality_check
54
+ self._combine_results = combine_results
55
+
56
+ @property
57
+ def pattern_type(self) -> str:
58
+ types = [p.pattern_type for p in self._patterns]
59
+ return f"composite({'+'.join(types)})"
60
+
61
+ async def _execute(self, ctx: Context) -> Result:
62
+ all_messages = []
63
+ escalation_log = []
64
+
65
+ for i, pattern in enumerate(self._patterns):
66
+ # Create child context for each pattern
67
+ child_ctx = ctx.child()
68
+
69
+ result = await pattern._execute(child_ctx)
70
+ all_messages.extend(result.messages)
71
+ escalation_log.append({
72
+ "pattern": pattern.pattern_type,
73
+ "output_length": len(result.output),
74
+ "metadata": result.metadata,
75
+ })
76
+
77
+ if self._quality_check(result):
78
+ return Result(
79
+ output=result.output,
80
+ messages=all_messages,
81
+ metadata={
82
+ "escalation_level": i,
83
+ "pattern_used": pattern.pattern_type,
84
+ "escalation_log": escalation_log,
85
+ "total_patterns_tried": i + 1,
86
+ },
87
+ )
88
+
89
+ # If not last pattern and combining, update context with current output
90
+ if self._combine_results and i < len(self._patterns) - 1:
91
+ ctx.metadata["previous_output"] = result.output
92
+ ctx.metadata["previous_pattern"] = pattern.pattern_type
93
+
94
+ # All patterns tried, return last result
95
+ return Result(
96
+ output=result.output,
97
+ messages=all_messages,
98
+ metadata={
99
+ "escalation_level": len(self._patterns) - 1,
100
+ "pattern_used": self._patterns[-1].pattern_type,
101
+ "escalation_log": escalation_log,
102
+ "total_patterns_tried": len(self._patterns),
103
+ "fully_escalated": True,
104
+ },
105
+ )
@@ -0,0 +1,165 @@
1
+ """Guardrail Layer: validate input, output, and inter-agent messages.
2
+
3
+ Based on: Microsoft Azure AI Agent Patterns Guide
4
+ Augment 2026 "Guardrail Layering (P11)"
5
+
6
+ Four insertion points:
7
+ 1. Input guardrail: before the pattern receives user input
8
+ 2. Inter-agent guardrail: between agent communications
9
+ 3. Tool-call guardrail: before tool execution
10
+ 4. Output guardrail: before returning final result
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import re
16
+ from abc import ABC, abstractmethod
17
+ from dataclasses import dataclass
18
+
19
+
20
+ @dataclass
21
+ class GuardrailResult:
22
+ """Result of a guardrail check."""
23
+
24
+ passed: bool
25
+ message: str = ""
26
+ sanitized_content: str | None = None # Modified content if sanitization applied
27
+
28
+
29
+ class Guardrail(ABC):
30
+ """Abstract base for all guardrails."""
31
+
32
+ @abstractmethod
33
+ def check(self, content: str) -> GuardrailResult:
34
+ """Check content against this guardrail's rules."""
35
+ ...
36
+
37
+
38
+ class LengthGuard(Guardrail):
39
+ """Reject messages exceeding a maximum length.
40
+
41
+ Args:
42
+ max_chars: Maximum allowed characters.
43
+ truncate: If True, truncate instead of rejecting.
44
+ """
45
+
46
+ def __init__(self, max_chars: int = 10000, truncate: bool = False) -> None:
47
+ self._max_chars = max_chars
48
+ self._truncate = truncate
49
+
50
+ def check(self, content: str) -> GuardrailResult:
51
+ if len(content) <= self._max_chars:
52
+ return GuardrailResult(passed=True)
53
+ if self._truncate:
54
+ return GuardrailResult(
55
+ passed=True,
56
+ message=f"Truncated from {len(content)} to {self._max_chars} chars",
57
+ sanitized_content=content[: self._max_chars] + "... [truncated]",
58
+ )
59
+ return GuardrailResult(
60
+ passed=False,
61
+ message=f"Content exceeds maximum length: {len(content)}/{self._max_chars} chars",
62
+ )
63
+
64
+
65
+ class PIIGuard(Guardrail):
66
+ """Detect and optionally redact personally identifiable information.
67
+
68
+ Detects: email addresses, phone numbers, SSNs, credit card numbers.
69
+
70
+ Args:
71
+ redact: If True, redact PII and pass. If False, reject on detection.
72
+ """
73
+
74
+ _PATTERNS = {
75
+ "email": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
76
+ "phone": r"\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b",
77
+ "ssn": r"\b\d{3}[-]?\d{2}[-]?\d{4}\b",
78
+ "credit_card": r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
79
+ }
80
+
81
+ def __init__(self, redact: bool = True) -> None:
82
+ self._redact = redact
83
+
84
+ def check(self, content: str) -> GuardrailResult:
85
+ detections: list[str] = []
86
+ sanitized = content
87
+
88
+ for pii_type, pattern in self._PATTERNS.items():
89
+ matches = re.findall(pattern, content)
90
+ if matches:
91
+ detections.append(f"{pii_type}: {len(matches)} found")
92
+ if self._redact:
93
+ sanitized = re.sub(pattern, f"[REDACTED-{pii_type.upper()}]", sanitized)
94
+
95
+ if not detections:
96
+ return GuardrailResult(passed=True)
97
+
98
+ if self._redact:
99
+ return GuardrailResult(
100
+ passed=True,
101
+ message=f"PII redacted: {', '.join(detections)}",
102
+ sanitized_content=sanitized,
103
+ )
104
+
105
+ return GuardrailResult(
106
+ passed=False,
107
+ message=f"PII detected: {', '.join(detections)}",
108
+ )
109
+
110
+
111
+ class ContentGuard(Guardrail):
112
+ """Block content matching configurable deny patterns.
113
+
114
+ Args:
115
+ deny_patterns: List of regex patterns that should be blocked.
116
+ deny_words: List of exact words/phrases to block.
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ deny_patterns: list[str] | None = None,
122
+ deny_words: list[str] | None = None,
123
+ ) -> None:
124
+ self._patterns = [re.compile(p, re.IGNORECASE) for p in (deny_patterns or [])]
125
+ self._words = [w.lower() for w in (deny_words or [])]
126
+
127
+ def check(self, content: str) -> GuardrailResult:
128
+ content_lower = content.lower()
129
+
130
+ for word in self._words:
131
+ if word in content_lower:
132
+ return GuardrailResult(passed=False, message=f"Blocked word detected: '{word}'")
133
+
134
+ for pattern in self._patterns:
135
+ if pattern.search(content):
136
+ return GuardrailResult(
137
+ passed=False, message=f"Blocked pattern matched: {pattern.pattern}"
138
+ )
139
+
140
+ return GuardrailResult(passed=True)
141
+
142
+
143
+ class GuardrailChain:
144
+ """Chain multiple guardrails together. All must pass.
145
+
146
+ Args:
147
+ guardrails: List of guardrails to apply in order.
148
+ """
149
+
150
+ def __init__(self, guardrails: list[Guardrail]) -> None:
151
+ self._guardrails = guardrails
152
+
153
+ def check(self, content: str) -> GuardrailResult:
154
+ current_content = content
155
+
156
+ for guard in self._guardrails:
157
+ result = guard.check(current_content)
158
+ if not result.passed:
159
+ return result
160
+ if result.sanitized_content is not None:
161
+ current_content = result.sanitized_content
162
+
163
+ if current_content != content:
164
+ return GuardrailResult(passed=True, sanitized_content=current_content)
165
+ return GuardrailResult(passed=True)
@@ -0,0 +1,9 @@
1
+ """Tier 1: Orchestration patterns — Supervisor, Pipeline, FanOut, Hierarchical, OrchestratorWorkers."""
2
+
3
+ from pyagent_patterns.orchestration.fan_out_fan_in import FanOutFanIn
4
+ from pyagent_patterns.orchestration.hierarchical import Hierarchical
5
+ from pyagent_patterns.orchestration.orchestrator_workers import OrchestratorWorkers
6
+ from pyagent_patterns.orchestration.pipeline import Pipeline
7
+ from pyagent_patterns.orchestration.supervisor import Supervisor
8
+
9
+ __all__ = ["Supervisor", "Pipeline", "FanOutFanIn", "Hierarchical", "OrchestratorWorkers"]
@@ -0,0 +1,76 @@
1
+ """Fan-Out / Fan-In pattern: broadcast task to N agents in parallel, aggregate results.
2
+
3
+ All agents receive the same task and run concurrently. An aggregator
4
+ combines their outputs into a unified response.
5
+
6
+ LLM calls: N agents (parallel) + 1 aggregator = N+1 total
7
+ Wall-clock latency: max(agent latencies) + aggregator
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ from typing import AsyncIterator
14
+
15
+ from pyagent_patterns.base import Agent, Context, Message, Pattern, Result
16
+
17
+
18
+ class FanOutFanIn(Pattern):
19
+ """Parallel execution with result aggregation.
20
+
21
+ Args:
22
+ agents: List of agents to run in parallel on the same task.
23
+ aggregator: Agent that combines all parallel outputs into one.
24
+ """
25
+
26
+ def __init__(self, agents: list[Agent], aggregator: Agent) -> None:
27
+ if not agents:
28
+ raise ValueError("FanOutFanIn requires at least one agent")
29
+ self._agents = agents
30
+ self._aggregator = aggregator
31
+
32
+ @property
33
+ def pattern_type(self) -> str:
34
+ return "fan_out_fan_in"
35
+
36
+ async def _execute(self, ctx: Context) -> Result:
37
+ messages: list[Message] = []
38
+
39
+ # Fan-out: run all agents in parallel
40
+ tasks = [agent.run(ctx.messages) for agent in self._agents]
41
+ parallel_results = await asyncio.gather(*tasks)
42
+ messages.extend(parallel_results)
43
+
44
+ # Fan-in: aggregate results
45
+ combined = "\n\n".join(
46
+ f"--- {self._agents[i].name} ---\n{r.content}"
47
+ for i, r in enumerate(parallel_results)
48
+ )
49
+ agg_prompt = Message.user(
50
+ f"Combine the following analyses into a unified response:\n\n{combined}"
51
+ )
52
+ aggregated = await self._aggregator.run([agg_prompt])
53
+ messages.append(aggregated)
54
+
55
+ return Result(
56
+ output=aggregated.content,
57
+ messages=messages,
58
+ metadata={
59
+ "parallel_agents": len(self._agents),
60
+ "agent_names": [a.name for a in self._agents],
61
+ },
62
+ )
63
+
64
+ async def stream(self, task: str, context: Context | None = None) -> AsyncIterator[str]:
65
+ """Stream individual agent results as they complete."""
66
+ ctx = context or Context(task=task)
67
+ ctx.messages.append(Message.user(task))
68
+
69
+ async def _run_one(agent: Agent) -> tuple[str, str]:
70
+ result = await agent.run(ctx.messages)
71
+ return agent.name, result.content
72
+
73
+ tasks = [_run_one(agent) for agent in self._agents]
74
+ for coro in asyncio.as_completed(tasks):
75
+ name, content = await coro
76
+ yield f"[{name}] {content}"
@@ -0,0 +1,110 @@
1
+ """Hierarchical Coordination pattern: Manager → Teams → Workers.
2
+
3
+ A tiered approach where a manager delegates to team leads,
4
+ who further delegate to individual workers. Results flow back
5
+ up through the hierarchy.
6
+
7
+ LLM calls: 1 manager + T team leads + W workers
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ from dataclasses import dataclass
14
+
15
+ from pyagent_patterns.base import Agent, Context, Message, Pattern, Result
16
+
17
+
18
+ @dataclass
19
+ class Team:
20
+ """A team with a lead agent and worker agents.
21
+
22
+ Args:
23
+ name: Team name (e.g., "Research", "Risk").
24
+ lead: The team lead agent who coordinates workers.
25
+ workers: Individual worker agents on this team.
26
+ """
27
+
28
+ name: str
29
+ lead: Agent
30
+ workers: list[Agent]
31
+
32
+
33
+ class Hierarchical(Pattern):
34
+ """Manager → Team Leads → Workers hierarchical coordination.
35
+
36
+ Args:
37
+ manager: Top-level manager that decomposes work and synthesizes results.
38
+ teams: List of teams, each with a lead and workers.
39
+ """
40
+
41
+ def __init__(self, manager: Agent, teams: list[Team]) -> None:
42
+ self._manager = manager
43
+ self._teams = teams
44
+
45
+ @property
46
+ def pattern_type(self) -> str:
47
+ return "hierarchical"
48
+
49
+ async def _execute(self, ctx: Context) -> Result:
50
+ messages: list[Message] = []
51
+
52
+ # Step 1: Manager decomposes task
53
+ decompose_prompt = Message.user(
54
+ f"Decompose this task into subtasks for these teams: "
55
+ f"{', '.join(t.name for t in self._teams)}.\n"
56
+ f"For each team, provide a clear subtask description.\n\n"
57
+ f"Task: {ctx.task}"
58
+ )
59
+ manager_plan = await self._manager.run([decompose_prompt])
60
+ messages.append(manager_plan)
61
+
62
+ # Step 2: Teams work in parallel
63
+ async def _run_team(team: Team, plan: str) -> tuple[str, list[Message]]:
64
+ team_msgs: list[Message] = []
65
+
66
+ # Team lead delegates to workers
67
+ worker_tasks = [
68
+ worker.run([Message.user(f"Team {team.name} task: {plan}\nDo your part.")])
69
+ for worker in team.workers
70
+ ]
71
+ worker_results = await asyncio.gather(*worker_tasks)
72
+ team_msgs.extend(worker_results)
73
+
74
+ # Team lead synthesizes worker outputs
75
+ worker_summary = "\n".join(
76
+ f"- {team.workers[i].name}: {r.content}" for i, r in enumerate(worker_results)
77
+ )
78
+ lead_msg = Message.user(
79
+ f"Synthesize your team's work:\n{worker_summary}"
80
+ )
81
+ lead_result = await team.lead.run([lead_msg])
82
+ team_msgs.append(lead_result)
83
+ return lead_result.content, team_msgs
84
+
85
+ team_tasks = [_run_team(team, manager_plan.content) for team in self._teams]
86
+ team_outputs = await asyncio.gather(*team_tasks)
87
+
88
+ for _, team_msgs in team_outputs:
89
+ messages.extend(team_msgs)
90
+
91
+ # Step 3: Manager synthesizes all team outputs
92
+ team_summary = "\n\n".join(
93
+ f"--- {self._teams[i].name} Team ---\n{output}"
94
+ for i, (output, _) in enumerate(team_outputs)
95
+ )
96
+ synthesis_prompt = Message.user(
97
+ f"Synthesize these team outputs into a final response:\n\n{team_summary}"
98
+ )
99
+ final = await self._manager.run([synthesis_prompt])
100
+ messages.append(final)
101
+
102
+ return Result(
103
+ output=final.content,
104
+ messages=messages,
105
+ metadata={
106
+ "teams": len(self._teams),
107
+ "total_workers": sum(len(t.workers) for t in self._teams),
108
+ "team_names": [t.name for t in self._teams],
109
+ },
110
+ )