aury-agent 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. aury/__init__.py +2 -0
  2. aury/agents/__init__.py +55 -0
  3. aury/agents/a2a/__init__.py +168 -0
  4. aury/agents/backends/__init__.py +196 -0
  5. aury/agents/backends/artifact/__init__.py +9 -0
  6. aury/agents/backends/artifact/memory.py +130 -0
  7. aury/agents/backends/artifact/types.py +133 -0
  8. aury/agents/backends/code/__init__.py +65 -0
  9. aury/agents/backends/file/__init__.py +11 -0
  10. aury/agents/backends/file/local.py +66 -0
  11. aury/agents/backends/file/types.py +40 -0
  12. aury/agents/backends/invocation/__init__.py +8 -0
  13. aury/agents/backends/invocation/memory.py +81 -0
  14. aury/agents/backends/invocation/types.py +110 -0
  15. aury/agents/backends/memory/__init__.py +8 -0
  16. aury/agents/backends/memory/memory.py +179 -0
  17. aury/agents/backends/memory/types.py +136 -0
  18. aury/agents/backends/message/__init__.py +9 -0
  19. aury/agents/backends/message/memory.py +122 -0
  20. aury/agents/backends/message/types.py +124 -0
  21. aury/agents/backends/sandbox.py +275 -0
  22. aury/agents/backends/session/__init__.py +8 -0
  23. aury/agents/backends/session/memory.py +93 -0
  24. aury/agents/backends/session/types.py +124 -0
  25. aury/agents/backends/shell/__init__.py +11 -0
  26. aury/agents/backends/shell/local.py +110 -0
  27. aury/agents/backends/shell/types.py +55 -0
  28. aury/agents/backends/shell.py +209 -0
  29. aury/agents/backends/snapshot/__init__.py +19 -0
  30. aury/agents/backends/snapshot/git.py +95 -0
  31. aury/agents/backends/snapshot/hybrid.py +125 -0
  32. aury/agents/backends/snapshot/memory.py +86 -0
  33. aury/agents/backends/snapshot/types.py +59 -0
  34. aury/agents/backends/state/__init__.py +29 -0
  35. aury/agents/backends/state/composite.py +49 -0
  36. aury/agents/backends/state/file.py +57 -0
  37. aury/agents/backends/state/memory.py +52 -0
  38. aury/agents/backends/state/sqlite.py +262 -0
  39. aury/agents/backends/state/types.py +178 -0
  40. aury/agents/backends/subagent/__init__.py +165 -0
  41. aury/agents/cli/__init__.py +41 -0
  42. aury/agents/cli/chat.py +239 -0
  43. aury/agents/cli/config.py +236 -0
  44. aury/agents/cli/extensions.py +460 -0
  45. aury/agents/cli/main.py +189 -0
  46. aury/agents/cli/session.py +337 -0
  47. aury/agents/cli/workflow.py +276 -0
  48. aury/agents/context_providers/__init__.py +66 -0
  49. aury/agents/context_providers/artifact.py +299 -0
  50. aury/agents/context_providers/base.py +177 -0
  51. aury/agents/context_providers/memory.py +70 -0
  52. aury/agents/context_providers/message.py +130 -0
  53. aury/agents/context_providers/skill.py +50 -0
  54. aury/agents/context_providers/subagent.py +46 -0
  55. aury/agents/context_providers/tool.py +68 -0
  56. aury/agents/core/__init__.py +83 -0
  57. aury/agents/core/base.py +573 -0
  58. aury/agents/core/context.py +797 -0
  59. aury/agents/core/context_builder.py +303 -0
  60. aury/agents/core/event_bus/__init__.py +15 -0
  61. aury/agents/core/event_bus/bus.py +203 -0
  62. aury/agents/core/factory.py +169 -0
  63. aury/agents/core/isolator.py +97 -0
  64. aury/agents/core/logging.py +95 -0
  65. aury/agents/core/parallel.py +194 -0
  66. aury/agents/core/runner.py +139 -0
  67. aury/agents/core/services/__init__.py +5 -0
  68. aury/agents/core/services/file_session.py +144 -0
  69. aury/agents/core/services/message.py +53 -0
  70. aury/agents/core/services/session.py +53 -0
  71. aury/agents/core/signals.py +109 -0
  72. aury/agents/core/state.py +363 -0
  73. aury/agents/core/types/__init__.py +107 -0
  74. aury/agents/core/types/action.py +176 -0
  75. aury/agents/core/types/artifact.py +135 -0
  76. aury/agents/core/types/block.py +736 -0
  77. aury/agents/core/types/message.py +350 -0
  78. aury/agents/core/types/recall.py +144 -0
  79. aury/agents/core/types/session.py +257 -0
  80. aury/agents/core/types/subagent.py +154 -0
  81. aury/agents/core/types/tool.py +205 -0
  82. aury/agents/eval/__init__.py +331 -0
  83. aury/agents/hitl/__init__.py +57 -0
  84. aury/agents/hitl/ask_user.py +242 -0
  85. aury/agents/hitl/compaction.py +230 -0
  86. aury/agents/hitl/exceptions.py +87 -0
  87. aury/agents/hitl/permission.py +617 -0
  88. aury/agents/hitl/revert.py +216 -0
  89. aury/agents/llm/__init__.py +31 -0
  90. aury/agents/llm/adapter.py +367 -0
  91. aury/agents/llm/openai.py +294 -0
  92. aury/agents/llm/provider.py +476 -0
  93. aury/agents/mcp/__init__.py +153 -0
  94. aury/agents/memory/__init__.py +46 -0
  95. aury/agents/memory/compaction.py +394 -0
  96. aury/agents/memory/manager.py +465 -0
  97. aury/agents/memory/processor.py +177 -0
  98. aury/agents/memory/store.py +187 -0
  99. aury/agents/memory/types.py +137 -0
  100. aury/agents/messages/__init__.py +40 -0
  101. aury/agents/messages/config.py +47 -0
  102. aury/agents/messages/raw_store.py +224 -0
  103. aury/agents/messages/store.py +118 -0
  104. aury/agents/messages/types.py +88 -0
  105. aury/agents/middleware/__init__.py +31 -0
  106. aury/agents/middleware/base.py +341 -0
  107. aury/agents/middleware/chain.py +342 -0
  108. aury/agents/middleware/message.py +129 -0
  109. aury/agents/middleware/message_container.py +126 -0
  110. aury/agents/middleware/raw_message.py +153 -0
  111. aury/agents/middleware/truncation.py +139 -0
  112. aury/agents/middleware/types.py +81 -0
  113. aury/agents/plugin.py +162 -0
  114. aury/agents/react/__init__.py +4 -0
  115. aury/agents/react/agent.py +1923 -0
  116. aury/agents/sandbox/__init__.py +23 -0
  117. aury/agents/sandbox/local.py +239 -0
  118. aury/agents/sandbox/remote.py +200 -0
  119. aury/agents/sandbox/types.py +115 -0
  120. aury/agents/skill/__init__.py +16 -0
  121. aury/agents/skill/loader.py +180 -0
  122. aury/agents/skill/types.py +83 -0
  123. aury/agents/tool/__init__.py +39 -0
  124. aury/agents/tool/builtin/__init__.py +23 -0
  125. aury/agents/tool/builtin/ask_user.py +155 -0
  126. aury/agents/tool/builtin/bash.py +107 -0
  127. aury/agents/tool/builtin/delegate.py +726 -0
  128. aury/agents/tool/builtin/edit.py +121 -0
  129. aury/agents/tool/builtin/plan.py +277 -0
  130. aury/agents/tool/builtin/read.py +91 -0
  131. aury/agents/tool/builtin/thinking.py +111 -0
  132. aury/agents/tool/builtin/yield_result.py +130 -0
  133. aury/agents/tool/decorator.py +252 -0
  134. aury/agents/tool/set.py +204 -0
  135. aury/agents/usage/__init__.py +12 -0
  136. aury/agents/usage/tracker.py +236 -0
  137. aury/agents/workflow/__init__.py +85 -0
  138. aury/agents/workflow/adapter.py +268 -0
  139. aury/agents/workflow/dag.py +116 -0
  140. aury/agents/workflow/dsl.py +575 -0
  141. aury/agents/workflow/executor.py +659 -0
  142. aury/agents/workflow/expression.py +136 -0
  143. aury/agents/workflow/parser.py +182 -0
  144. aury/agents/workflow/state.py +145 -0
  145. aury/agents/workflow/types.py +86 -0
  146. aury_agent-0.0.4.dist-info/METADATA +90 -0
  147. aury_agent-0.0.4.dist-info/RECORD +149 -0
  148. aury_agent-0.0.4.dist-info/WHEEL +4 -0
  149. aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,236 @@
