aury-agent 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aury/__init__.py +2 -0
- aury/agents/__init__.py +55 -0
- aury/agents/a2a/__init__.py +168 -0
- aury/agents/backends/__init__.py +196 -0
- aury/agents/backends/artifact/__init__.py +9 -0
- aury/agents/backends/artifact/memory.py +130 -0
- aury/agents/backends/artifact/types.py +133 -0
- aury/agents/backends/code/__init__.py +65 -0
- aury/agents/backends/file/__init__.py +11 -0
- aury/agents/backends/file/local.py +66 -0
- aury/agents/backends/file/types.py +40 -0
- aury/agents/backends/invocation/__init__.py +8 -0
- aury/agents/backends/invocation/memory.py +81 -0
- aury/agents/backends/invocation/types.py +110 -0
- aury/agents/backends/memory/__init__.py +8 -0
- aury/agents/backends/memory/memory.py +179 -0
- aury/agents/backends/memory/types.py +136 -0
- aury/agents/backends/message/__init__.py +9 -0
- aury/agents/backends/message/memory.py +122 -0
- aury/agents/backends/message/types.py +124 -0
- aury/agents/backends/sandbox.py +275 -0
- aury/agents/backends/session/__init__.py +8 -0
- aury/agents/backends/session/memory.py +93 -0
- aury/agents/backends/session/types.py +124 -0
- aury/agents/backends/shell/__init__.py +11 -0
- aury/agents/backends/shell/local.py +110 -0
- aury/agents/backends/shell/types.py +55 -0
- aury/agents/backends/shell.py +209 -0
- aury/agents/backends/snapshot/__init__.py +19 -0
- aury/agents/backends/snapshot/git.py +95 -0
- aury/agents/backends/snapshot/hybrid.py +125 -0
- aury/agents/backends/snapshot/memory.py +86 -0
- aury/agents/backends/snapshot/types.py +59 -0
- aury/agents/backends/state/__init__.py +29 -0
- aury/agents/backends/state/composite.py +49 -0
- aury/agents/backends/state/file.py +57 -0
- aury/agents/backends/state/memory.py +52 -0
- aury/agents/backends/state/sqlite.py +262 -0
- aury/agents/backends/state/types.py +178 -0
- aury/agents/backends/subagent/__init__.py +165 -0
- aury/agents/cli/__init__.py +41 -0
- aury/agents/cli/chat.py +239 -0
- aury/agents/cli/config.py +236 -0
- aury/agents/cli/extensions.py +460 -0
- aury/agents/cli/main.py +189 -0
- aury/agents/cli/session.py +337 -0
- aury/agents/cli/workflow.py +276 -0
- aury/agents/context_providers/__init__.py +66 -0
- aury/agents/context_providers/artifact.py +299 -0
- aury/agents/context_providers/base.py +177 -0
- aury/agents/context_providers/memory.py +70 -0
- aury/agents/context_providers/message.py +130 -0
- aury/agents/context_providers/skill.py +50 -0
- aury/agents/context_providers/subagent.py +46 -0
- aury/agents/context_providers/tool.py +68 -0
- aury/agents/core/__init__.py +83 -0
- aury/agents/core/base.py +573 -0
- aury/agents/core/context.py +797 -0
- aury/agents/core/context_builder.py +303 -0
- aury/agents/core/event_bus/__init__.py +15 -0
- aury/agents/core/event_bus/bus.py +203 -0
- aury/agents/core/factory.py +169 -0
- aury/agents/core/isolator.py +97 -0
- aury/agents/core/logging.py +95 -0
- aury/agents/core/parallel.py +194 -0
- aury/agents/core/runner.py +139 -0
- aury/agents/core/services/__init__.py +5 -0
- aury/agents/core/services/file_session.py +144 -0
- aury/agents/core/services/message.py +53 -0
- aury/agents/core/services/session.py +53 -0
- aury/agents/core/signals.py +109 -0
- aury/agents/core/state.py +363 -0
- aury/agents/core/types/__init__.py +107 -0
- aury/agents/core/types/action.py +176 -0
- aury/agents/core/types/artifact.py +135 -0
- aury/agents/core/types/block.py +736 -0
- aury/agents/core/types/message.py +350 -0
- aury/agents/core/types/recall.py +144 -0
- aury/agents/core/types/session.py +257 -0
- aury/agents/core/types/subagent.py +154 -0
- aury/agents/core/types/tool.py +205 -0
- aury/agents/eval/__init__.py +331 -0
- aury/agents/hitl/__init__.py +57 -0
- aury/agents/hitl/ask_user.py +242 -0
- aury/agents/hitl/compaction.py +230 -0
- aury/agents/hitl/exceptions.py +87 -0
- aury/agents/hitl/permission.py +617 -0
- aury/agents/hitl/revert.py +216 -0
- aury/agents/llm/__init__.py +31 -0
- aury/agents/llm/adapter.py +367 -0
- aury/agents/llm/openai.py +294 -0
- aury/agents/llm/provider.py +476 -0
- aury/agents/mcp/__init__.py +153 -0
- aury/agents/memory/__init__.py +46 -0
- aury/agents/memory/compaction.py +394 -0
- aury/agents/memory/manager.py +465 -0
- aury/agents/memory/processor.py +177 -0
- aury/agents/memory/store.py +187 -0
- aury/agents/memory/types.py +137 -0
- aury/agents/messages/__init__.py +40 -0
- aury/agents/messages/config.py +47 -0
- aury/agents/messages/raw_store.py +224 -0
- aury/agents/messages/store.py +118 -0
- aury/agents/messages/types.py +88 -0
- aury/agents/middleware/__init__.py +31 -0
- aury/agents/middleware/base.py +341 -0
- aury/agents/middleware/chain.py +342 -0
- aury/agents/middleware/message.py +129 -0
- aury/agents/middleware/message_container.py +126 -0
- aury/agents/middleware/raw_message.py +153 -0
- aury/agents/middleware/truncation.py +139 -0
- aury/agents/middleware/types.py +81 -0
- aury/agents/plugin.py +162 -0
- aury/agents/react/__init__.py +4 -0
- aury/agents/react/agent.py +1923 -0
- aury/agents/sandbox/__init__.py +23 -0
- aury/agents/sandbox/local.py +239 -0
- aury/agents/sandbox/remote.py +200 -0
- aury/agents/sandbox/types.py +115 -0
- aury/agents/skill/__init__.py +16 -0
- aury/agents/skill/loader.py +180 -0
- aury/agents/skill/types.py +83 -0
- aury/agents/tool/__init__.py +39 -0
- aury/agents/tool/builtin/__init__.py +23 -0
- aury/agents/tool/builtin/ask_user.py +155 -0
- aury/agents/tool/builtin/bash.py +107 -0
- aury/agents/tool/builtin/delegate.py +726 -0
- aury/agents/tool/builtin/edit.py +121 -0
- aury/agents/tool/builtin/plan.py +277 -0
- aury/agents/tool/builtin/read.py +91 -0
- aury/agents/tool/builtin/thinking.py +111 -0
- aury/agents/tool/builtin/yield_result.py +130 -0
- aury/agents/tool/decorator.py +252 -0
- aury/agents/tool/set.py +204 -0
- aury/agents/usage/__init__.py +12 -0
- aury/agents/usage/tracker.py +236 -0
- aury/agents/workflow/__init__.py +85 -0
- aury/agents/workflow/adapter.py +268 -0
- aury/agents/workflow/dag.py +116 -0
- aury/agents/workflow/dsl.py +575 -0
- aury/agents/workflow/executor.py +659 -0
- aury/agents/workflow/expression.py +136 -0
- aury/agents/workflow/parser.py +182 -0
- aury/agents/workflow/state.py +145 -0
- aury/agents/workflow/types.py +86 -0
- aury_agent-0.0.4.dist-info/METADATA +90 -0
- aury_agent-0.0.4.dist-info/RECORD +149 -0
- aury_agent-0.0.4.dist-info/WHEEL +4 -0
- aury_agent-0.0.4.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,575 @@
|
|
|
1
|
+
"""Python DSL for workflow definition.
|
|
2
|
+
|
|
3
|
+
Provides a fluent API for building workflows:
|
|
4
|
+
|
|
5
|
+
```python
|
|
6
|
+
from aury.agents.workflow.dsl import workflow, step, parallel, condition
|
|
7
|
+
from aury.agents.middleware import BaseMiddleware, MiddlewareChain
|
|
8
|
+
|
|
9
|
+
# Simple sequence
|
|
10
|
+
wf = workflow("my_workflow") >> step("A") >> step("B") >> step("C")
|
|
11
|
+
|
|
12
|
+
# Parallel execution
|
|
13
|
+
wf = workflow("my_workflow") >> parallel(
|
|
14
|
+
step("A"),
|
|
15
|
+
step("B"),
|
|
16
|
+
) >> step("C")
|
|
17
|
+
|
|
18
|
+
# Conditional branching
|
|
19
|
+
wf = workflow("my_workflow") >> condition(
|
|
20
|
+
expr="state.mode == 'fast'",
|
|
21
|
+
then_=step("FastPath"),
|
|
22
|
+
else_=step("SlowPath"),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# With middleware at workflow level
|
|
26
|
+
class LoggingMiddleware(BaseMiddleware):
|
|
27
|
+
async def on_agent_start(self, agent_id, input_data, context):
|
|
28
|
+
print(f"Starting {agent_id}")
|
|
29
|
+
return HookResult.proceed()
|
|
30
|
+
|
|
31
|
+
middleware = MiddlewareChain().use(LoggingMiddleware())
|
|
32
|
+
wf = workflow("my_workflow").middleware(middleware) >> step("A") >> step("B")
|
|
33
|
+
|
|
34
|
+
# With middleware at step level
|
|
35
|
+
wf = workflow("my_workflow") >> step("A").with_middleware(LoggingMiddleware()) >> step("B")
|
|
36
|
+
```
|
|
37
|
+
"""
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
from dataclasses import dataclass, field
|
|
41
|
+
from typing import Any, Callable, TYPE_CHECKING
|
|
42
|
+
from uuid import uuid4
|
|
43
|
+
|
|
44
|
+
from .types import NodeType, NodeSpec, EdgeSpec, WorkflowSpec, Position
|
|
45
|
+
|
|
46
|
+
if TYPE_CHECKING:
|
|
47
|
+
from ..middleware import Middleware, MiddlewareChain
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _gen_id(prefix: str = "node") -> str:
|
|
51
|
+
"""Generate unique node ID."""
|
|
52
|
+
return f"{prefix}_{uuid4().hex[:8]}"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass
|
|
56
|
+
class DSLNode:
|
|
57
|
+
"""Base DSL node for workflow building."""
|
|
58
|
+
|
|
59
|
+
id: str
|
|
60
|
+
node_type: NodeType
|
|
61
|
+
config: dict[str, Any] = field(default_factory=dict)
|
|
62
|
+
children: list["DSLNode"] = field(default_factory=list)
|
|
63
|
+
condition: str | None = None # Skip condition
|
|
64
|
+
node_middleware: list["Middleware"] = field(default_factory=list) # Node-level middleware
|
|
65
|
+
|
|
66
|
+
def __rshift__(self, other: "DSLNode") -> "DSLSequence":
|
|
67
|
+
"""Chain operator: a >> b creates sequence."""
|
|
68
|
+
return DSLSequence([self, other])
|
|
69
|
+
|
|
70
|
+
def skip_when(self, condition: str) -> "DSLNode":
|
|
71
|
+
"""Set skip condition for this node.
|
|
72
|
+
|
|
73
|
+
Usage:
|
|
74
|
+
step("A").skip_when("state.skip_a == True")
|
|
75
|
+
"""
|
|
76
|
+
self.condition = condition
|
|
77
|
+
return self
|
|
78
|
+
|
|
79
|
+
def with_middleware(self, *middlewares: "Middleware") -> "DSLNode":
|
|
80
|
+
"""Add middleware to this specific node.
|
|
81
|
+
|
|
82
|
+
Usage:
|
|
83
|
+
step("A").with_middleware(LoggingMiddleware(), MetricsMiddleware())
|
|
84
|
+
"""
|
|
85
|
+
self.node_middleware.extend(middlewares)
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
def to_spec(self, position: Position | None = None) -> NodeSpec:
|
|
89
|
+
"""Convert to NodeSpec."""
|
|
90
|
+
return NodeSpec(
|
|
91
|
+
id=self.id,
|
|
92
|
+
type=self.node_type,
|
|
93
|
+
position=position or Position(),
|
|
94
|
+
config=self.config,
|
|
95
|
+
when=self.condition,
|
|
96
|
+
middleware=self.node_middleware if self.node_middleware else None,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DSLSequence:
|
|
101
|
+
"""Sequence of DSL nodes (a >> b >> c)."""
|
|
102
|
+
|
|
103
|
+
def __init__(self, nodes: list[DSLNode]):
|
|
104
|
+
self.nodes = nodes
|
|
105
|
+
|
|
106
|
+
def __rshift__(self, other: DSLNode | "DSLSequence") -> "DSLSequence":
|
|
107
|
+
"""Extend sequence with more nodes."""
|
|
108
|
+
if isinstance(other, DSLSequence):
|
|
109
|
+
return DSLSequence(self.nodes + other.nodes)
|
|
110
|
+
return DSLSequence(self.nodes + [other])
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class DSLStep(DSLNode):
|
|
115
|
+
"""Agent step node."""
|
|
116
|
+
|
|
117
|
+
agent_name: str = ""
|
|
118
|
+
inputs: dict[str, Any] = field(default_factory=dict)
|
|
119
|
+
output_key: str | None = None
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
agent: str,
|
|
124
|
+
inputs: dict[str, Any] | None = None,
|
|
125
|
+
output: str | None = None,
|
|
126
|
+
id: str | None = None,
|
|
127
|
+
middleware: list["Middleware"] | None = None,
|
|
128
|
+
):
|
|
129
|
+
super().__init__(
|
|
130
|
+
id=id or _gen_id("step"),
|
|
131
|
+
node_type=NodeType.AGENT,
|
|
132
|
+
)
|
|
133
|
+
self.agent_name = agent
|
|
134
|
+
self.inputs = inputs or {}
|
|
135
|
+
self.output_key = output
|
|
136
|
+
if middleware:
|
|
137
|
+
self.node_middleware = middleware
|
|
138
|
+
|
|
139
|
+
def to_spec(self, position: Position | None = None) -> NodeSpec:
|
|
140
|
+
return NodeSpec(
|
|
141
|
+
id=self.id,
|
|
142
|
+
type=self.node_type,
|
|
143
|
+
position=position or Position(),
|
|
144
|
+
agent=self.agent_name,
|
|
145
|
+
config=self.config,
|
|
146
|
+
inputs=self.inputs,
|
|
147
|
+
output=self.output_key,
|
|
148
|
+
when=self.condition,
|
|
149
|
+
middleware=self.node_middleware if self.node_middleware else None,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@dataclass
|
|
154
|
+
class DSLParallel(DSLNode):
|
|
155
|
+
"""Parallel execution node."""
|
|
156
|
+
|
|
157
|
+
branches: list[DSLNode | DSLSequence] = field(default_factory=list)
|
|
158
|
+
merge_strategy: str = "collect_list"
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
*branches: DSLNode | DSLSequence,
|
|
163
|
+
merge: str = "collect_list",
|
|
164
|
+
id: str | None = None,
|
|
165
|
+
):
|
|
166
|
+
super().__init__(
|
|
167
|
+
id=id or _gen_id("parallel"),
|
|
168
|
+
node_type=NodeType.PARALLEL,
|
|
169
|
+
)
|
|
170
|
+
self.branches = list(branches)
|
|
171
|
+
self.merge_strategy = merge
|
|
172
|
+
self.config["merge_strategy"] = merge
|
|
173
|
+
|
|
174
|
+
def to_spec(self, position: Position | None = None) -> NodeSpec:
|
|
175
|
+
# Convert branches to spec format
|
|
176
|
+
branch_specs = []
|
|
177
|
+
for branch in self.branches:
|
|
178
|
+
if isinstance(branch, DSLSequence):
|
|
179
|
+
# Sequence: list of step IDs
|
|
180
|
+
branch_specs.append({
|
|
181
|
+
"steps": [n.id for n in branch.nodes]
|
|
182
|
+
})
|
|
183
|
+
else:
|
|
184
|
+
# Single node
|
|
185
|
+
branch_specs.append({
|
|
186
|
+
"steps": [branch.id]
|
|
187
|
+
})
|
|
188
|
+
|
|
189
|
+
return NodeSpec(
|
|
190
|
+
id=self.id,
|
|
191
|
+
type=self.node_type,
|
|
192
|
+
position=position or Position(),
|
|
193
|
+
config=self.config,
|
|
194
|
+
branches=branch_specs,
|
|
195
|
+
when=self.condition,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class DSLCondition(DSLNode):
|
|
201
|
+
"""Conditional branching node."""
|
|
202
|
+
|
|
203
|
+
expression: str = ""
|
|
204
|
+
then_branch: DSLNode | DSLSequence | None = None
|
|
205
|
+
else_branch: DSLNode | DSLSequence | None = None
|
|
206
|
+
|
|
207
|
+
def __init__(
|
|
208
|
+
self,
|
|
209
|
+
expr: str,
|
|
210
|
+
then_: DSLNode | DSLSequence | None = None,
|
|
211
|
+
else_: DSLNode | DSLSequence | None = None,
|
|
212
|
+
id: str | None = None,
|
|
213
|
+
):
|
|
214
|
+
super().__init__(
|
|
215
|
+
id=id or _gen_id("condition"),
|
|
216
|
+
node_type=NodeType.CONDITION,
|
|
217
|
+
)
|
|
218
|
+
self.expression = expr
|
|
219
|
+
self.then_branch = then_
|
|
220
|
+
self.else_branch = else_
|
|
221
|
+
|
|
222
|
+
def to_spec(self, position: Position | None = None) -> NodeSpec:
|
|
223
|
+
then_id = None
|
|
224
|
+
else_id = None
|
|
225
|
+
|
|
226
|
+
if self.then_branch:
|
|
227
|
+
if isinstance(self.then_branch, DSLSequence):
|
|
228
|
+
then_id = self.then_branch.nodes[0].id if self.then_branch.nodes else None
|
|
229
|
+
else:
|
|
230
|
+
then_id = self.then_branch.id
|
|
231
|
+
|
|
232
|
+
if self.else_branch:
|
|
233
|
+
if isinstance(self.else_branch, DSLSequence):
|
|
234
|
+
else_id = self.else_branch.nodes[0].id if self.else_branch.nodes else None
|
|
235
|
+
else:
|
|
236
|
+
else_id = self.else_branch.id
|
|
237
|
+
|
|
238
|
+
return NodeSpec(
|
|
239
|
+
id=self.id,
|
|
240
|
+
type=self.node_type,
|
|
241
|
+
position=position or Position(),
|
|
242
|
+
config=self.config,
|
|
243
|
+
expression=self.expression,
|
|
244
|
+
then_node=then_id,
|
|
245
|
+
else_node=else_id,
|
|
246
|
+
when=self.condition,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class DSLWorkflow:
|
|
251
|
+
"""Workflow builder using DSL."""
|
|
252
|
+
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
name: str,
|
|
256
|
+
version: str = "1.0",
|
|
257
|
+
description: str | None = None,
|
|
258
|
+
):
|
|
259
|
+
self.name = name
|
|
260
|
+
self.version = version
|
|
261
|
+
self.description = description
|
|
262
|
+
self.nodes: list[DSLNode] = []
|
|
263
|
+
self.state_schema: dict[str, str] = {}
|
|
264
|
+
self.input_schema: dict[str, dict[str, Any]] = {}
|
|
265
|
+
self._root: DSLNode | DSLSequence | None = None
|
|
266
|
+
self._middleware: "MiddlewareChain | None" = None
|
|
267
|
+
|
|
268
|
+
def __rshift__(self, other: DSLNode | DSLSequence) -> "DSLWorkflow":
|
|
269
|
+
"""Set workflow root: workflow >> step("A")."""
|
|
270
|
+
self._root = other
|
|
271
|
+
return self
|
|
272
|
+
|
|
273
|
+
def middleware(
|
|
274
|
+
self,
|
|
275
|
+
*middlewares: "Middleware | MiddlewareChain",
|
|
276
|
+
) -> "DSLWorkflow":
|
|
277
|
+
"""Set workflow-level middleware.
|
|
278
|
+
|
|
279
|
+
These middleware hooks apply to all nodes in the workflow.
|
|
280
|
+
Accepts individual middlewares or a MiddlewareChain.
|
|
281
|
+
Can be called multiple times to add more middleware.
|
|
282
|
+
|
|
283
|
+
Usage:
|
|
284
|
+
# Pass individual middlewares
|
|
285
|
+
workflow("wf").middleware(LoggingMiddleware(), MetricsMiddleware()) >> step("A")
|
|
286
|
+
|
|
287
|
+
# Pass a MiddlewareChain
|
|
288
|
+
chain = MiddlewareChain().use(LoggingMiddleware())
|
|
289
|
+
workflow("wf").middleware(chain) >> step("A")
|
|
290
|
+
|
|
291
|
+
# Multiple calls (additive)
|
|
292
|
+
workflow("wf").middleware(LoggingMiddleware()).middleware(MetricsMiddleware())
|
|
293
|
+
"""
|
|
294
|
+
from ..middleware import MiddlewareChain as MWChain
|
|
295
|
+
|
|
296
|
+
# Initialize chain if needed
|
|
297
|
+
if self._middleware is None:
|
|
298
|
+
self._middleware = MWChain()
|
|
299
|
+
|
|
300
|
+
for mw in middlewares:
|
|
301
|
+
if isinstance(mw, MWChain):
|
|
302
|
+
# Merge chains
|
|
303
|
+
for m in mw.middlewares:
|
|
304
|
+
self._middleware.use(m)
|
|
305
|
+
else:
|
|
306
|
+
# Add individual middleware
|
|
307
|
+
self._middleware.use(mw)
|
|
308
|
+
|
|
309
|
+
return self
|
|
310
|
+
|
|
311
|
+
def state(self, **schema: str) -> "DSLWorkflow":
|
|
312
|
+
"""Define state schema.
|
|
313
|
+
|
|
314
|
+
Usage:
|
|
315
|
+
workflow("wf").state(count="int", items="list[str]")
|
|
316
|
+
"""
|
|
317
|
+
self.state_schema.update(schema)
|
|
318
|
+
return self
|
|
319
|
+
|
|
320
|
+
def inputs(self, **schema: dict[str, Any]) -> "DSLWorkflow":
|
|
321
|
+
"""Define input schema.
|
|
322
|
+
|
|
323
|
+
Usage:
|
|
324
|
+
workflow("wf").inputs(
|
|
325
|
+
query={"type": "string", "required": True}
|
|
326
|
+
)
|
|
327
|
+
"""
|
|
328
|
+
self.input_schema.update(schema)
|
|
329
|
+
return self
|
|
330
|
+
|
|
331
|
+
def build(self) -> WorkflowSpec:
|
|
332
|
+
"""Build WorkflowSpec from DSL."""
|
|
333
|
+
nodes: list[NodeSpec] = []
|
|
334
|
+
edges: list[EdgeSpec] = []
|
|
335
|
+
|
|
336
|
+
# Collect all nodes from root
|
|
337
|
+
all_nodes = self._collect_nodes(self._root)
|
|
338
|
+
|
|
339
|
+
# Convert to specs with positions
|
|
340
|
+
y_offset = 0
|
|
341
|
+
for node in all_nodes:
|
|
342
|
+
pos = Position(x=100, y=y_offset)
|
|
343
|
+
nodes.append(node.to_spec(pos))
|
|
344
|
+
y_offset += 100
|
|
345
|
+
|
|
346
|
+
# Generate edges
|
|
347
|
+
edges = self._generate_edges(self._root)
|
|
348
|
+
|
|
349
|
+
return WorkflowSpec(
|
|
350
|
+
name=self.name,
|
|
351
|
+
version=self.version,
|
|
352
|
+
description=self.description,
|
|
353
|
+
state=self.state_schema,
|
|
354
|
+
inputs=self.input_schema,
|
|
355
|
+
nodes=nodes,
|
|
356
|
+
edges=edges,
|
|
357
|
+
middleware=self._middleware,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def _collect_nodes(
|
|
361
|
+
self,
|
|
362
|
+
root: DSLNode | DSLSequence | None,
|
|
363
|
+
) -> list[DSLNode]:
|
|
364
|
+
"""Recursively collect all nodes."""
|
|
365
|
+
if root is None:
|
|
366
|
+
return []
|
|
367
|
+
|
|
368
|
+
nodes: list[DSLNode] = []
|
|
369
|
+
|
|
370
|
+
if isinstance(root, DSLSequence):
|
|
371
|
+
for node in root.nodes:
|
|
372
|
+
nodes.extend(self._collect_nodes(node))
|
|
373
|
+
elif isinstance(root, DSLParallel):
|
|
374
|
+
nodes.append(root)
|
|
375
|
+
for branch in root.branches:
|
|
376
|
+
nodes.extend(self._collect_nodes(branch))
|
|
377
|
+
elif isinstance(root, DSLCondition):
|
|
378
|
+
nodes.append(root)
|
|
379
|
+
if root.then_branch:
|
|
380
|
+
nodes.extend(self._collect_nodes(root.then_branch))
|
|
381
|
+
if root.else_branch:
|
|
382
|
+
nodes.extend(self._collect_nodes(root.else_branch))
|
|
383
|
+
else:
|
|
384
|
+
nodes.append(root)
|
|
385
|
+
|
|
386
|
+
return nodes
|
|
387
|
+
|
|
388
|
+
def _generate_edges(
|
|
389
|
+
self,
|
|
390
|
+
root: DSLNode | DSLSequence | None,
|
|
391
|
+
) -> list[EdgeSpec]:
|
|
392
|
+
"""Generate edges from DSL structure."""
|
|
393
|
+
if root is None:
|
|
394
|
+
return []
|
|
395
|
+
|
|
396
|
+
edges: list[EdgeSpec] = []
|
|
397
|
+
|
|
398
|
+
if isinstance(root, DSLSequence):
|
|
399
|
+
# Sequential edges
|
|
400
|
+
for i in range(len(root.nodes) - 1):
|
|
401
|
+
from_node = root.nodes[i]
|
|
402
|
+
to_node = root.nodes[i + 1]
|
|
403
|
+
|
|
404
|
+
# Handle parallel/condition getting last node
|
|
405
|
+
from_id = self._get_exit_id(from_node)
|
|
406
|
+
to_id = self._get_entry_id(to_node)
|
|
407
|
+
|
|
408
|
+
edges.append(EdgeSpec(from_node=from_id, to_node=to_id))
|
|
409
|
+
|
|
410
|
+
# Recurse into complex nodes
|
|
411
|
+
edges.extend(self._generate_edges(from_node))
|
|
412
|
+
|
|
413
|
+
# Last node recursion
|
|
414
|
+
if root.nodes:
|
|
415
|
+
edges.extend(self._generate_edges(root.nodes[-1]))
|
|
416
|
+
|
|
417
|
+
elif isinstance(root, DSLParallel):
|
|
418
|
+
# Edges into each branch
|
|
419
|
+
for branch in root.branches:
|
|
420
|
+
entry_id = self._get_entry_id(branch)
|
|
421
|
+
edges.append(EdgeSpec(from_node=root.id, to_node=entry_id))
|
|
422
|
+
edges.extend(self._generate_edges(branch))
|
|
423
|
+
|
|
424
|
+
elif isinstance(root, DSLCondition):
|
|
425
|
+
if root.then_branch:
|
|
426
|
+
then_entry = self._get_entry_id(root.then_branch)
|
|
427
|
+
edges.append(EdgeSpec(
|
|
428
|
+
from_node=root.id,
|
|
429
|
+
to_node=then_entry,
|
|
430
|
+
when=root.expression,
|
|
431
|
+
))
|
|
432
|
+
edges.extend(self._generate_edges(root.then_branch))
|
|
433
|
+
|
|
434
|
+
if root.else_branch:
|
|
435
|
+
else_entry = self._get_entry_id(root.else_branch)
|
|
436
|
+
edges.append(EdgeSpec(
|
|
437
|
+
from_node=root.id,
|
|
438
|
+
to_node=else_entry,
|
|
439
|
+
when=f"not ({root.expression})",
|
|
440
|
+
))
|
|
441
|
+
edges.extend(self._generate_edges(root.else_branch))
|
|
442
|
+
|
|
443
|
+
return edges
|
|
444
|
+
|
|
445
|
+
def _get_entry_id(self, node: DSLNode | DSLSequence) -> str:
|
|
446
|
+
"""Get entry node ID."""
|
|
447
|
+
if isinstance(node, DSLSequence):
|
|
448
|
+
return node.nodes[0].id if node.nodes else ""
|
|
449
|
+
return node.id
|
|
450
|
+
|
|
451
|
+
def _get_exit_id(self, node: DSLNode | DSLSequence) -> str:
|
|
452
|
+
"""Get exit node ID."""
|
|
453
|
+
if isinstance(node, DSLSequence):
|
|
454
|
+
return node.nodes[-1].id if node.nodes else ""
|
|
455
|
+
return node.id
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
# ========== Convenience Functions ==========
|
|
459
|
+
|
|
460
|
+
def workflow(
|
|
461
|
+
name: str,
|
|
462
|
+
version: str = "1.0",
|
|
463
|
+
description: str | None = None,
|
|
464
|
+
) -> DSLWorkflow:
|
|
465
|
+
"""Create a new workflow builder.
|
|
466
|
+
|
|
467
|
+
Usage:
|
|
468
|
+
wf = workflow("my_workflow", version="1.0")
|
|
469
|
+
wf = wf >> step("agent_a") >> step("agent_b")
|
|
470
|
+
spec = wf.build()
|
|
471
|
+
"""
|
|
472
|
+
return DSLWorkflow(name, version, description)
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def step(
|
|
476
|
+
agent: str,
|
|
477
|
+
inputs: dict[str, Any] | None = None,
|
|
478
|
+
output: str | None = None,
|
|
479
|
+
id: str | None = None,
|
|
480
|
+
middleware: list["Middleware"] | None = None,
|
|
481
|
+
) -> DSLStep:
|
|
482
|
+
"""Create an agent step.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
agent: Agent name to execute
|
|
486
|
+
inputs: Input mapping (state keys or expressions)
|
|
487
|
+
output: State key to store output
|
|
488
|
+
id: Optional explicit node ID
|
|
489
|
+
middleware: Optional step-level middleware list
|
|
490
|
+
|
|
491
|
+
Usage:
|
|
492
|
+
step("research_agent", inputs={"query": "state.user_query"}, output="results")
|
|
493
|
+
|
|
494
|
+
# With step-level middleware
|
|
495
|
+
step("agent", middleware=[LoggingMiddleware(), MetricsMiddleware()])
|
|
496
|
+
"""
|
|
497
|
+
return DSLStep(agent=agent, inputs=inputs, output=output, id=id, middleware=middleware)
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def parallel(
|
|
501
|
+
*branches: DSLNode | DSLSequence,
|
|
502
|
+
merge: str = "collect_list",
|
|
503
|
+
id: str | None = None,
|
|
504
|
+
) -> DSLParallel:
|
|
505
|
+
"""Create parallel execution node.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
*branches: Parallel branches (steps or sequences)
|
|
509
|
+
merge: Merge strategy ("collect_list", "collect_dict", "first_success")
|
|
510
|
+
id: Optional explicit node ID
|
|
511
|
+
|
|
512
|
+
Usage:
|
|
513
|
+
parallel(
|
|
514
|
+
step("agent_a"),
|
|
515
|
+
step("agent_b") >> step("agent_c"),
|
|
516
|
+
merge="collect_list",
|
|
517
|
+
)
|
|
518
|
+
"""
|
|
519
|
+
return DSLParallel(*branches, merge=merge, id=id)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def condition(
|
|
523
|
+
expr: str,
|
|
524
|
+
then_: DSLNode | DSLSequence | None = None,
|
|
525
|
+
else_: DSLNode | DSLSequence | None = None,
|
|
526
|
+
id: str | None = None,
|
|
527
|
+
) -> DSLCondition:
|
|
528
|
+
"""Create conditional branching node.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
expr: Condition expression (evaluated against state)
|
|
532
|
+
then_: Branch to execute if condition is true
|
|
533
|
+
else_: Branch to execute if condition is false
|
|
534
|
+
id: Optional explicit node ID
|
|
535
|
+
|
|
536
|
+
Usage:
|
|
537
|
+
condition(
|
|
538
|
+
expr="state.count > 10",
|
|
539
|
+
then_=step("high_path"),
|
|
540
|
+
else_=step("low_path"),
|
|
541
|
+
)
|
|
542
|
+
"""
|
|
543
|
+
return DSLCondition(expr=expr, then_=then_, else_=else_, id=id)
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def skip(condition: str) -> Callable[[DSLNode], DSLNode]:
|
|
547
|
+
"""Decorator to set skip condition.
|
|
548
|
+
|
|
549
|
+
Usage:
|
|
550
|
+
@skip("state.skip_step")
|
|
551
|
+
def my_step():
|
|
552
|
+
return step("agent")
|
|
553
|
+
|
|
554
|
+
Or inline:
|
|
555
|
+
step("agent").skip_when("state.skip_step")
|
|
556
|
+
"""
|
|
557
|
+
def decorator(node: DSLNode) -> DSLNode:
|
|
558
|
+
node.condition = condition
|
|
559
|
+
return node
|
|
560
|
+
return decorator
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
__all__ = [
|
|
564
|
+
"DSLNode",
|
|
565
|
+
"DSLSequence",
|
|
566
|
+
"DSLStep",
|
|
567
|
+
"DSLParallel",
|
|
568
|
+
"DSLCondition",
|
|
569
|
+
"DSLWorkflow",
|
|
570
|
+
"workflow",
|
|
571
|
+
"step",
|
|
572
|
+
"parallel",
|
|
573
|
+
"condition",
|
|
574
|
+
"skip",
|
|
575
|
+
]
|