commandnet 0.2.2__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {commandnet-0.2.2 → commandnet-0.3.0}/PKG-INFO +1 -1
- {commandnet-0.2.2 → commandnet-0.3.0}/commandnet/__init__.py +1 -1
- {commandnet-0.2.2 → commandnet-0.3.0}/commandnet/core/node.py +18 -11
- commandnet-0.3.0/commandnet/engine/runtime.py +292 -0
- {commandnet-0.2.2 → commandnet-0.3.0}/commandnet/interfaces/persistence.py +9 -1
- {commandnet-0.2.2 → commandnet-0.3.0}/pyproject.toml +1 -1
- commandnet-0.2.2/commandnet/engine/runtime.py +0 -211
- {commandnet-0.2.2 → commandnet-0.3.0}/README.md +0 -0
- {commandnet-0.2.2 → commandnet-0.3.0}/commandnet/core/graph.py +0 -0
- {commandnet-0.2.2 → commandnet-0.3.0}/commandnet/core/models.py +0 -0
- {commandnet-0.2.2 → commandnet-0.3.0}/commandnet/interfaces/event_bus.py +0 -0
- {commandnet-0.2.2 → commandnet-0.3.0}/commandnet/interfaces/observer.py +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .core.models import Event
|
|
2
|
-
from .core.node import Node, Parallel, ParallelTask, Schedule
|
|
2
|
+
from .core.node import Node, Parallel, ParallelTask, Schedule, Wait
|
|
3
3
|
from .core.graph import GraphAnalyzer
|
|
4
4
|
from .interfaces.persistence import Persistence
|
|
5
5
|
from .interfaces.event_bus import EventBus
|
|
@@ -1,30 +1,37 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import Generic, TypeVar, Type, Optional, Union, List, Dict, Any
|
|
1
|
+
from typing import Generic, TypeVar, Type, Optional, Union, List, Any
|
|
3
2
|
from pydantic import BaseModel, ConfigDict
|
|
4
3
|
|
|
5
4
|
C = TypeVar('C', bound=BaseModel) # Context
|
|
6
|
-
P = TypeVar('P', bound=BaseModel) # Payload
|
|
5
|
+
P = TypeVar('P', bound=BaseModel) # Payload
|
|
6
|
+
|
|
7
|
+
# The Recursive Type Definition
|
|
8
|
+
Target = Union[Type['Node'], 'Parallel', 'Schedule', 'Wait', None]
|
|
7
9
|
|
|
8
10
|
class ParallelTask(BaseModel):
|
|
9
11
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
10
|
-
|
|
11
|
-
payload: Optional[Any] = None
|
|
12
|
+
action: Target
|
|
12
13
|
sub_context_path: str
|
|
14
|
+
payload: Optional[Any] = None
|
|
13
15
|
|
|
14
16
|
class Parallel(BaseModel):
|
|
15
17
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
16
|
-
branches: List[ParallelTask]
|
|
17
|
-
join_node: Type['Node']
|
|
18
|
+
branches: List[Union[ParallelTask, 'Wait']]
|
|
19
|
+
join_node: Optional[Type['Node']] = None
|
|
18
20
|
|
|
19
21
|
class Schedule(BaseModel):
|
|
20
22
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
21
|
-
|
|
23
|
+
action: Target
|
|
22
24
|
delay_seconds: int
|
|
23
25
|
payload: Optional[Any] = None
|
|
24
26
|
idempotency_key: Optional[str] = None
|
|
25
27
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
+
class Wait(BaseModel):
|
|
29
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
30
|
+
signal_id: str
|
|
31
|
+
resume_action: Target
|
|
32
|
+
sub_context_path: Optional[str] = None
|
|
33
|
+
|
|
34
|
+
TransitionResult = Target
|
|
28
35
|
|
|
29
36
|
class Node(Generic[C, P]):
|
|
30
37
|
@classmethod
|
|
@@ -32,5 +39,5 @@ class Node(Generic[C, P]):
|
|
|
32
39
|
return cls.__name__
|
|
33
40
|
|
|
34
41
|
async def run(self, ctx: C, payload: Optional[P] = None) -> TransitionResult:
|
|
35
|
-
"""Executes node logic. Returns
|
|
42
|
+
"""Executes node logic. Returns a Target (Node class, Parallel, Schedule, Wait, or None)."""
|
|
36
43
|
pass
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from datetime import datetime, timezone, timedelta
|
|
4
|
+
from typing import Optional, Type, Iterable, Dict, Any, Union
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from ..core.models import Event
|
|
8
|
+
from ..core.node import Node, Parallel, Schedule, Wait, Target, ParallelTask
|
|
9
|
+
from ..core.graph import GraphAnalyzer
|
|
10
|
+
from ..interfaces.persistence import Persistence
|
|
11
|
+
from ..interfaces.event_bus import EventBus
|
|
12
|
+
from ..interfaces.observer import Observer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Engine:
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
persistence: Persistence,
|
|
19
|
+
event_bus: EventBus,
|
|
20
|
+
nodes: Iterable[Type[Node]],
|
|
21
|
+
observer: Optional[Observer] = None,
|
|
22
|
+
):
|
|
23
|
+
self.db = persistence
|
|
24
|
+
self.bus = event_bus
|
|
25
|
+
self.observer = observer or Observer()
|
|
26
|
+
self.logger = logging.getLogger("CommandNet")
|
|
27
|
+
self._scheduler_task: Optional[asyncio.Task] = None
|
|
28
|
+
|
|
29
|
+
self._registry: Dict[str, Type[Node]] = {n.get_node_name(): n for n in nodes}
|
|
30
|
+
|
|
31
|
+
# --- HELPER UTILITIES ---
|
|
32
|
+
|
|
33
|
+
def _dump_ctx(self, ctx: Any) -> dict:
|
|
34
|
+
"""Helper to safely dump context whether it's a model or a dict."""
|
|
35
|
+
if hasattr(ctx, "model_dump"):
|
|
36
|
+
return ctx.model_dump()
|
|
37
|
+
return ctx
|
|
38
|
+
|
|
39
|
+
def _get_path(self, obj: Any, path: str) -> Any:
|
|
40
|
+
"""Helper to get a value from an object or a dict."""
|
|
41
|
+
if isinstance(obj, dict):
|
|
42
|
+
return obj.get(path)
|
|
43
|
+
return getattr(obj, path)
|
|
44
|
+
|
|
45
|
+
# --- CORE RECURSIVE RESOLVER ---
|
|
46
|
+
|
|
47
|
+
async def _apply_target(
|
|
48
|
+
self,
|
|
49
|
+
agent_id: str,
|
|
50
|
+
context: Any,
|
|
51
|
+
target: Target,
|
|
52
|
+
duration: float = 0.0,
|
|
53
|
+
payload: Any = None,
|
|
54
|
+
):
|
|
55
|
+
# 1. Terminal State
|
|
56
|
+
if target is None:
|
|
57
|
+
await self.observer.on_transition(agent_id, "RUN", "TERMINAL", duration)
|
|
58
|
+
ctx_dict = self._dump_ctx(context)
|
|
59
|
+
if "#" in agent_id:
|
|
60
|
+
parent_id = agent_id.split("#")[0]
|
|
61
|
+
await self.db.save_sub_state(
|
|
62
|
+
agent_id, parent_id, "TERMINAL", ctx_dict, None
|
|
63
|
+
)
|
|
64
|
+
join_node_name = await self.db.register_sub_task_completion(agent_id)
|
|
65
|
+
if join_node_name:
|
|
66
|
+
await self._trigger_recompose(parent_id, join_node_name)
|
|
67
|
+
else:
|
|
68
|
+
await self.db.save_state(agent_id, "TERMINAL", ctx_dict, None)
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
# 2. Node Class Transition
|
|
72
|
+
if isinstance(target, type) and issubclass(target, Node):
|
|
73
|
+
node_name = target.get_node_name()
|
|
74
|
+
await self.observer.on_transition(agent_id, "RUN", node_name, duration)
|
|
75
|
+
|
|
76
|
+
p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
77
|
+
evt = Event(agent_id=agent_id, node_name=node_name, payload=p_load)
|
|
78
|
+
ctx_dict = self._dump_ctx(context)
|
|
79
|
+
|
|
80
|
+
if "#" in agent_id:
|
|
81
|
+
await self.db.save_sub_state(
|
|
82
|
+
agent_id, agent_id.split("#")[0], node_name, ctx_dict, evt
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
await self.db.save_state(agent_id, node_name, ctx_dict, evt)
|
|
86
|
+
|
|
87
|
+
await self.bus.publish(evt)
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
# 3. Wait Directive
|
|
91
|
+
if isinstance(target, Wait):
|
|
92
|
+
actual_id = agent_id
|
|
93
|
+
actual_ctx = context
|
|
94
|
+
if target.sub_context_path and "#" not in agent_id:
|
|
95
|
+
actual_id = f"{agent_id}#{target.sub_context_path}"
|
|
96
|
+
actual_ctx = self._get_path(context, target.sub_context_path)
|
|
97
|
+
|
|
98
|
+
await self.observer.on_transition(
|
|
99
|
+
actual_id, "RUN", f"WAIT:{target.signal_id}", duration
|
|
100
|
+
)
|
|
101
|
+
await self.db.park_agent(
|
|
102
|
+
actual_id,
|
|
103
|
+
target.signal_id,
|
|
104
|
+
target.resume_action,
|
|
105
|
+
self._dump_ctx(actual_ctx),
|
|
106
|
+
)
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
# 4. Parallel Directive
|
|
110
|
+
if isinstance(target, Parallel):
|
|
111
|
+
join_name = target.join_node.get_node_name() if target.join_node else "FORK"
|
|
112
|
+
await self.observer.on_transition(
|
|
113
|
+
agent_id, "RUN", f"PARALLEL:{join_name}", duration
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if target.join_node:
|
|
117
|
+
await self.db.create_task_group(
|
|
118
|
+
agent_id, join_name, len(target.branches)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for branch in target.branches:
|
|
122
|
+
task = (
|
|
123
|
+
branch
|
|
124
|
+
if isinstance(branch, ParallelTask)
|
|
125
|
+
else ParallelTask(
|
|
126
|
+
action=branch,
|
|
127
|
+
sub_context_path=branch.sub_context_path, # type: ignore
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
sub_ctx = self._get_path(context, task.sub_context_path)
|
|
131
|
+
await self._apply_target(
|
|
132
|
+
f"{agent_id}#{task.sub_context_path}",
|
|
133
|
+
sub_ctx,
|
|
134
|
+
task.action,
|
|
135
|
+
duration=0,
|
|
136
|
+
payload=task.payload,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if target.join_node:
|
|
140
|
+
await self.db.save_state(
|
|
141
|
+
agent_id, "WAITING_FOR_JOIN", self._dump_ctx(context), None
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
await self._apply_target(agent_id, context, None)
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
# 5. Schedule Directive
|
|
148
|
+
if isinstance(target, Schedule):
|
|
149
|
+
if isinstance(target.action, type) and issubclass(target.action, Node):
|
|
150
|
+
target_node_name = target.action.get_node_name()
|
|
151
|
+
else:
|
|
152
|
+
raise TypeError("Schedule.action must be a Node class.")
|
|
153
|
+
|
|
154
|
+
await self.observer.on_transition(
|
|
155
|
+
agent_id, "RUN", f"SCHEDULED:{target_node_name}", duration
|
|
156
|
+
)
|
|
157
|
+
run_at_dt = datetime.now(timezone.utc) + timedelta(
|
|
158
|
+
seconds=target.delay_seconds
|
|
159
|
+
)
|
|
160
|
+
p_load = (
|
|
161
|
+
target.payload.model_dump()
|
|
162
|
+
if hasattr(target.payload, "model_dump")
|
|
163
|
+
else target.payload
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
scheduled_evt = Event(
|
|
167
|
+
agent_id=agent_id,
|
|
168
|
+
node_name=target_node_name,
|
|
169
|
+
payload=p_load,
|
|
170
|
+
run_at=run_at_dt.isoformat(),
|
|
171
|
+
idempotency_key=target.idempotency_key,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
await self.db.schedule_event(scheduled_evt)
|
|
175
|
+
await self.db.save_state(
|
|
176
|
+
agent_id, target_node_name, self._dump_ctx(context), None
|
|
177
|
+
)
|
|
178
|
+
return
|
|
179
|
+
|
|
180
|
+
# --- WORKER & EXTERNAL TRIGGERS ---
|
|
181
|
+
|
|
182
|
+
async def process_event(self, event: Event):
|
|
183
|
+
start_time = asyncio.get_event_loop().time()
|
|
184
|
+
agent_id = event.agent_id
|
|
185
|
+
current_node_name, ctx_dict = await self.db.load_and_lock_agent(agent_id)
|
|
186
|
+
|
|
187
|
+
if not current_node_name:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
if current_node_name != event.node_name:
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
node_cls = self._registry.get(current_node_name)
|
|
195
|
+
if not node_cls:
|
|
196
|
+
raise RuntimeError(f"Node '{current_node_name}' not found.")
|
|
197
|
+
|
|
198
|
+
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
199
|
+
payload_type = GraphAnalyzer.get_payload_type(node_cls)
|
|
200
|
+
|
|
201
|
+
ctx = (
|
|
202
|
+
ctx_type.model_validate(ctx_dict)
|
|
203
|
+
if issubclass(ctx_type, BaseModel)
|
|
204
|
+
else ctx_dict
|
|
205
|
+
)
|
|
206
|
+
payload = (
|
|
207
|
+
payload_type.model_validate(event.payload)
|
|
208
|
+
if (event.payload and issubclass(payload_type, BaseModel))
|
|
209
|
+
else event.payload
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
result = await node_cls().run(ctx, payload)
|
|
213
|
+
await self._apply_target(
|
|
214
|
+
agent_id,
|
|
215
|
+
ctx,
|
|
216
|
+
result,
|
|
217
|
+
(asyncio.get_event_loop().time() - start_time) * 1000,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
except Exception as e:
|
|
221
|
+
await self.observer.on_error(agent_id, current_node_name, e)
|
|
222
|
+
raise
|
|
223
|
+
finally:
|
|
224
|
+
await self.db.unlock_agent(agent_id)
|
|
225
|
+
|
|
226
|
+
async def release_signal(self, signal_id: str, payload: Any = None):
|
|
227
|
+
waiters = await self.db.get_and_clear_waiters(signal_id)
|
|
228
|
+
for waiter in waiters:
|
|
229
|
+
agent_id = waiter["agent_id"]
|
|
230
|
+
await self.db.load_and_lock_agent(agent_id)
|
|
231
|
+
try:
|
|
232
|
+
await self._apply_target(
|
|
233
|
+
agent_id=agent_id,
|
|
234
|
+
context=waiter["context"],
|
|
235
|
+
target=waiter["next_target"],
|
|
236
|
+
payload=payload,
|
|
237
|
+
)
|
|
238
|
+
finally:
|
|
239
|
+
await self.db.unlock_agent(agent_id)
|
|
240
|
+
|
|
241
|
+
# --- LIFECYCLE METHODS ---
|
|
242
|
+
|
|
243
|
+
async def start_worker(self, poll_interval: float = 1.0):
|
|
244
|
+
await self.bus.subscribe(self.process_event)
|
|
245
|
+
self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
|
|
246
|
+
self.logger.info("Worker started.")
|
|
247
|
+
|
|
248
|
+
async def _scheduler_loop(self, poll_interval: float):
|
|
249
|
+
while True:
|
|
250
|
+
try:
|
|
251
|
+
due = await self.db.pop_due_events()
|
|
252
|
+
for evt in due:
|
|
253
|
+
await self.bus.publish(evt)
|
|
254
|
+
except asyncio.CancelledError:
|
|
255
|
+
break
|
|
256
|
+
except Exception as e:
|
|
257
|
+
self.logger.error(f"Scheduler Error: {e}")
|
|
258
|
+
await asyncio.sleep(poll_interval)
|
|
259
|
+
|
|
260
|
+
async def stop(self):
|
|
261
|
+
if self._scheduler_task:
|
|
262
|
+
self._scheduler_task.cancel()
|
|
263
|
+
try:
|
|
264
|
+
await self._scheduler_task
|
|
265
|
+
except asyncio.CancelledError:
|
|
266
|
+
pass
|
|
267
|
+
|
|
268
|
+
async def trigger_agent(
|
|
269
|
+
self,
|
|
270
|
+
agent_id: str,
|
|
271
|
+
start_node: Type[Node],
|
|
272
|
+
initial_context: BaseModel,
|
|
273
|
+
payload: Optional[BaseModel] = None,
|
|
274
|
+
):
|
|
275
|
+
node_name = start_node.get_node_name()
|
|
276
|
+
evt = Event(
|
|
277
|
+
agent_id=agent_id,
|
|
278
|
+
node_name=node_name,
|
|
279
|
+
payload=payload.model_dump() if hasattr(payload, "model_dump") else payload,
|
|
280
|
+
)
|
|
281
|
+
await self.db.save_state(agent_id, node_name, initial_context.model_dump(), evt)
|
|
282
|
+
await self.bus.publish(evt)
|
|
283
|
+
|
|
284
|
+
async def _trigger_recompose(self, parent_id: str, join_node_name: str):
|
|
285
|
+
await self.db.load_and_lock_agent(parent_id)
|
|
286
|
+
merged_ctx_dict = await self.db.recompose_parent(parent_id)
|
|
287
|
+
await self._apply_target(
|
|
288
|
+
parent_id, merged_ctx_dict, self._registry[join_node_name]
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
def validate_graph(self, start_node: Type[Node]):
|
|
292
|
+
return GraphAnalyzer.validate(start_node, self._registry)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, Optional, Tuple, List
|
|
2
|
+
from typing import Dict, Optional, Tuple, List, Any
|
|
3
3
|
from ..core.models import Event
|
|
4
4
|
|
|
5
5
|
class Persistence(ABC):
|
|
@@ -39,3 +39,11 @@ class Persistence(ABC):
|
|
|
39
39
|
@abstractmethod
|
|
40
40
|
async def pop_due_events(self) -> List[Event]:
|
|
41
41
|
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
async def park_agent(self, agent_id: str, signal_id: str, next_target: Any, context: Dict):
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
async def get_and_clear_waiters(self, signal_id: str) -> List[Dict]:
|
|
49
|
+
pass
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "commandnet"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.3.0"
|
|
4
4
|
description = "A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime."
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Christopher Vaz", email = "christophervaz160@gmail.com" }
|
|
@@ -1,211 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import logging
|
|
3
|
-
from datetime import datetime, timezone, timedelta
|
|
4
|
-
from typing import Optional, Type, Iterable, Dict
|
|
5
|
-
from pydantic import BaseModel
|
|
6
|
-
|
|
7
|
-
from ..core.models import Event
|
|
8
|
-
from ..core.node import Node, Parallel, Schedule
|
|
9
|
-
from ..core.graph import GraphAnalyzer
|
|
10
|
-
from ..interfaces.persistence import Persistence
|
|
11
|
-
from ..interfaces.event_bus import EventBus
|
|
12
|
-
from ..interfaces.observer import Observer
|
|
13
|
-
|
|
14
|
-
class Engine:
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
persistence: Persistence,
|
|
18
|
-
event_bus: EventBus,
|
|
19
|
-
nodes: Iterable[Type[Node]],
|
|
20
|
-
observer: Optional[Observer] = None
|
|
21
|
-
):
|
|
22
|
-
self.db = persistence
|
|
23
|
-
self.bus = event_bus
|
|
24
|
-
self.observer = observer or Observer()
|
|
25
|
-
self.logger = logging.getLogger("CommandNet")
|
|
26
|
-
self._scheduler_task: Optional[asyncio.Task] = None
|
|
27
|
-
|
|
28
|
-
# Build Engine-scoped registry
|
|
29
|
-
self._registry: Dict[str, Type[Node]] = {}
|
|
30
|
-
for node_cls in nodes:
|
|
31
|
-
name = node_cls.get_node_name()
|
|
32
|
-
if name in self._registry and self._registry[name] is not node_cls:
|
|
33
|
-
raise RuntimeError(f"Node name collision in Engine: '{name}'")
|
|
34
|
-
self._registry[name] = node_cls
|
|
35
|
-
|
|
36
|
-
def validate_graph(self, start_node: Type[Node]):
|
|
37
|
-
"""Helper to validate the graph using this engine's registry."""
|
|
38
|
-
return GraphAnalyzer.validate(start_node, self._registry)
|
|
39
|
-
|
|
40
|
-
async def start_worker(self, poll_interval: float = 1.0):
|
|
41
|
-
await self.bus.subscribe(self.process_event)
|
|
42
|
-
self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
|
|
43
|
-
self.logger.info("Worker and Scheduler started.")
|
|
44
|
-
|
|
45
|
-
async def _scheduler_loop(self, poll_interval: float):
|
|
46
|
-
while True:
|
|
47
|
-
try:
|
|
48
|
-
due_events = await self.db.pop_due_events()
|
|
49
|
-
for evt in due_events:
|
|
50
|
-
await self.bus.publish(evt)
|
|
51
|
-
except asyncio.CancelledError:
|
|
52
|
-
break
|
|
53
|
-
except Exception as e:
|
|
54
|
-
self.logger.error(f"Scheduler Error: {e}")
|
|
55
|
-
await asyncio.sleep(poll_interval)
|
|
56
|
-
|
|
57
|
-
async def stop(self):
|
|
58
|
-
if self._scheduler_task:
|
|
59
|
-
self._scheduler_task.cancel()
|
|
60
|
-
try:
|
|
61
|
-
await self._scheduler_task
|
|
62
|
-
except asyncio.CancelledError:
|
|
63
|
-
pass
|
|
64
|
-
|
|
65
|
-
async def trigger_agent(self, agent_id: str, start_node: Type[Node], initial_context: BaseModel, payload: Optional[BaseModel] = None):
|
|
66
|
-
node_name = start_node.get_node_name()
|
|
67
|
-
if node_name not in self._registry:
|
|
68
|
-
raise ValueError(f"Node '{node_name}' is not registered with this Engine.")
|
|
69
|
-
|
|
70
|
-
start_event = Event(
|
|
71
|
-
agent_id=agent_id,
|
|
72
|
-
node_name=node_name,
|
|
73
|
-
payload=payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
74
|
-
)
|
|
75
|
-
await self.db.save_state(agent_id, node_name, initial_context.model_dump(), start_event)
|
|
76
|
-
await self.bus.publish(start_event)
|
|
77
|
-
|
|
78
|
-
async def process_event(self, event: Event):
|
|
79
|
-
start_time = asyncio.get_event_loop().time()
|
|
80
|
-
|
|
81
|
-
current_node_name, ctx_dict = await self.db.load_and_lock_agent(event.agent_id)
|
|
82
|
-
if not current_node_name:
|
|
83
|
-
return
|
|
84
|
-
|
|
85
|
-
locked = True
|
|
86
|
-
try:
|
|
87
|
-
if current_node_name != event.node_name:
|
|
88
|
-
# Same logic as before for sub-state/state
|
|
89
|
-
if "#" in event.agent_id:
|
|
90
|
-
parent_id = event.agent_id.split("#")[0]
|
|
91
|
-
await self.db.save_sub_state(event.agent_id, parent_id, current_node_name, ctx_dict, None)
|
|
92
|
-
else:
|
|
93
|
-
await self.db.save_state(event.agent_id, current_node_name, ctx_dict, None)
|
|
94
|
-
locked = False
|
|
95
|
-
return
|
|
96
|
-
|
|
97
|
-
node_cls = self._registry.get(current_node_name)
|
|
98
|
-
if not node_cls:
|
|
99
|
-
raise RuntimeError(f"Node '{current_node_name}' not found in this Engine's registry.")
|
|
100
|
-
|
|
101
|
-
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
102
|
-
payload_type = GraphAnalyzer.get_payload_type(node_cls)
|
|
103
|
-
|
|
104
|
-
ctx = ctx_type.model_validate(ctx_dict) if issubclass(ctx_type, BaseModel) else ctx_dict
|
|
105
|
-
|
|
106
|
-
payload = None
|
|
107
|
-
if event.payload is not None:
|
|
108
|
-
if isinstance(payload_type, type) and issubclass(payload_type, BaseModel):
|
|
109
|
-
payload = payload_type.model_validate(event.payload)
|
|
110
|
-
else:
|
|
111
|
-
payload = event.payload
|
|
112
|
-
|
|
113
|
-
node_instance = node_cls()
|
|
114
|
-
result = await node_instance.run(ctx, payload)
|
|
115
|
-
duration = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
116
|
-
|
|
117
|
-
if isinstance(result, Parallel):
|
|
118
|
-
await self._handle_parallel_start(event.agent_id, ctx, result, duration)
|
|
119
|
-
elif isinstance(result, Schedule):
|
|
120
|
-
await self._handle_schedule(event.agent_id, current_node_name, ctx, result, duration)
|
|
121
|
-
elif result:
|
|
122
|
-
await self._handle_transition(event.agent_id, current_node_name, result, ctx, duration)
|
|
123
|
-
else:
|
|
124
|
-
await self._handle_terminal(event.agent_id, current_node_name, ctx, duration)
|
|
125
|
-
|
|
126
|
-
locked = False
|
|
127
|
-
|
|
128
|
-
except Exception as e:
|
|
129
|
-
await self.observer.on_error(event.agent_id, current_node_name, e)
|
|
130
|
-
raise
|
|
131
|
-
finally:
|
|
132
|
-
if locked:
|
|
133
|
-
await self.db.unlock_agent(event.agent_id)
|
|
134
|
-
|
|
135
|
-
# Remaining _handle_* methods use self._registry and self.validate_graph
|
|
136
|
-
async def _handle_transition(self, agent_id: str, from_node: str, next_node_cls: Type[Node], ctx: BaseModel, duration: float):
|
|
137
|
-
next_name = next_node_cls.get_node_name()
|
|
138
|
-
if next_name not in self._registry:
|
|
139
|
-
raise RuntimeError(f"Transition target '{next_name}' not in registry.")
|
|
140
|
-
|
|
141
|
-
await self.observer.on_transition(agent_id, from_node, next_name, duration)
|
|
142
|
-
next_event = Event(agent_id=agent_id, node_name=next_name)
|
|
143
|
-
await self.db.save_state(agent_id, next_name, ctx.model_dump(), next_event)
|
|
144
|
-
await self.bus.publish(next_event)
|
|
145
|
-
|
|
146
|
-
async def _handle_parallel_start(self, parent_id: str, parent_ctx: BaseModel, parallel: Parallel, duration: float):
|
|
147
|
-
join_name = parallel.join_node.get_node_name()
|
|
148
|
-
await self.observer.on_transition(parent_id, "ParallelStart", join_name, duration)
|
|
149
|
-
await self.db.create_task_group(parent_id=parent_id, join_node_name=join_name, task_count=len(parallel.branches))
|
|
150
|
-
|
|
151
|
-
for task in parallel.branches:
|
|
152
|
-
if not hasattr(parent_ctx, task.sub_context_path):
|
|
153
|
-
raise RuntimeError(f"Context missing path: '{task.sub_context_path}'.")
|
|
154
|
-
|
|
155
|
-
sub_ctx = getattr(parent_ctx, task.sub_context_path)
|
|
156
|
-
sub_id = f"{parent_id}#{task.sub_context_path}"
|
|
157
|
-
node_name = task.node_cls.get_node_name()
|
|
158
|
-
|
|
159
|
-
evt = Event(
|
|
160
|
-
agent_id=sub_id,
|
|
161
|
-
node_name=node_name,
|
|
162
|
-
payload=task.payload.model_dump() if hasattr(task.payload, "model_dump") else task.payload
|
|
163
|
-
)
|
|
164
|
-
await self.db.save_sub_state(sub_id, parent_id, node_name, sub_ctx.model_dump(), evt)
|
|
165
|
-
await self.bus.publish(evt)
|
|
166
|
-
|
|
167
|
-
await self.db.save_state(parent_id, "WAITING_FOR_JOIN", parent_ctx.model_dump(), None)
|
|
168
|
-
|
|
169
|
-
async def _handle_terminal(self, agent_id: str, from_node: str, ctx: BaseModel, duration: float):
|
|
170
|
-
await self.observer.on_transition(agent_id, from_node, "TERMINAL", duration)
|
|
171
|
-
if "#" in agent_id:
|
|
172
|
-
parent_id = agent_id.split("#")[0]
|
|
173
|
-
await self.db.save_sub_state(agent_id, parent_id, "TERMINAL", ctx.model_dump(), None)
|
|
174
|
-
join_node_name = await self.db.register_sub_task_completion(agent_id)
|
|
175
|
-
if join_node_name:
|
|
176
|
-
await self._trigger_recompose(parent_id, join_node_name)
|
|
177
|
-
else:
|
|
178
|
-
await self.db.save_state(agent_id, "TERMINAL", ctx.model_dump(), None)
|
|
179
|
-
|
|
180
|
-
async def _handle_schedule(self, agent_id: str, from_node: str, ctx: BaseModel, schedule: Schedule, duration: float):
|
|
181
|
-
target_name = schedule.node_cls.get_node_name()
|
|
182
|
-
await self.observer.on_transition(agent_id, from_node, f"SCHEDULED:{target_name}", duration)
|
|
183
|
-
run_at_dt = datetime.now(timezone.utc) + timedelta(seconds=schedule.delay_seconds)
|
|
184
|
-
|
|
185
|
-
evt = Event(
|
|
186
|
-
agent_id=agent_id,
|
|
187
|
-
node_name=target_name,
|
|
188
|
-
payload=schedule.payload.model_dump() if hasattr(schedule.payload, "model_dump") else schedule.payload,
|
|
189
|
-
run_at=run_at_dt.isoformat(),
|
|
190
|
-
idempotency_key=schedule.idempotency_key
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
scheduled = await self.db.schedule_event(evt)
|
|
194
|
-
next_node = target_name if scheduled else "TERMINAL"
|
|
195
|
-
|
|
196
|
-
if "#" in agent_id:
|
|
197
|
-
parent_id = agent_id.split("#")[0]
|
|
198
|
-
await self.db.save_sub_state(agent_id, parent_id, next_node, ctx.model_dump(), None)
|
|
199
|
-
if not scheduled:
|
|
200
|
-
join_node_name = await self.db.register_sub_task_completion(agent_id)
|
|
201
|
-
if join_node_name:
|
|
202
|
-
await self._trigger_recompose(parent_id, join_node_name)
|
|
203
|
-
else:
|
|
204
|
-
await self.db.save_state(agent_id, next_node, ctx.model_dump(), None)
|
|
205
|
-
|
|
206
|
-
async def _trigger_recompose(self, parent_id: str, join_node_name: str):
|
|
207
|
-
await self.db.load_and_lock_agent(parent_id)
|
|
208
|
-
merged_ctx_dict = await self.db.recompose_parent(parent_id)
|
|
209
|
-
join_event = Event(agent_id=parent_id, node_name=join_node_name)
|
|
210
|
-
await self.db.save_state(parent_id, join_node_name, merged_ctx_dict, join_event)
|
|
211
|
-
await self.bus.publish(join_event)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|