1
+ """Usage tracking for LLM and other services."""
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ from dataclasses import dataclass, field, asdict
6
+ from datetime import datetime
7
+ from decimal import Decimal
8
+ from enum import Enum
9
+ from typing import Any
10
+
11
+ from ..core.event_bus import EventBus, Events
12
+
13
+ # Alias for backward compatibility
14
+ Bus = EventBus
15
+
16
+
17
+ class UsageType(Enum):
18
+ """Types of usage to track."""
19
+ LLM_INPUT = "llm_input"
20
+ LLM_OUTPUT = "llm_output"
21
+ LLM_CACHE_READ = "llm_cache_read"
22
+ LLM_CACHE_WRITE = "llm_cache_write"
23
+ EMBEDDING = "embedding"
24
+ IMAGE_GEN = "image_gen"
25
+ SPEECH = "speech"
26
+ SEARCH = "search"
27
+ EXTERNAL_API = "external_api"
28
+
29
+
30
+ @dataclass
31
+ class UsageEntry:
32
+ """A single usage record."""
33
+ id: str
34
+ type: UsageType
35
+ provider: str
36
+ model: str | None = None
37
+ units: int = 0
38
+ unit_type: str = "tokens" # tokens, requests, characters, etc.
39
+ cost: Decimal | None = None
40
+
41
+ # Context
42
+ session_id: str | None = None
43
+ invocation_id: str | None = None
44
+ tool_id: str | None = None
45
+ step: int | None = None
46
+
47
+ # Metadata
48
+ created_at: datetime = field(default_factory=datetime.now)
49
+ metadata: dict[str, Any] = field(default_factory=dict)
50
+
51
+ def to_dict(self) -> dict[str, Any]:
52
+ return {
53
+ "id": self.id,
54
+ "type": self.type.value,
55
+ "provider": self.provider,
56
+ "model": self.model,
57
+ "units": self.units,
58
+ "unit_type": self.unit_type,
59
+ "cost": float(self.cost) if self.cost else None,
60
+ "session_id": self.session_id,
61
+ "invocation_id": self.invocation_id,
62
+ "tool_id": self.tool_id,
63
+ "step": self.step,
64
+ "created_at": self.created_at.isoformat(),
65
+ "metadata": self.metadata,
66
+ }
67
+
68
+
69
+ class UsageTracker:
70
+ """Track usage of LLM and other services.
71
+
72
+ Provides methods for recording usage, computing costs,
73
+ and generating summaries.
74
+ """
75
+
76
+ def __init__(self, bus: Bus | None = None):
77
+ self.bus = bus
78
+ self._entries: list[UsageEntry] = []
79
+ self._lock = asyncio.Lock()
80
+ self._counter = 0
81
+
82
+ def _generate_id(self) -> str:
83
+ self._counter += 1
84
+ return f"usage_{self._counter:06d}"
85
+
86
+ async def record(self, entry: UsageEntry) -> None:
87
+ """Record a usage entry."""
88
+ async with self._lock:
89
+ self._entries.append(entry)
90
+
91
+ if self.bus:
92
+ await self.bus.publish(Events.USAGE_RECORDED, entry.to_dict())
93
+
94
+ async def record_llm(
95
+ self,
96
+ provider: str,
97
+ model: str,
98
+ input_tokens: int = 0,
99
+ output_tokens: int = 0,
100
+ cache_read_tokens: int = 0,
101
+ cache_write_tokens: int = 0,
102
+ session_id: str | None = None,
103
+ invocation_id: str | None = None,
104
+ step: int | None = None,
105
+ **metadata: Any,
106
+ ) -> None:
107
+ """Convenience method for recording LLM usage."""
108
+ base = {
109
+ "provider": provider,
110
+ "model": model,
111
+ "unit_type": "tokens",
112
+ "session_id": session_id,
113
+ "invocation_id": invocation_id,
114
+ "step": step,
115
+ "metadata": metadata,
116
+ }
117
+
118
+ if input_tokens > 0:
119
+ await self.record(UsageEntry(
120
+ id=self._generate_id(),
121
+ type=UsageType.LLM_INPUT,
122
+ units=input_tokens,
123
+ **base,
124
+ ))
125
+
126
+ if output_tokens > 0:
127
+ await self.record(UsageEntry(
128
+ id=self._generate_id(),
129
+ type=UsageType.LLM_OUTPUT,
130
+ units=output_tokens,
131
+ **base,
132
+ ))
133
+
134
+ if cache_read_tokens > 0:
135
+ await self.record(UsageEntry(
136
+ id=self._generate_id(),
137
+ type=UsageType.LLM_CACHE_READ,
138
+ units=cache_read_tokens,
139
+ **base,
140
+ ))
141
+
142
+ if cache_write_tokens > 0:
143
+ await self.record(UsageEntry(
144
+ id=self._generate_id(),
145
+ type=UsageType.LLM_CACHE_WRITE,
146
+ units=cache_write_tokens,
147
+ **base,
148
+ ))
149
+
150
+ async def record_embedding(
151
+ self,
152
+ provider: str,
153
+ model: str,
154
+ tokens: int,
155
+ session_id: str | None = None,
156
+ **metadata: Any,
157
+ ) -> None:
158
+ """Record embedding usage."""
159
+ await self.record(UsageEntry(
160
+ id=self._generate_id(),
161
+ type=UsageType.EMBEDDING,
162
+ provider=provider,
163
+ model=model,
164
+ units=tokens,
165
+ unit_type="tokens",
166
+ session_id=session_id,
167
+ metadata=metadata,
168
+ ))
169
+
170
+ def summarize(
171
+ self,
172
+ session_id: str | None = None,
173
+ invocation_id: str | None = None,
174
+ ) -> dict[str, Any]:
175
+ """Generate usage summary."""
176
+ entries = self._entries
177
+
178
+ if session_id:
179
+ entries = [e for e in entries if e.session_id == session_id]
180
+ if invocation_id:
181
+ entries = [e for e in entries if e.invocation_id == invocation_id]
182
+
183
+ by_type: dict[str, int] = {}
184
+ by_provider: dict[str, int] = {}
185
+ by_model: dict[str, int] = {}
186
+ total_cost = Decimal(0)
187
+
188
+ for e in entries:
189
+ by_type[e.type.value] = by_type.get(e.type.value, 0) + e.units
190
+ by_provider[e.provider] = by_provider.get(e.provider, 0) + e.units
191
+ if e.model:
192
+ by_model[e.model] = by_model.get(e.model, 0) + e.units
193
+ if e.cost:
194
+ total_cost += e.cost
195
+
196
+ return {
197
+ "total_cost": float(total_cost),
198
+ "by_type": by_type,
199
+ "by_provider": by_provider,
200
+ "by_model": by_model,
201
+ "entry_count": len(entries),
202
+ "total_tokens": sum(by_type.values()),
203
+ }
204
+
205
+ def get_entries(
206
+ self,
207
+ session_id: str | None = None,
208
+ invocation_id: str | None = None,
209
+ type_filter: UsageType | None = None,
210
+ ) -> list[UsageEntry]:
211
+ """Get filtered usage entries."""
212
+ entries = self._entries
213
+
214
+ if session_id:
215
+ entries = [e for e in entries if e.session_id == session_id]
216
+ if invocation_id:
217
+ entries = [e for e in entries if e.invocation_id == invocation_id]
218
+ if type_filter:
219
+ entries = [e for e in entries if e.type == type_filter]
220
+
221
+ return entries
222
+
223
+ def clear(self, session_id: str | None = None) -> int:
224
+ """Clear usage entries.
225
+
226
+ Returns:
227
+ Number of entries cleared
228
+ """
229
+ if session_id:
230
+ original = len(self._entries)
231
+ self._entries = [e for e in self._entries if e.session_id != session_id]
232
+ return original - len(self._entries)
233
+
234
+ count = len(self._entries)
235
+ self._entries.clear()
236
+ return count
@@ -0,0 +1,85 @@
1
+ """Workflow system for DAG-based orchestration."""
2
+ from .types import (
3
+ NodeType,
4
+ Position,
5
+ NodeSpec,
6
+ EdgeSpec,
7
+ WorkflowSpec,
8
+ Workflow,
9
+ )
10
+ from .parser import (
11
+ WorkflowParser,
12
+ WorkflowValidationError,
13
+ )
14
+ from .expression import (
15
+ ExpressionEvaluator,
16
+ ExpressionError,
17
+ )
18
+ from .state import (
19
+ WorkflowState,
20
+ MergeStrategy,
21
+ CollectListStrategy,
22
+ CollectDictStrategy,
23
+ FirstSuccessStrategy,
24
+ get_merge_strategy,
25
+ )
26
+ from .dag import DAGExecutor
27
+ from .executor import WorkflowExecutor
28
+ from ..core.factory import AgentFactory
29
+ from .adapter import WorkflowAgent
30
+ from .dsl import (
31
+ DSLNode,
32
+ DSLSequence,
33
+ DSLStep,
34
+ DSLParallel,
35
+ DSLCondition,
36
+ DSLWorkflow,
37
+ workflow,
38
+ step,
39
+ parallel,
40
+ condition,
41
+ skip,
42
+ )
43
+
44
+ __all__ = [
45
+ # Types
46
+ "NodeType",
47
+ "Position",
48
+ "NodeSpec",
49
+ "EdgeSpec",
50
+ "WorkflowSpec",
51
+ "Workflow",
52
+ # Parser
53
+ "WorkflowParser",
54
+ "WorkflowValidationError",
55
+ # Expression
56
+ "ExpressionEvaluator",
57
+ "ExpressionError",
58
+ # State
59
+ "WorkflowState",
60
+ "MergeStrategy",
61
+ "CollectListStrategy",
62
+ "CollectDictStrategy",
63
+ "FirstSuccessStrategy",
64
+ "get_merge_strategy",
65
+ # DAG
66
+ "DAGExecutor",
67
+ # Executor
68
+ "WorkflowExecutor",
69
+ # Factory
70
+ "AgentFactory",
71
+ # Agent
72
+ "WorkflowAgent",
73
+ # DSL
74
+ "DSLNode",
75
+ "DSLSequence",
76
+ "DSLStep",
77
+ "DSLParallel",
78
+ "DSLCondition",
79
+ "DSLWorkflow",
80
+ "workflow",
81
+ "step",
82
+ "parallel",
83
+ "condition",
84
+ "skip",
85
+ ]
@@ -0,0 +1,268 @@
1
+ """Workflow agent - executes DAG-based workflows.
2
+
3
+ WorkflowAgent uses the unified BaseAgent constructor:
4
+ __init__(self, ctx: InvocationContext, config: AgentConfig | None = None)
5
+
6
+ Workflow definition and AgentFactory are passed via config or set after init.
7
+
8
+ Middleware hooks are called at:
9
+ - on_agent_start/end: workflow start/end
10
+ - on_subagent_start/end: each node execution (in executor)
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from typing import Any, AsyncIterator, ClassVar, Literal, TYPE_CHECKING
15
+
16
+ from ..core.base import BaseAgent, AgentConfig
17
+ from ..core.context import InvocationContext
18
+ from ..core.logging import workflow_logger as logger
19
+ from ..core.types.block import BlockEvent, BlockKind, BlockOp
20
+ from ..middleware import HookAction
21
+ from .types import Workflow
22
+ from .executor import WorkflowExecutor
23
+
24
+ if TYPE_CHECKING:
25
+ from ..core.factory import AgentFactory
26
+ from ..core.types.session import Session
27
+ from ..backends import Backends
28
+ from ..core.event_bus import Bus
29
+ from ..middleware import MiddlewareChain
30
+
31
+
32
+ class WorkflowAgent(BaseAgent):
33
+ """Workflow agent - executes DAG-based workflows.
34
+
35
+ Two ways to create:
36
+
37
+ 1. Simple (recommended):
38
+ agent = WorkflowAgent.create(workflow=wf, agent_factory=factory)
39
+
40
+ 2. Advanced (for custom Session/Backends/Bus):
41
+ ctx = InvocationContext(session=session, backends=backends, bus=bus)
42
+ agent = WorkflowAgent(ctx, config)
43
+ agent.set_workflow(workflow, agent_factory)
44
+ """
45
+
46
+ agent_type: ClassVar[Literal["react", "workflow"]] = "workflow"
47
+
48
+ @classmethod
49
+ def create(
50
+ cls,
51
+ workflow: Workflow,
52
+ agent_factory: "WorkflowAgentFactory",
53
+ config: AgentConfig | None = None,
54
+ *,
55
+ backends: "Backends | None" = None,
56
+ session: "Session | None" = None,
57
+ bus: "Bus | None" = None,
58
+ middleware: "MiddlewareChain | None" = None,
59
+ ) -> "WorkflowAgent":
60
+ """Create WorkflowAgent with minimal boilerplate.
61
+
62
+ Args:
63
+ workflow: Workflow definition
64
+ agent_factory: Factory for creating sub-agents
65
+ config: Agent configuration (optional)
66
+ backends: Backends container (auto-created if None)
67
+ session: Session object (auto-created if None)
68
+ bus: Event bus (auto-created if None)
69
+ middleware: Middleware chain (optional)
70
+
71
+ Returns:
72
+ Configured WorkflowAgent ready to run
73
+
74
+ Example:
75
+ agent = WorkflowAgent.create(
76
+ workflow=my_workflow,
77
+ agent_factory=factory,
78
+ )
79
+ async for response in agent.run(inputs):
80
+ print(response)
81
+ """
82
+ from ..core.event_bus import Bus
83
+ from ..core.types.session import Session, generate_id
84
+ from ..backends import Backends
85
+
86
+ # Auto-create backends if not provided
87
+ if backends is None:
88
+ backends = Backends.create_default()
89
+
90
+ # Auto-create missing components
91
+ if session is None:
92
+ session = Session(id=generate_id("sess"))
93
+ if bus is None:
94
+ bus = Bus()
95
+
96
+ # Build context
97
+ ctx = InvocationContext(
98
+ session=session,
99
+ invocation_id=generate_id("inv"),
100
+ agent_id=config.name if config else workflow.spec.name,
101
+ backends=backends,
102
+ bus=bus,
103
+ middleware=middleware,
104
+ )
105
+
106
+ agent = cls(ctx, config)
107
+ agent.set_workflow(workflow, agent_factory)
108
+ return agent
109
+
110
+ def __init__(
111
+ self,
112
+ ctx: InvocationContext,
113
+ config: AgentConfig | None = None,
114
+ ):
115
+ """Initialize WorkflowAgent.
116
+
117
+ Args:
118
+ ctx: InvocationContext with storage, bus, etc.
119
+ config: Agent configuration (may contain workflow in metadata)
120
+ """
121
+ super().__init__(ctx, config)
122
+
123
+ # Workflow can be set via config.metadata or set_workflow()
124
+ self.workflow: Workflow | None = self.config.metadata.get('workflow')
125
+ self.agent_factory: "WorkflowAgentFactory | None" = self.config.metadata.get('agent_factory')
126
+ self._executor: WorkflowExecutor | None = None
127
+
128
+ def set_workflow(
129
+ self,
130
+ workflow: Workflow,
131
+ agent_factory: "WorkflowAgentFactory",
132
+ ) -> "WorkflowAgent":
133
+ """Set workflow definition and factory.
134
+
135
+ Args:
136
+ workflow: Workflow definition
137
+ agent_factory: Factory for creating sub-agents
138
+
139
+ Returns:
140
+ Self for chaining
141
+ """
142
+ self.workflow = workflow
143
+ self.agent_factory = agent_factory
144
+ return self
145
+
146
+ async def _execute(self, input: Any) -> None:
147
+ """Execute workflow with middleware hooks.
148
+
149
+ Args:
150
+ input: Workflow inputs (dict or single value)
151
+
152
+ Raises:
153
+ ValueError: If workflow or agent_factory not set
154
+ """
155
+ # Validate workflow is set
156
+ if self.workflow is None:
157
+ raise ValueError("Workflow not set. Call set_workflow() first.")
158
+ if self.agent_factory is None:
159
+ raise ValueError("AgentFactory not set. Call set_workflow() first.")
160
+
161
+ inputs = input if isinstance(input, dict) else {"input": input}
162
+
163
+ # Build middleware context
164
+ mw_context = {
165
+ "session_id": self.session.id,
166
+ "invocation_id": self.ctx.invocation_id,
167
+ "agent_id": self.name,
168
+ "agent_type": self.agent_type,
169
+ "workflow_name": self.workflow.spec.name,
170
+ }
171
+
172
+ # === Middleware: on_agent_start ===
173
+ if self.middleware:
174
+ hook_result = await self.middleware.process_agent_start(
175
+ self.name, inputs, mw_context
176
+ )
177
+ if hook_result.action == HookAction.STOP:
178
+ logger.info("Workflow stopped by middleware on_agent_start")
179
+ await self.ctx.emit(BlockEvent(
180
+ kind=BlockKind.ERROR,
181
+ op=BlockOp.APPLY,
182
+ data={"message": hook_result.message or "Stopped by middleware"},
183
+ ))
184
+ return
185
+ elif hook_result.action == HookAction.SKIP:
186
+ logger.info("Workflow skipped by middleware on_agent_start")
187
+ return
188
+
189
+ try:
190
+ # Create executor with context (pass middleware)
191
+ self._executor = WorkflowExecutor(
192
+ workflow=self.workflow,
193
+ agent_factory=self.agent_factory,
194
+ ctx=self.ctx,
195
+ middleware=self.middleware,
196
+ )
197
+
198
+ result = await self._executor.execute(inputs)
199
+
200
+ # === Middleware: on_agent_end ===
201
+ if self.middleware:
202
+ await self.middleware.process_agent_end(
203
+ self.name, result, mw_context
204
+ )
205
+
206
+ except Exception as e:
207
+ # === Middleware: on_error ===
208
+ if self.middleware:
209
+ processed_error = await self.middleware.process_error(e, mw_context)
210
+ if processed_error is None:
211
+ logger.info("Error suppressed by middleware")
212
+ return
213
+ raise
214
+
215
+ async def pause(self) -> str:
216
+ """Pause workflow and return invocation ID."""
217
+ if self._executor:
218
+ self._executor.pause()
219
+ return self.ctx.invocation_id
220
+
221
+ async def _resume_internal(self, invocation_id: str) -> None:
222
+ """Internal resume logic."""
223
+ if self.workflow is None or self.agent_factory is None:
224
+ raise ValueError("Workflow not set. Call set_workflow() first.")
225
+
226
+ # Load saved state
227
+ state_key = f"workflow_state:{invocation_id}"
228
+ if not self.ctx.backends or not self.ctx.backends.state:
229
+ raise ValueError("No state backend available")
230
+ saved_state = await self.ctx.backends.state.get("workflow", state_key)
231
+
232
+ if not saved_state:
233
+ raise ValueError(f"No saved state for invocation: {invocation_id}")
234
+
235
+ inputs = saved_state.get("inputs", {})
236
+
237
+ # Create executor and resume
238
+ self._executor = WorkflowExecutor(
239
+ workflow=self.workflow,
240
+ agent_factory=self.agent_factory,
241
+ ctx=self.ctx,
242
+ )
243
+
244
+ await self._executor.execute(inputs, resume_state=saved_state)
245
+
246
+ async def resume(self, invocation_id: str) -> AsyncIterator[BlockEvent]:
247
+ """Resume paused workflow."""
248
+ import asyncio
249
+ from ..core.context import _emit_queue_var
250
+
251
+ queue: asyncio.Queue[BlockEvent] = asyncio.Queue()
252
+ token = _emit_queue_var.set(queue)
253
+
254
+ try:
255
+ exec_task = asyncio.create_task(self._resume_internal(invocation_id))
256
+
257
+ while not exec_task.done() or not queue.empty():
258
+ try:
259
+ block = await asyncio.wait_for(queue.get(), timeout=0.05)
260
+ yield block
261
+ except asyncio.TimeoutError:
262
+ continue
263
+
264
+ await exec_task
265
+
266
+ finally:
267
+ _emit_queue_var.reset(token)
268
+
@@ -0,0 +1,116 @@
1
+ """Generic DAG executor."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Callable, TypeVar
5
+
6
+ T = TypeVar("T")
7
+
8
+
9
+ class DAGExecutor:
10
+ """Generic DAG parallel executor."""
11
+
12
+ def __init__(
13
+ self,
14
+ tasks: list[T],
15
+ get_task_id: Callable[[T], str],
16
+ get_dependencies: Callable[[T], list[str]],
17
+ ):
18
+ self.tasks = {get_task_id(task): task for task in tasks}
19
+ self.get_task_id = get_task_id
20
+ self.get_dependencies = get_dependencies
21
+
22
+ self.completed: set[str] = set()
23
+ self.failed: set[str] = set()
24
+ self.running: set[str] = set()
25
+ self.skipped: set[str] = set()
26
+
27
+ def get_ready_tasks(self) -> list[T]:
28
+ """Get tasks ready for execution.
29
+
30
+ A task is ready when all its dependencies are completed or skipped.
31
+ If a dependency has failed, the task will never be ready.
32
+ """
33
+ ready = []
34
+ processed = self.completed | self.failed | self.running | self.skipped
35
+
36
+ for task_id, task in self.tasks.items():
37
+ if task_id in processed:
38
+ continue
39
+
40
+ deps = self.get_dependencies(task)
41
+ deps_satisfied = all(
42
+ dep_id in (self.completed | self.skipped)
43
+ for dep_id in deps
44
+ )
45
+
46
+ if deps_satisfied:
47
+ ready.append(task)
48
+
49
+ return ready
50
+
51
+ def is_blocked(self) -> bool:
52
+ """Check if DAG is blocked due to failed dependencies.
53
+
54
+ Returns True if there are unprocessed tasks that can never run
55
+ because their dependencies have failed.
56
+ """
57
+ if not self.failed:
58
+ return False
59
+
60
+ processed = self.completed | self.failed | self.running | self.skipped
61
+
62
+ for task_id, task in self.tasks.items():
63
+ if task_id in processed:
64
+ continue
65
+
66
+ # Check if any dependency has failed
67
+ deps = self.get_dependencies(task)
68
+ for dep_id in deps:
69
+ if dep_id in self.failed:
70
+ return True
71
+
72
+ return False
73
+
74
+ def mark_running(self, task_id: str) -> None:
75
+ self.running.add(task_id)
76
+
77
+ def mark_completed(self, task_id: str) -> None:
78
+ self.running.discard(task_id)
79
+ self.completed.add(task_id)
80
+
81
+ def mark_failed(self, task_id: str) -> None:
82
+ self.running.discard(task_id)
83
+ self.failed.add(task_id)
84
+
85
+ def mark_skipped(self, task_id: str) -> None:
86
+ self.skipped.add(task_id)
87
+
88
+ def is_finished(self) -> bool:
89
+ total_processed = len(self.completed) + len(self.failed) + len(self.skipped)
90
+ return total_processed == len(self.tasks)
91
+
92
+ def has_failures(self) -> bool:
93
+ return len(self.failed) > 0
94
+
95
+ def get_status(self) -> dict[str, Any]:
96
+ pending = (
97
+ len(self.tasks)
98
+ - len(self.completed)
99
+ - len(self.failed)
100
+ - len(self.skipped)
101
+ - len(self.running)
102
+ )
103
+ return {
104
+ "total": len(self.tasks),
105
+ "completed": len(self.completed),
106
+ "failed": len(self.failed),
107
+ "skipped": len(self.skipped),
108
+ "running": len(self.running),
109
+ "pending": pending,
110
+ }
111
+
112
+ def reset(self) -> None:
113
+ self.completed.clear()
114
+ self.failed.clear()
115
+ self.running.clear()
116
+ self.skipped.clear()