aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Usage tracking for LLM and other services."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import asyncio
|
|
5
|
+
from dataclasses import dataclass, field, asdict
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from decimal import Decimal
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from ..core.event_bus import EventBus, Events
|
|
12
|
+
|
|
13
|
+
# Alias for backward compatibility
|
|
14
|
+
Bus = EventBus
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UsageType(Enum):
|
|
18
|
+
"""Types of usage to track."""
|
|
19
|
+
LLM_INPUT = "llm_input"
|
|
20
|
+
LLM_OUTPUT = "llm_output"
|
|
21
|
+
LLM_CACHE_READ = "llm_cache_read"
|
|
22
|
+
LLM_CACHE_WRITE = "llm_cache_write"
|
|
23
|
+
EMBEDDING = "embedding"
|
|
24
|
+
IMAGE_GEN = "image_gen"
|
|
25
|
+
SPEECH = "speech"
|
|
26
|
+
SEARCH = "search"
|
|
27
|
+
EXTERNAL_API = "external_api"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class UsageEntry:
|
|
32
|
+
"""A single usage record."""
|
|
33
|
+
id: str
|
|
34
|
+
type: UsageType
|
|
35
|
+
provider: str
|
|
36
|
+
model: str | None = None
|
|
37
|
+
units: int = 0
|
|
38
|
+
unit_type: str = "tokens" # tokens, requests, characters, etc.
|
|
39
|
+
cost: Decimal | None = None
|
|
40
|
+
|
|
41
|
+
# Context
|
|
42
|
+
session_id: str | None = None
|
|
43
|
+
invocation_id: str | None = None
|
|
44
|
+
tool_id: str | None = None
|
|
45
|
+
step: int | None = None
|
|
46
|
+
|
|
47
|
+
# Metadata
|
|
48
|
+
created_at: datetime = field(default_factory=datetime.now)
|
|
49
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
def to_dict(self) -> dict[str, Any]:
|
|
52
|
+
return {
|
|
53
|
+
"id": self.id,
|
|
54
|
+
"type": self.type.value,
|
|
55
|
+
"provider": self.provider,
|
|
56
|
+
"model": self.model,
|
|
57
|
+
"units": self.units,
|
|
58
|
+
"unit_type": self.unit_type,
|
|
59
|
+
"cost": float(self.cost) if self.cost else None,
|
|
60
|
+
"session_id": self.session_id,
|
|
61
|
+
"invocation_id": self.invocation_id,
|
|
62
|
+
"tool_id": self.tool_id,
|
|
63
|
+
"step": self.step,
|
|
64
|
+
"created_at": self.created_at.isoformat(),
|
|
65
|
+
"metadata": self.metadata,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class UsageTracker:
|
|
70
|
+
"""Track usage of LLM and other services.
|
|
71
|
+
|
|
72
|
+
Provides methods for recording usage, computing costs,
|
|
73
|
+
and generating summaries.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, bus: Bus | None = None):
|
|
77
|
+
self.bus = bus
|
|
78
|
+
self._entries: list[UsageEntry] = []
|
|
79
|
+
self._lock = asyncio.Lock()
|
|
80
|
+
self._counter = 0
|
|
81
|
+
|
|
82
|
+
def _generate_id(self) -> str:
|
|
83
|
+
self._counter += 1
|
|
84
|
+
return f"usage_{self._counter:06d}"
|
|
85
|
+
|
|
86
|
+
async def record(self, entry: UsageEntry) -> None:
|
|
87
|
+
"""Record a usage entry."""
|
|
88
|
+
async with self._lock:
|
|
89
|
+
self._entries.append(entry)
|
|
90
|
+
|
|
91
|
+
if self.bus:
|
|
92
|
+
await self.bus.publish(Events.USAGE_RECORDED, entry.to_dict())
|
|
93
|
+
|
|
94
|
+
async def record_llm(
|
|
95
|
+
self,
|
|
96
|
+
provider: str,
|
|
97
|
+
model: str,
|
|
98
|
+
input_tokens: int = 0,
|
|
99
|
+
output_tokens: int = 0,
|
|
100
|
+
cache_read_tokens: int = 0,
|
|
101
|
+
cache_write_tokens: int = 0,
|
|
102
|
+
session_id: str | None = None,
|
|
103
|
+
invocation_id: str | None = None,
|
|
104
|
+
step: int | None = None,
|
|
105
|
+
**metadata: Any,
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Convenience method for recording LLM usage."""
|
|
108
|
+
base = {
|
|
109
|
+
"provider": provider,
|
|
110
|
+
"model": model,
|
|
111
|
+
"unit_type": "tokens",
|
|
112
|
+
"session_id": session_id,
|
|
113
|
+
"invocation_id": invocation_id,
|
|
114
|
+
"step": step,
|
|
115
|
+
"metadata": metadata,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if input_tokens > 0:
|
|
119
|
+
await self.record(UsageEntry(
|
|
120
|
+
id=self._generate_id(),
|
|
121
|
+
type=UsageType.LLM_INPUT,
|
|
122
|
+
units=input_tokens,
|
|
123
|
+
**base,
|
|
124
|
+
))
|
|
125
|
+
|
|
126
|
+
if output_tokens > 0:
|
|
127
|
+
await self.record(UsageEntry(
|
|
128
|
+
id=self._generate_id(),
|
|
129
|
+
type=UsageType.LLM_OUTPUT,
|
|
130
|
+
units=output_tokens,
|
|
131
|
+
**base,
|
|
132
|
+
))
|
|
133
|
+
|
|
134
|
+
if cache_read_tokens > 0:
|
|
135
|
+
await self.record(UsageEntry(
|
|
136
|
+
id=self._generate_id(),
|
|
137
|
+
type=UsageType.LLM_CACHE_READ,
|
|
138
|
+
units=cache_read_tokens,
|
|
139
|
+
**base,
|
|
140
|
+
))
|
|
141
|
+
|
|
142
|
+
if cache_write_tokens > 0:
|
|
143
|
+
await self.record(UsageEntry(
|
|
144
|
+
id=self._generate_id(),
|
|
145
|
+
type=UsageType.LLM_CACHE_WRITE,
|
|
146
|
+
units=cache_write_tokens,
|
|
147
|
+
**base,
|
|
148
|
+
))
|
|
149
|
+
|
|
150
|
+
async def record_embedding(
|
|
151
|
+
self,
|
|
152
|
+
provider: str,
|
|
153
|
+
model: str,
|
|
154
|
+
tokens: int,
|
|
155
|
+
session_id: str | None = None,
|
|
156
|
+
**metadata: Any,
|
|
157
|
+
) -> None:
|
|
158
|
+
"""Record embedding usage."""
|
|
159
|
+
await self.record(UsageEntry(
|
|
160
|
+
id=self._generate_id(),
|
|
161
|
+
type=UsageType.EMBEDDING,
|
|
162
|
+
provider=provider,
|
|
163
|
+
model=model,
|
|
164
|
+
units=tokens,
|
|
165
|
+
unit_type="tokens",
|
|
166
|
+
session_id=session_id,
|
|
167
|
+
metadata=metadata,
|
|
168
|
+
))
|
|
169
|
+
|
|
170
|
+
def summarize(
|
|
171
|
+
self,
|
|
172
|
+
session_id: str | None = None,
|
|
173
|
+
invocation_id: str | None = None,
|
|
174
|
+
) -> dict[str, Any]:
|
|
175
|
+
"""Generate usage summary."""
|
|
176
|
+
entries = self._entries
|
|
177
|
+
|
|
178
|
+
if session_id:
|
|
179
|
+
entries = [e for e in entries if e.session_id == session_id]
|
|
180
|
+
if invocation_id:
|
|
181
|
+
entries = [e for e in entries if e.invocation_id == invocation_id]
|
|
182
|
+
|
|
183
|
+
by_type: dict[str, int] = {}
|
|
184
|
+
by_provider: dict[str, int] = {}
|
|
185
|
+
by_model: dict[str, int] = {}
|
|
186
|
+
total_cost = Decimal(0)
|
|
187
|
+
|
|
188
|
+
for e in entries:
|
|
189
|
+
by_type[e.type.value] = by_type.get(e.type.value, 0) + e.units
|
|
190
|
+
by_provider[e.provider] = by_provider.get(e.provider, 0) + e.units
|
|
191
|
+
if e.model:
|
|
192
|
+
by_model[e.model] = by_model.get(e.model, 0) + e.units
|
|
193
|
+
if e.cost:
|
|
194
|
+
total_cost += e.cost
|
|
195
|
+
|
|
196
|
+
return {
|
|
197
|
+
"total_cost": float(total_cost),
|
|
198
|
+
"by_type": by_type,
|
|
199
|
+
"by_provider": by_provider,
|
|
200
|
+
"by_model": by_model,
|
|
201
|
+
"entry_count": len(entries),
|
|
202
|
+
"total_tokens": sum(by_type.values()),
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
def get_entries(
|
|
206
|
+
self,
|
|
207
|
+
session_id: str | None = None,
|
|
208
|
+
invocation_id: str | None = None,
|
|
209
|
+
type_filter: UsageType | None = None,
|
|
210
|
+
) -> list[UsageEntry]:
|
|
211
|
+
"""Get filtered usage entries."""
|
|
212
|
+
entries = self._entries
|
|
213
|
+
|
|
214
|
+
if session_id:
|
|
215
|
+
entries = [e for e in entries if e.session_id == session_id]
|
|
216
|
+
if invocation_id:
|
|
217
|
+
entries = [e for e in entries if e.invocation_id == invocation_id]
|
|
218
|
+
if type_filter:
|
|
219
|
+
entries = [e for e in entries if e.type == type_filter]
|
|
220
|
+
|
|
221
|
+
return entries
|
|
222
|
+
|
|
223
|
+
def clear(self, session_id: str | None = None) -> int:
|
|
224
|
+
"""Clear usage entries.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Number of entries cleared
|
|
228
|
+
"""
|
|
229
|
+
if session_id:
|
|
230
|
+
original = len(self._entries)
|
|
231
|
+
self._entries = [e for e in self._entries if e.session_id != session_id]
|
|
232
|
+
return original - len(self._entries)
|
|
233
|
+
|
|
234
|
+
count = len(self._entries)
|
|
235
|
+
self._entries.clear()
|
|
236
|
+
return count
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Workflow system for DAG-based orchestration."""
|
|
2
|
+
from .types import (
|
|
3
|
+
NodeType,
|
|
4
|
+
Position,
|
|
5
|
+
NodeSpec,
|
|
6
|
+
EdgeSpec,
|
|
7
|
+
WorkflowSpec,
|
|
8
|
+
Workflow,
|
|
9
|
+
)
|
|
10
|
+
from .parser import (
|
|
11
|
+
WorkflowParser,
|
|
12
|
+
WorkflowValidationError,
|
|
13
|
+
)
|
|
14
|
+
from .expression import (
|
|
15
|
+
ExpressionEvaluator,
|
|
16
|
+
ExpressionError,
|
|
17
|
+
)
|
|
18
|
+
from .state import (
|
|
19
|
+
WorkflowState,
|
|
20
|
+
MergeStrategy,
|
|
21
|
+
CollectListStrategy,
|
|
22
|
+
CollectDictStrategy,
|
|
23
|
+
FirstSuccessStrategy,
|
|
24
|
+
get_merge_strategy,
|
|
25
|
+
)
|
|
26
|
+
from .dag import DAGExecutor
|
|
27
|
+
from .executor import WorkflowExecutor
|
|
28
|
+
from ..core.factory import AgentFactory
|
|
29
|
+
from .adapter import WorkflowAgent
|
|
30
|
+
from .dsl import (
|
|
31
|
+
DSLNode,
|
|
32
|
+
DSLSequence,
|
|
33
|
+
DSLStep,
|
|
34
|
+
DSLParallel,
|
|
35
|
+
DSLCondition,
|
|
36
|
+
DSLWorkflow,
|
|
37
|
+
workflow,
|
|
38
|
+
step,
|
|
39
|
+
parallel,
|
|
40
|
+
condition,
|
|
41
|
+
skip,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
# Types
|
|
46
|
+
"NodeType",
|
|
47
|
+
"Position",
|
|
48
|
+
"NodeSpec",
|
|
49
|
+
"EdgeSpec",
|
|
50
|
+
"WorkflowSpec",
|
|
51
|
+
"Workflow",
|
|
52
|
+
# Parser
|
|
53
|
+
"WorkflowParser",
|
|
54
|
+
"WorkflowValidationError",
|
|
55
|
+
# Expression
|
|
56
|
+
"ExpressionEvaluator",
|
|
57
|
+
"ExpressionError",
|
|
58
|
+
# State
|
|
59
|
+
"WorkflowState",
|
|
60
|
+
"MergeStrategy",
|
|
61
|
+
"CollectListStrategy",
|
|
62
|
+
"CollectDictStrategy",
|
|
63
|
+
"FirstSuccessStrategy",
|
|
64
|
+
"get_merge_strategy",
|
|
65
|
+
# DAG
|
|
66
|
+
"DAGExecutor",
|
|
67
|
+
# Executor
|
|
68
|
+
"WorkflowExecutor",
|
|
69
|
+
# Factory
|
|
70
|
+
"AgentFactory",
|
|
71
|
+
# Agent
|
|
72
|
+
"WorkflowAgent",
|
|
73
|
+
# DSL
|
|
74
|
+
"DSLNode",
|
|
75
|
+
"DSLSequence",
|
|
76
|
+
"DSLStep",
|
|
77
|
+
"DSLParallel",
|
|
78
|
+
"DSLCondition",
|
|
79
|
+
"DSLWorkflow",
|
|
80
|
+
"workflow",
|
|
81
|
+
"step",
|
|
82
|
+
"parallel",
|
|
83
|
+
"condition",
|
|
84
|
+
"skip",
|
|
85
|
+
]
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""Workflow agent - executes DAG-based workflows.
|
|
2
|
+
|
|
3
|
+
WorkflowAgent uses the unified BaseAgent constructor:
|
|
4
|
+
__init__(self, ctx: InvocationContext, config: AgentConfig | None = None)
|
|
5
|
+
|
|
6
|
+
Workflow definition and AgentFactory are passed via config or set after init.
|
|
7
|
+
|
|
8
|
+
Middleware hooks are called at:
|
|
9
|
+
- on_agent_start/end: workflow start/end
|
|
10
|
+
- on_subagent_start/end: each node execution (in executor)
|
|
11
|
+
"""
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from typing import Any, AsyncIterator, ClassVar, Literal, TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
from ..core.base import BaseAgent, AgentConfig
|
|
17
|
+
from ..core.context import InvocationContext
|
|
18
|
+
from ..core.logging import workflow_logger as logger
|
|
19
|
+
from ..core.types.block import BlockEvent, BlockKind, BlockOp
|
|
20
|
+
from ..middleware import HookAction
|
|
21
|
+
from .types import Workflow
|
|
22
|
+
from .executor import WorkflowExecutor
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ..core.factory import AgentFactory
|
|
26
|
+
from ..core.types.session import Session
|
|
27
|
+
from ..backends import Backends
|
|
28
|
+
from ..core.event_bus import Bus
|
|
29
|
+
from ..middleware import MiddlewareChain
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class WorkflowAgent(BaseAgent):
|
|
33
|
+
"""Workflow agent - executes DAG-based workflows.
|
|
34
|
+
|
|
35
|
+
Two ways to create:
|
|
36
|
+
|
|
37
|
+
1. Simple (recommended):
|
|
38
|
+
agent = WorkflowAgent.create(workflow=wf, agent_factory=factory)
|
|
39
|
+
|
|
40
|
+
2. Advanced (for custom Session/Backends/Bus):
|
|
41
|
+
ctx = InvocationContext(session=session, backends=backends, bus=bus)
|
|
42
|
+
agent = WorkflowAgent(ctx, config)
|
|
43
|
+
agent.set_workflow(workflow, agent_factory)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
agent_type: ClassVar[Literal["react", "workflow"]] = "workflow"
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def create(
|
|
50
|
+
cls,
|
|
51
|
+
workflow: Workflow,
|
|
52
|
+
agent_factory: "WorkflowAgentFactory",
|
|
53
|
+
config: AgentConfig | None = None,
|
|
54
|
+
*,
|
|
55
|
+
backends: "Backends | None" = None,
|
|
56
|
+
session: "Session | None" = None,
|
|
57
|
+
bus: "Bus | None" = None,
|
|
58
|
+
middleware: "MiddlewareChain | None" = None,
|
|
59
|
+
) -> "WorkflowAgent":
|
|
60
|
+
"""Create WorkflowAgent with minimal boilerplate.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
workflow: Workflow definition
|
|
64
|
+
agent_factory: Factory for creating sub-agents
|
|
65
|
+
config: Agent configuration (optional)
|
|
66
|
+
backends: Backends container (auto-created if None)
|
|
67
|
+
session: Session object (auto-created if None)
|
|
68
|
+
bus: Event bus (auto-created if None)
|
|
69
|
+
middleware: Middleware chain (optional)
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Configured WorkflowAgent ready to run
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
agent = WorkflowAgent.create(
|
|
76
|
+
workflow=my_workflow,
|
|
77
|
+
agent_factory=factory,
|
|
78
|
+
)
|
|
79
|
+
async for response in agent.run(inputs):
|
|
80
|
+
print(response)
|
|
81
|
+
"""
|
|
82
|
+
from ..core.event_bus import Bus
|
|
83
|
+
from ..core.types.session import Session, generate_id
|
|
84
|
+
from ..backends import Backends
|
|
85
|
+
|
|
86
|
+
# Auto-create backends if not provided
|
|
87
|
+
if backends is None:
|
|
88
|
+
backends = Backends.create_default()
|
|
89
|
+
|
|
90
|
+
# Auto-create missing components
|
|
91
|
+
if session is None:
|
|
92
|
+
session = Session(id=generate_id("sess"))
|
|
93
|
+
if bus is None:
|
|
94
|
+
bus = Bus()
|
|
95
|
+
|
|
96
|
+
# Build context
|
|
97
|
+
ctx = InvocationContext(
|
|
98
|
+
session=session,
|
|
99
|
+
invocation_id=generate_id("inv"),
|
|
100
|
+
agent_id=config.name if config else workflow.spec.name,
|
|
101
|
+
backends=backends,
|
|
102
|
+
bus=bus,
|
|
103
|
+
middleware=middleware,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
agent = cls(ctx, config)
|
|
107
|
+
agent.set_workflow(workflow, agent_factory)
|
|
108
|
+
return agent
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
ctx: InvocationContext,
|
|
113
|
+
config: AgentConfig | None = None,
|
|
114
|
+
):
|
|
115
|
+
"""Initialize WorkflowAgent.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
ctx: InvocationContext with storage, bus, etc.
|
|
119
|
+
config: Agent configuration (may contain workflow in metadata)
|
|
120
|
+
"""
|
|
121
|
+
super().__init__(ctx, config)
|
|
122
|
+
|
|
123
|
+
# Workflow can be set via config.metadata or set_workflow()
|
|
124
|
+
self.workflow: Workflow | None = self.config.metadata.get('workflow')
|
|
125
|
+
self.agent_factory: "WorkflowAgentFactory | None" = self.config.metadata.get('agent_factory')
|
|
126
|
+
self._executor: WorkflowExecutor | None = None
|
|
127
|
+
|
|
128
|
+
def set_workflow(
|
|
129
|
+
self,
|
|
130
|
+
workflow: Workflow,
|
|
131
|
+
agent_factory: "WorkflowAgentFactory",
|
|
132
|
+
) -> "WorkflowAgent":
|
|
133
|
+
"""Set workflow definition and factory.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
workflow: Workflow definition
|
|
137
|
+
agent_factory: Factory for creating sub-agents
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Self for chaining
|
|
141
|
+
"""
|
|
142
|
+
self.workflow = workflow
|
|
143
|
+
self.agent_factory = agent_factory
|
|
144
|
+
return self
|
|
145
|
+
|
|
146
|
+
async def _execute(self, input: Any) -> None:
|
|
147
|
+
"""Execute workflow with middleware hooks.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
input: Workflow inputs (dict or single value)
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If workflow or agent_factory not set
|
|
154
|
+
"""
|
|
155
|
+
# Validate workflow is set
|
|
156
|
+
if self.workflow is None:
|
|
157
|
+
raise ValueError("Workflow not set. Call set_workflow() first.")
|
|
158
|
+
if self.agent_factory is None:
|
|
159
|
+
raise ValueError("AgentFactory not set. Call set_workflow() first.")
|
|
160
|
+
|
|
161
|
+
inputs = input if isinstance(input, dict) else {"input": input}
|
|
162
|
+
|
|
163
|
+
# Build middleware context
|
|
164
|
+
mw_context = {
|
|
165
|
+
"session_id": self.session.id,
|
|
166
|
+
"invocation_id": self.ctx.invocation_id,
|
|
167
|
+
"agent_id": self.name,
|
|
168
|
+
"agent_type": self.agent_type,
|
|
169
|
+
"workflow_name": self.workflow.spec.name,
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
# === Middleware: on_agent_start ===
|
|
173
|
+
if self.middleware:
|
|
174
|
+
hook_result = await self.middleware.process_agent_start(
|
|
175
|
+
self.name, inputs, mw_context
|
|
176
|
+
)
|
|
177
|
+
if hook_result.action == HookAction.STOP:
|
|
178
|
+
logger.info("Workflow stopped by middleware on_agent_start")
|
|
179
|
+
await self.ctx.emit(BlockEvent(
|
|
180
|
+
kind=BlockKind.ERROR,
|
|
181
|
+
op=BlockOp.APPLY,
|
|
182
|
+
data={"message": hook_result.message or "Stopped by middleware"},
|
|
183
|
+
))
|
|
184
|
+
return
|
|
185
|
+
elif hook_result.action == HookAction.SKIP:
|
|
186
|
+
logger.info("Workflow skipped by middleware on_agent_start")
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
# Create executor with context (pass middleware)
|
|
191
|
+
self._executor = WorkflowExecutor(
|
|
192
|
+
workflow=self.workflow,
|
|
193
|
+
agent_factory=self.agent_factory,
|
|
194
|
+
ctx=self.ctx,
|
|
195
|
+
middleware=self.middleware,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
result = await self._executor.execute(inputs)
|
|
199
|
+
|
|
200
|
+
# === Middleware: on_agent_end ===
|
|
201
|
+
if self.middleware:
|
|
202
|
+
await self.middleware.process_agent_end(
|
|
203
|
+
self.name, result, mw_context
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
except Exception as e:
|
|
207
|
+
# === Middleware: on_error ===
|
|
208
|
+
if self.middleware:
|
|
209
|
+
processed_error = await self.middleware.process_error(e, mw_context)
|
|
210
|
+
if processed_error is None:
|
|
211
|
+
logger.info("Error suppressed by middleware")
|
|
212
|
+
return
|
|
213
|
+
raise
|
|
214
|
+
|
|
215
|
+
async def pause(self) -> str:
|
|
216
|
+
"""Pause workflow and return invocation ID."""
|
|
217
|
+
if self._executor:
|
|
218
|
+
self._executor.pause()
|
|
219
|
+
return self.ctx.invocation_id
|
|
220
|
+
|
|
221
|
+
async def _resume_internal(self, invocation_id: str) -> None:
|
|
222
|
+
"""Internal resume logic."""
|
|
223
|
+
if self.workflow is None or self.agent_factory is None:
|
|
224
|
+
raise ValueError("Workflow not set. Call set_workflow() first.")
|
|
225
|
+
|
|
226
|
+
# Load saved state
|
|
227
|
+
state_key = f"workflow_state:{invocation_id}"
|
|
228
|
+
if not self.ctx.backends or not self.ctx.backends.state:
|
|
229
|
+
raise ValueError("No state backend available")
|
|
230
|
+
saved_state = await self.ctx.backends.state.get("workflow", state_key)
|
|
231
|
+
|
|
232
|
+
if not saved_state:
|
|
233
|
+
raise ValueError(f"No saved state for invocation: {invocation_id}")
|
|
234
|
+
|
|
235
|
+
inputs = saved_state.get("inputs", {})
|
|
236
|
+
|
|
237
|
+
# Create executor and resume
|
|
238
|
+
self._executor = WorkflowExecutor(
|
|
239
|
+
workflow=self.workflow,
|
|
240
|
+
agent_factory=self.agent_factory,
|
|
241
|
+
ctx=self.ctx,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
await self._executor.execute(inputs, resume_state=saved_state)
|
|
245
|
+
|
|
246
|
+
async def resume(self, invocation_id: str) -> AsyncIterator[BlockEvent]:
|
|
247
|
+
"""Resume paused workflow."""
|
|
248
|
+
import asyncio
|
|
249
|
+
from ..core.context import _emit_queue_var
|
|
250
|
+
|
|
251
|
+
queue: asyncio.Queue[BlockEvent] = asyncio.Queue()
|
|
252
|
+
token = _emit_queue_var.set(queue)
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
exec_task = asyncio.create_task(self._resume_internal(invocation_id))
|
|
256
|
+
|
|
257
|
+
while not exec_task.done() or not queue.empty():
|
|
258
|
+
try:
|
|
259
|
+
block = await asyncio.wait_for(queue.get(), timeout=0.05)
|
|
260
|
+
yield block
|
|
261
|
+
except asyncio.TimeoutError:
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
await exec_task
|
|
265
|
+
|
|
266
|
+
finally:
|
|
267
|
+
_emit_queue_var.reset(token)
|
|
268
|
+
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Generic DAG executor."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, Callable, TypeVar
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DAGExecutor:
|
|
10
|
+
"""Generic DAG parallel executor."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
tasks: list[T],
|
|
15
|
+
get_task_id: Callable[[T], str],
|
|
16
|
+
get_dependencies: Callable[[T], list[str]],
|
|
17
|
+
):
|
|
18
|
+
self.tasks = {get_task_id(task): task for task in tasks}
|
|
19
|
+
self.get_task_id = get_task_id
|
|
20
|
+
self.get_dependencies = get_dependencies
|
|
21
|
+
|
|
22
|
+
self.completed: set[str] = set()
|
|
23
|
+
self.failed: set[str] = set()
|
|
24
|
+
self.running: set[str] = set()
|
|
25
|
+
self.skipped: set[str] = set()
|
|
26
|
+
|
|
27
|
+
def get_ready_tasks(self) -> list[T]:
|
|
28
|
+
"""Get tasks ready for execution.
|
|
29
|
+
|
|
30
|
+
A task is ready when all its dependencies are completed or skipped.
|
|
31
|
+
If a dependency has failed, the task will never be ready.
|
|
32
|
+
"""
|
|
33
|
+
ready = []
|
|
34
|
+
processed = self.completed | self.failed | self.running | self.skipped
|
|
35
|
+
|
|
36
|
+
for task_id, task in self.tasks.items():
|
|
37
|
+
if task_id in processed:
|
|
38
|
+
continue
|
|
39
|
+
|
|
40
|
+
deps = self.get_dependencies(task)
|
|
41
|
+
deps_satisfied = all(
|
|
42
|
+
dep_id in (self.completed | self.skipped)
|
|
43
|
+
for dep_id in deps
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if deps_satisfied:
|
|
47
|
+
ready.append(task)
|
|
48
|
+
|
|
49
|
+
return ready
|
|
50
|
+
|
|
51
|
+
def is_blocked(self) -> bool:
|
|
52
|
+
"""Check if DAG is blocked due to failed dependencies.
|
|
53
|
+
|
|
54
|
+
Returns True if there are unprocessed tasks that can never run
|
|
55
|
+
because their dependencies have failed.
|
|
56
|
+
"""
|
|
57
|
+
if not self.failed:
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
processed = self.completed | self.failed | self.running | self.skipped
|
|
61
|
+
|
|
62
|
+
for task_id, task in self.tasks.items():
|
|
63
|
+
if task_id in processed:
|
|
64
|
+
continue
|
|
65
|
+
|
|
66
|
+
# Check if any dependency has failed
|
|
67
|
+
deps = self.get_dependencies(task)
|
|
68
|
+
for dep_id in deps:
|
|
69
|
+
if dep_id in self.failed:
|
|
70
|
+
return True
|
|
71
|
+
|
|
72
|
+
return False
|
|
73
|
+
|
|
74
|
+
def mark_running(self, task_id: str) -> None:
|
|
75
|
+
self.running.add(task_id)
|
|
76
|
+
|
|
77
|
+
def mark_completed(self, task_id: str) -> None:
|
|
78
|
+
self.running.discard(task_id)
|
|
79
|
+
self.completed.add(task_id)
|
|
80
|
+
|
|
81
|
+
def mark_failed(self, task_id: str) -> None:
|
|
82
|
+
self.running.discard(task_id)
|
|
83
|
+
self.failed.add(task_id)
|
|
84
|
+
|
|
85
|
+
def mark_skipped(self, task_id: str) -> None:
|
|
86
|
+
self.skipped.add(task_id)
|
|
87
|
+
|
|
88
|
+
def is_finished(self) -> bool:
|
|
89
|
+
total_processed = len(self.completed) + len(self.failed) + len(self.skipped)
|
|
90
|
+
return total_processed == len(self.tasks)
|
|
91
|
+
|
|
92
|
+
def has_failures(self) -> bool:
|
|
93
|
+
return len(self.failed) > 0
|
|
94
|
+
|
|
95
|
+
def get_status(self) -> dict[str, Any]:
|
|
96
|
+
pending = (
|
|
97
|
+
len(self.tasks)
|
|
98
|
+
- len(self.completed)
|
|
99
|
+
- len(self.failed)
|
|
100
|
+
- len(self.skipped)
|
|
101
|
+
- len(self.running)
|
|
102
|
+
)
|
|
103
|
+
return {
|
|
104
|
+
"total": len(self.tasks),
|
|
105
|
+
"completed": len(self.completed),
|
|
106
|
+
"failed": len(self.failed),
|
|
107
|
+
"skipped": len(self.skipped),
|
|
108
|
+
"running": len(self.running),
|
|
109
|
+
"pending": pending,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
def reset(self) -> None:
|
|
113
|
+
self.completed.clear()
|
|
114
|
+
self.failed.clear()
|
|
115
|
+
self.running.clear()
|
|
116
|
+
self.skipped.clear()
|