commandnet 0.2.2__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {commandnet-0.2.2 → commandnet-0.4.0}/PKG-INFO +4 -3
- {commandnet-0.2.2 → commandnet-0.4.0}/README.md +3 -2
- {commandnet-0.2.2 → commandnet-0.4.0}/commandnet/__init__.py +3 -1
- {commandnet-0.2.2 → commandnet-0.4.0}/commandnet/core/graph.py +1 -0
- {commandnet-0.2.2 → commandnet-0.4.0}/commandnet/core/models.py +2 -1
- commandnet-0.4.0/commandnet/core/node.py +65 -0
- commandnet-0.4.0/commandnet/engine/runtime.py +402 -0
- {commandnet-0.2.2 → commandnet-0.4.0}/commandnet/interfaces/event_bus.py +1 -0
- commandnet-0.4.0/commandnet/interfaces/observer.py +6 -0
- commandnet-0.4.0/commandnet/interfaces/persistence.py +66 -0
- {commandnet-0.2.2 → commandnet-0.4.0}/pyproject.toml +2 -1
- commandnet-0.2.2/commandnet/core/node.py +0 -36
- commandnet-0.2.2/commandnet/engine/runtime.py +0 -211
- commandnet-0.2.2/commandnet/interfaces/observer.py +0 -5
- commandnet-0.2.2/commandnet/interfaces/persistence.py +0 -41
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: commandnet
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime.
|
|
5
5
|
Author: Christopher Vaz
|
|
6
6
|
Author-email: christophervaz160@gmail.com
|
|
@@ -53,7 +53,7 @@ pip install commandnet
|
|
|
53
53
|
## 🛠️ Quick Start
|
|
54
54
|
|
|
55
55
|
### 1. Define Your Context
|
|
56
|
-
The "Context" is the persistent state of your
|
|
56
|
+
The "Context" is the persistent state of your subject, defined using Pydantic.
|
|
57
57
|
|
|
58
58
|
```python
|
|
59
59
|
from pydantic import BaseModel
|
|
@@ -139,7 +139,7 @@ engine = Engine(persistence=db, event_bus=bus)
|
|
|
139
139
|
await engine.start_worker()
|
|
140
140
|
|
|
141
141
|
# Trigger an execution
|
|
142
|
-
await engine.
|
|
142
|
+
await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
|
|
143
143
|
```
|
|
144
144
|
|
|
145
145
|
---
|
|
@@ -172,3 +172,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
|
|
|
172
172
|
|
|
173
173
|
MIT
|
|
174
174
|
|
|
175
|
+
|
|
@@ -30,7 +30,7 @@ pip install commandnet
|
|
|
30
30
|
## 🛠️ Quick Start
|
|
31
31
|
|
|
32
32
|
### 1. Define Your Context
|
|
33
|
-
The "Context" is the persistent state of your
|
|
33
|
+
The "Context" is the persistent state of your subject, defined using Pydantic.
|
|
34
34
|
|
|
35
35
|
```python
|
|
36
36
|
from pydantic import BaseModel
|
|
@@ -116,7 +116,7 @@ engine = Engine(persistence=db, event_bus=bus)
|
|
|
116
116
|
await engine.start_worker()
|
|
117
117
|
|
|
118
118
|
# Trigger an execution
|
|
119
|
-
await engine.
|
|
119
|
+
await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
|
|
120
120
|
```
|
|
121
121
|
|
|
122
122
|
---
|
|
@@ -148,3 +148,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
|
|
|
148
148
|
## 📄 License
|
|
149
149
|
|
|
150
150
|
MIT
|
|
151
|
+
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .core.models import Event
|
|
2
|
-
from .core.node import Node, Parallel, ParallelTask, Schedule
|
|
2
|
+
from .core.node import Node, Parallel, ParallelTask, Schedule, Wait, Call, Interrupt
|
|
3
3
|
from .core.graph import GraphAnalyzer
|
|
4
4
|
from .interfaces.persistence import Persistence
|
|
5
5
|
from .interfaces.event_bus import EventBus
|
|
@@ -9,4 +9,6 @@ from .engine.runtime import Engine
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
"Event", "Node", "Parallel", "ParallelTask", "GraphAnalyzer",
|
|
11
11
|
"Persistence", "EventBus", "Observer", "Engine", "Schedule",
|
|
12
|
+
"Call", "Interrupt",
|
|
12
13
|
]
|
|
14
|
+
|
|
@@ -8,10 +8,11 @@ def utcnow_iso() -> str:
|
|
|
8
8
|
|
|
9
9
|
class Event(BaseModel):
|
|
10
10
|
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
11
|
-
|
|
11
|
+
subject_id: str
|
|
12
12
|
node_name: str
|
|
13
13
|
payload: Optional[Dict[str, Any]] = None
|
|
14
14
|
|
|
15
15
|
timestamp: str = Field(default_factory=utcnow_iso)
|
|
16
16
|
run_at: str = Field(default_factory=utcnow_iso)
|
|
17
17
|
idempotency_key: Optional[str] = None
|
|
18
|
+
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Generic, TypeVar, Type, Optional, Union, List, Any
|
|
2
|
+
from pydantic import BaseModel, ConfigDict
|
|
3
|
+
|
|
4
|
+
C = TypeVar('C', bound=BaseModel) # Context
|
|
5
|
+
P = TypeVar('P', bound=BaseModel) # Payload
|
|
6
|
+
|
|
7
|
+
# The Recursive Type Definition
|
|
8
|
+
Target = Union[Type['Node'], 'Parallel', 'Schedule', 'Wait', 'Call', 'Interrupt', None]
|
|
9
|
+
|
|
10
|
+
class ParallelTask(BaseModel):
|
|
11
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
12
|
+
action: Target
|
|
13
|
+
sub_context_path: str
|
|
14
|
+
payload: Optional[Any] = None
|
|
15
|
+
|
|
16
|
+
class Parallel(BaseModel):
|
|
17
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
18
|
+
branches: List[Union[ParallelTask, 'Wait']]
|
|
19
|
+
join_node: Optional[Type['Node']] = None
|
|
20
|
+
|
|
21
|
+
class Schedule(BaseModel):
|
|
22
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
23
|
+
action: Target
|
|
24
|
+
delay_seconds: int
|
|
25
|
+
payload: Optional[Any] = None
|
|
26
|
+
idempotency_key: Optional[str] = None
|
|
27
|
+
|
|
28
|
+
class Wait(BaseModel):
|
|
29
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
30
|
+
signal_id: str
|
|
31
|
+
resume_action: Target
|
|
32
|
+
sub_context_path: Optional[str] = None
|
|
33
|
+
|
|
34
|
+
class Call(BaseModel):
|
|
35
|
+
"""The 'Await' type: Deduplicates execution based on an idempotency key."""
|
|
36
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
37
|
+
node_cls: Type['Node']
|
|
38
|
+
idempotency_key: str
|
|
39
|
+
payload: Optional[Any] = None
|
|
40
|
+
# After the called node finishes, the caller resumes at this node
|
|
41
|
+
resume_action: Target = None
|
|
42
|
+
|
|
43
|
+
class Interrupt(BaseModel):
|
|
44
|
+
"""Signals the engine to stop execution."""
|
|
45
|
+
subject_id: str
|
|
46
|
+
hard: bool = True # True = kill task, False = flag for soft exit
|
|
47
|
+
|
|
48
|
+
TransitionResult = Target
|
|
49
|
+
|
|
50
|
+
class Node(Generic[C, P]):
|
|
51
|
+
@classmethod
|
|
52
|
+
def get_node_name(cls) -> str:
|
|
53
|
+
return cls.__name__
|
|
54
|
+
|
|
55
|
+
async def run(self, ctx: C, payload: Optional[P] = None) -> TransitionResult:
|
|
56
|
+
"""Executes node logic. Returns a Target (Node class, Parallel, Schedule, Wait, or None)."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
async def on_signal(self, ctx: C, signal_id: str, payload: Any) -> Target:
|
|
60
|
+
"""
|
|
61
|
+
CALLBACK: Called when a 'Wait' state is resolved or an external signal
|
|
62
|
+
is sent to this node specifically.
|
|
63
|
+
"""
|
|
64
|
+
return None
|
|
65
|
+
|
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from datetime import datetime, timezone, timedelta
|
|
4
|
+
from typing import Optional, Type, Iterable, Dict, Any, Union, List
|
|
5
|
+
|
|
6
|
+
from ..core.models import Event
|
|
7
|
+
from ..core.node import (
|
|
8
|
+
Node,
|
|
9
|
+
Parallel,
|
|
10
|
+
Schedule,
|
|
11
|
+
Wait,
|
|
12
|
+
Target,
|
|
13
|
+
ParallelTask,
|
|
14
|
+
Call,
|
|
15
|
+
Interrupt,
|
|
16
|
+
)
|
|
17
|
+
from ..core.graph import GraphAnalyzer
|
|
18
|
+
from ..interfaces.persistence import Persistence
|
|
19
|
+
from ..interfaces.event_bus import EventBus
|
|
20
|
+
from ..interfaces.observer import Observer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Engine:
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
persistence: Persistence,
|
|
27
|
+
event_bus: EventBus,
|
|
28
|
+
nodes: Iterable[Type[Node]],
|
|
29
|
+
observer: Optional[Observer] = None,
|
|
30
|
+
):
|
|
31
|
+
self.db = persistence
|
|
32
|
+
self.bus = event_bus
|
|
33
|
+
self.observer = observer or Observer()
|
|
34
|
+
self.logger = logging.getLogger("CommandNet")
|
|
35
|
+
self._scheduler_task: Optional[asyncio.Task] = None
|
|
36
|
+
self._active_tasks: Dict[str, asyncio.Task] = {}
|
|
37
|
+
self._registry: Dict[str, Type[Node]] = {n.get_node_name(): n for n in nodes}
|
|
38
|
+
|
|
39
|
+
def _dump_ctx(self, ctx: Any) -> dict:
|
|
40
|
+
return ctx.model_dump() if hasattr(ctx, "model_dump") else ctx
|
|
41
|
+
|
|
42
|
+
def _get_path(self, obj: Any, path: str) -> Any:
|
|
43
|
+
if isinstance(obj, dict):
|
|
44
|
+
return obj.get(path)
|
|
45
|
+
return getattr(obj, path)
|
|
46
|
+
|
|
47
|
+
# --- CORE WORKER LOGIC ---
|
|
48
|
+
|
|
49
|
+
async def process_event(self, event: Event):
|
|
50
|
+
# 1. Internal Control (Cross-worker Hard Cancel)
|
|
51
|
+
if event.node_name == "__CONTROL__":
|
|
52
|
+
action = (event.payload or {}).get("action")
|
|
53
|
+
if action == "HARD_CANCEL" and event.subject_id in self._active_tasks:
|
|
54
|
+
self._active_tasks[event.subject_id].cancel()
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# 2. Check cancellation status
|
|
58
|
+
if await self.db.is_cancelled(event.subject_id):
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
# 3. Task Tracking for Local Hard Cancel
|
|
62
|
+
task = asyncio.create_task(self._run_node_logic(event))
|
|
63
|
+
self._active_tasks[event.subject_id] = task
|
|
64
|
+
try:
|
|
65
|
+
await task
|
|
66
|
+
except asyncio.CancelledError:
|
|
67
|
+
self.logger.warning(f"Task {event.subject_id} hard-cancelled.")
|
|
68
|
+
finally:
|
|
69
|
+
self._active_tasks.pop(event.subject_id, None)
|
|
70
|
+
|
|
71
|
+
async def _run_node_logic(self, event: Event):
|
|
72
|
+
subject_id = event.subject_id
|
|
73
|
+
# We allow "AWAITING_CALL" because an subject wakes up from that state when a 'Call' resolves
|
|
74
|
+
node_name, ctx_dict = await self.db.lock_and_load(subject_id)
|
|
75
|
+
if not node_name or (
|
|
76
|
+
node_name != event.node_name and node_name != "AWAITING_CALL"
|
|
77
|
+
):
|
|
78
|
+
if node_name:
|
|
79
|
+
await self.db.unlock_subject(subject_id)
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
node_cls = self._registry.get(event.node_name)
|
|
84
|
+
if not node_cls:
|
|
85
|
+
raise RuntimeError(f"Node '{event.node_name}' not found.")
|
|
86
|
+
|
|
87
|
+
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
88
|
+
payload_type = GraphAnalyzer.get_payload_type(node_cls)
|
|
89
|
+
|
|
90
|
+
ctx = (
|
|
91
|
+
ctx_type.model_validate(ctx_dict)
|
|
92
|
+
if hasattr(ctx_type, "model_validate")
|
|
93
|
+
else ctx_dict
|
|
94
|
+
)
|
|
95
|
+
payload = (
|
|
96
|
+
payload_type.model_validate(event.payload)
|
|
97
|
+
if (event.payload and hasattr(payload_type, "model_validate"))
|
|
98
|
+
else event.payload
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Support Soft-Cancel checking inside Node.run
|
|
102
|
+
if hasattr(ctx, "is_cancelled"):
|
|
103
|
+
ctx.is_cancelled = await self.db.is_cancelled(subject_id)
|
|
104
|
+
|
|
105
|
+
start_t = asyncio.get_event_loop().time()
|
|
106
|
+
result = await node_cls().run(ctx, payload)
|
|
107
|
+
|
|
108
|
+
await self._apply_target(
|
|
109
|
+
subject_id,
|
|
110
|
+
ctx,
|
|
111
|
+
result,
|
|
112
|
+
(asyncio.get_event_loop().time() - start_t) * 1000,
|
|
113
|
+
)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
await self.observer.on_error(subject_id, event.node_name, e)
|
|
116
|
+
raise
|
|
117
|
+
finally:
|
|
118
|
+
await self.db.unlock_subject(subject_id)
|
|
119
|
+
|
|
120
|
+
# --- RECURSIVE TARGET RESOLVER ---
|
|
121
|
+
|
|
122
|
+
async def _apply_target(
|
|
123
|
+
self,
|
|
124
|
+
subject_id: str,
|
|
125
|
+
context: Any,
|
|
126
|
+
target: Target,
|
|
127
|
+
duration: float = 0.0,
|
|
128
|
+
payload: Any = None,
|
|
129
|
+
):
|
|
130
|
+
# 1. Interrupt (Cancellation)
|
|
131
|
+
if isinstance(target, Interrupt):
|
|
132
|
+
await self.cancel_subject(target.subject_id, target.hard)
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
# 2. Call (Idempotent Await)
|
|
136
|
+
if isinstance(target, Call):
|
|
137
|
+
is_leader = await self.db.add_call_waiter(
|
|
138
|
+
target.idempotency_key,
|
|
139
|
+
subject_id,
|
|
140
|
+
target.resume_action,
|
|
141
|
+
self._dump_ctx(context),
|
|
142
|
+
)
|
|
143
|
+
await self.observer.on_transition(
|
|
144
|
+
subject_id, "RUN", f"CALLING:{target.node_cls.get_node_name()}", duration
|
|
145
|
+
)
|
|
146
|
+
await self.db.save_state(
|
|
147
|
+
subject_id, "AWAITING_CALL", self._dump_ctx(context), None
|
|
148
|
+
)
|
|
149
|
+
if is_leader:
|
|
150
|
+
await self.trigger_subject(
|
|
151
|
+
f"call#{target.idempotency_key}",
|
|
152
|
+
target.node_cls,
|
|
153
|
+
context,
|
|
154
|
+
target.payload,
|
|
155
|
+
)
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
# 3. Terminal State
|
|
159
|
+
if target is None:
|
|
160
|
+
await self.observer.on_transition(subject_id, "RUN", "TERMINAL", duration)
|
|
161
|
+
ctx_dict = self._dump_ctx(context)
|
|
162
|
+
|
|
163
|
+
# Resolve Callers if this was a shared Virtual subject
|
|
164
|
+
if subject_id.startswith("call#"):
|
|
165
|
+
key = subject_id.split("#")[1]
|
|
166
|
+
waiters = await self.db.resolve_call_group(key)
|
|
167
|
+
for w in waiters:
|
|
168
|
+
# Resume waiters using the final context of the shared node as payload
|
|
169
|
+
await self._apply_target(
|
|
170
|
+
w["subject_id"],
|
|
171
|
+
w["context"],
|
|
172
|
+
w["resume_target"],
|
|
173
|
+
payload=ctx_dict,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Parallel Sub-task Completion Logic
|
|
177
|
+
if "#" in subject_id and not subject_id.startswith("call#"):
|
|
178
|
+
parent_id = subject_id.split("#")[0]
|
|
179
|
+
await self.db.save_sub_state(
|
|
180
|
+
subject_id, parent_id, "TERMINAL", ctx_dict, None
|
|
181
|
+
)
|
|
182
|
+
join_node_name = await self.db.register_sub_task_completion(subject_id)
|
|
183
|
+
if join_node_name:
|
|
184
|
+
await self._trigger_recompose(parent_id, join_node_name)
|
|
185
|
+
else:
|
|
186
|
+
await self.db.save_state(subject_id, "TERMINAL", ctx_dict, None)
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
# 4. Standard Node Transition
|
|
190
|
+
if isinstance(target, type) and issubclass(target, Node):
|
|
191
|
+
node_name = target.get_node_name()
|
|
192
|
+
await self.observer.on_transition(subject_id, "RUN", node_name, duration)
|
|
193
|
+
|
|
194
|
+
p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
195
|
+
evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load)
|
|
196
|
+
ctx_dict = self._dump_ctx(context)
|
|
197
|
+
|
|
198
|
+
if "#" in subject_id:
|
|
199
|
+
await self.db.save_sub_state(
|
|
200
|
+
subject_id, subject_id.split("#")[0], node_name, ctx_dict, evt
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
await self.db.save_state(subject_id, node_name, ctx_dict, evt)
|
|
204
|
+
|
|
205
|
+
await self.bus.publish(evt)
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
# 5. Wait / Signal Parking
|
|
209
|
+
if isinstance(target, Wait):
|
|
210
|
+
actual_id = subject_id
|
|
211
|
+
actual_ctx = context
|
|
212
|
+
if target.sub_context_path and "#" not in subject_id:
|
|
213
|
+
actual_id = f"{subject_id}#{target.sub_context_path}"
|
|
214
|
+
actual_ctx = self._get_path(context, target.sub_context_path)
|
|
215
|
+
|
|
216
|
+
await self.observer.on_transition(
|
|
217
|
+
actual_id, "RUN", f"WAIT:{target.signal_id}", duration
|
|
218
|
+
)
|
|
219
|
+
await self.db.park_subject(
|
|
220
|
+
actual_id,
|
|
221
|
+
target.signal_id,
|
|
222
|
+
target.resume_action,
|
|
223
|
+
self._dump_ctx(actual_ctx),
|
|
224
|
+
)
|
|
225
|
+
return
|
|
226
|
+
|
|
227
|
+
# 6. Parallel Fan-out
|
|
228
|
+
if isinstance(target, Parallel):
|
|
229
|
+
join_name = target.join_node.get_node_name() if target.join_node else "FORK"
|
|
230
|
+
await self.observer.on_transition(
|
|
231
|
+
subject_id, "RUN", f"PARALLEL:{join_name}", duration
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if target.join_node:
|
|
235
|
+
await self.db.create_task_group(
|
|
236
|
+
subject_id, join_name, len(target.branches)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
for branch in target.branches:
|
|
240
|
+
# Normalize branch to ParallelTask without stripping Wait wrappers
|
|
241
|
+
if isinstance(branch, ParallelTask):
|
|
242
|
+
task = branch
|
|
243
|
+
else:
|
|
244
|
+
# If it's a Wait or a Node class, wrap it but keep the action intact
|
|
245
|
+
path = (
|
|
246
|
+
branch.sub_context_path
|
|
247
|
+
if isinstance(branch, Wait) and branch.sub_context_path
|
|
248
|
+
else "default"
|
|
249
|
+
)
|
|
250
|
+
task = ParallelTask(action=branch, sub_context_path=path)
|
|
251
|
+
|
|
252
|
+
sub_ctx = self._get_path(context, task.sub_context_path)
|
|
253
|
+
await self._apply_target(
|
|
254
|
+
f"{subject_id}#{task.sub_context_path}",
|
|
255
|
+
sub_ctx,
|
|
256
|
+
task.action,
|
|
257
|
+
payload=task.payload,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
if target.join_node:
|
|
261
|
+
await self.db.save_state(
|
|
262
|
+
subject_id, "WAITING_FOR_JOIN", self._dump_ctx(context), None
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
await self._apply_target(subject_id, context, None)
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
# 7. Schedule (Delayed Execution)
|
|
269
|
+
if isinstance(target, Schedule):
|
|
270
|
+
if not (
|
|
271
|
+
isinstance(target.action, type) and issubclass(target.action, Node)
|
|
272
|
+
):
|
|
273
|
+
raise TypeError("Schedule.action must be a Node class.")
|
|
274
|
+
|
|
275
|
+
target_node_name = target.action.get_node_name()
|
|
276
|
+
await self.observer.on_transition(
|
|
277
|
+
subject_id, "RUN", f"SCHEDULED:{target_node_name}", duration
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
run_at_dt = datetime.now(timezone.utc) + timedelta(
|
|
281
|
+
seconds=target.delay_seconds
|
|
282
|
+
)
|
|
283
|
+
p_load = (
|
|
284
|
+
target.payload.model_dump()
|
|
285
|
+
if hasattr(target.payload, "model_dump")
|
|
286
|
+
else target.payload
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
scheduled_evt = Event(
|
|
290
|
+
subject_id=subject_id,
|
|
291
|
+
node_name=target_node_name,
|
|
292
|
+
payload=p_load,
|
|
293
|
+
run_at=run_at_dt.isoformat(),
|
|
294
|
+
idempotency_key=target.idempotency_key,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
await self.db.schedule_event(scheduled_evt)
|
|
298
|
+
await self.db.save_state(
|
|
299
|
+
subject_id, target_node_name, self._dump_ctx(context), None
|
|
300
|
+
)
|
|
301
|
+
return
|
|
302
|
+
|
|
303
|
+
# --- EXTERNAL CONTROL & SIGNALS ---
|
|
304
|
+
|
|
305
|
+
async def signal_node(self, subject_id: str, signal_id: str, payload: Any = None):
|
|
306
|
+
"""Resumes subject and triggers its specific on_signal instance method."""
|
|
307
|
+
node_name, ctx_dict = await self.db.lock_and_load(subject_id)
|
|
308
|
+
if not node_name:
|
|
309
|
+
return
|
|
310
|
+
try:
|
|
311
|
+
node_cls = self._registry.get(node_name)
|
|
312
|
+
node_inst = node_cls()
|
|
313
|
+
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
314
|
+
ctx = (
|
|
315
|
+
ctx_type.model_validate(ctx_dict)
|
|
316
|
+
if hasattr(ctx_type, "model_validate")
|
|
317
|
+
else ctx_dict
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
result = await node_inst.on_signal(ctx, signal_id, payload)
|
|
321
|
+
await self._apply_target(subject_id, ctx, result)
|
|
322
|
+
finally:
|
|
323
|
+
await self.db.unlock_subject(subject_id)
|
|
324
|
+
|
|
325
|
+
async def release_signal(self, signal_id: str, payload: Any = None):
|
|
326
|
+
"""Standard mass-resume for parked subjects."""
|
|
327
|
+
waiters = await self.db.get_and_clear_waiters(signal_id)
|
|
328
|
+
for waiter in waiters:
|
|
329
|
+
subject_id = waiter["subject_id"]
|
|
330
|
+
await self.db.lock_and_load(subject_id)
|
|
331
|
+
try:
|
|
332
|
+
await self._apply_target(
|
|
333
|
+
subject_id=subject_id,
|
|
334
|
+
context=waiter["context"],
|
|
335
|
+
target=waiter["next_target"],
|
|
336
|
+
payload=payload,
|
|
337
|
+
)
|
|
338
|
+
finally:
|
|
339
|
+
await self.db.unlock_subject(subject_id)
|
|
340
|
+
|
|
341
|
+
async def cancel_subject(self, subject_id: str, hard: bool = True):
|
|
342
|
+
await self.db.set_cancel_flag(subject_id, hard)
|
|
343
|
+
if hard:
|
|
344
|
+
if subject_id in self._active_tasks:
|
|
345
|
+
self._active_tasks[subject_id].cancel()
|
|
346
|
+
await self.bus.publish(
|
|
347
|
+
Event(
|
|
348
|
+
subject_id=subject_id,
|
|
349
|
+
node_name="__CONTROL__",
|
|
350
|
+
payload={"action": "HARD_CANCEL"},
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# --- LIFECYCLE & HELPERS ---
|
|
355
|
+
|
|
356
|
+
async def _trigger_recompose(self, parent_id: str, join_node_name: str):
|
|
357
|
+
await self.db.lock_and_load(parent_id)
|
|
358
|
+
merged_ctx_dict = await self.db.recompose_parent(parent_id)
|
|
359
|
+
await self._apply_target(
|
|
360
|
+
parent_id, merged_ctx_dict, self._registry[join_node_name]
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
async def start_worker(self, poll_interval: float = 1.0):
|
|
364
|
+
await self.bus.subscribe(self.process_event)
|
|
365
|
+
self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
|
|
366
|
+
|
|
367
|
+
async def _scheduler_loop(self, poll_interval: float):
|
|
368
|
+
while True:
|
|
369
|
+
try:
|
|
370
|
+
due = await self.db.pop_due_events()
|
|
371
|
+
for evt in due:
|
|
372
|
+
await self.bus.publish(evt)
|
|
373
|
+
except asyncio.CancelledError:
|
|
374
|
+
break
|
|
375
|
+
except Exception as e:
|
|
376
|
+
self.logger.error(f"Scheduler: {e}")
|
|
377
|
+
await asyncio.sleep(poll_interval)
|
|
378
|
+
|
|
379
|
+
async def trigger_subject(
|
|
380
|
+
self,
|
|
381
|
+
subject_id: str,
|
|
382
|
+
start_node: Type[Node],
|
|
383
|
+
initial_context: Any,
|
|
384
|
+
payload: Any = None,
|
|
385
|
+
):
|
|
386
|
+
node_name = start_node.get_node_name()
|
|
387
|
+
p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
388
|
+
evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load)
|
|
389
|
+
await self.db.save_state(
|
|
390
|
+
subject_id, node_name, self._dump_ctx(initial_context), evt
|
|
391
|
+
)
|
|
392
|
+
await self.bus.publish(evt)
|
|
393
|
+
|
|
394
|
+
async def stop(self):
|
|
395
|
+
if self._scheduler_task:
|
|
396
|
+
self._scheduler_task.cancel()
|
|
397
|
+
for task in self._active_tasks.values():
|
|
398
|
+
task.cancel()
|
|
399
|
+
|
|
400
|
+
def validate_graph(self, start_node: Type[Node]):
|
|
401
|
+
return GraphAnalyzer.validate(start_node, self._registry)
|
|
402
|
+
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, Optional, Tuple, List, Any
|
|
3
|
+
from ..core.models import Event
|
|
4
|
+
|
|
5
|
+
class Persistence(ABC):
|
|
6
|
+
@abstractmethod
|
|
7
|
+
async def lock_and_load(self, subject_id: str) -> Tuple[Optional[str], Optional[Dict]]:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
async def unlock_subject(self, subject_id: str):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
async def save_state(self, subject_id: str, node_name: str, context: Dict, event: Optional[Event]):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
async def save_sub_state(self, sub_id: str, parent_id: str, node_name: str, ctx: dict, evt: Optional[Event]):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
async def create_task_group(self, parent_id: str, join_node_name: str, task_count: int):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
async def register_sub_task_completion(self, sub_id: str) -> Optional[str]:
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
async def recompose_parent(self, parent_id: str) -> dict:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
async def schedule_event(self, event: Event) -> bool:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
async def pop_due_events(self) -> List[Event]:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
async def park_subject(self, subject_id: str, signal_id: str, next_target: Any, context: Dict):
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def get_and_clear_waiters(self, signal_id: str) -> List[Dict]:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
async def set_cancel_flag(self, subject_id: str, hard: bool):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
async def is_cancelled(self, subject_id: str) -> bool:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
async def add_call_waiter(self, key: str, subject_id: str, resume_target: Any, context: dict) -> bool:
|
|
60
|
+
"""Returns True if this is the first waiter for this key (the 'leader')."""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
async def resolve_call_group(self, key: str) -> List[Dict[str, Any]]:
|
|
65
|
+
pass
|
|
66
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "commandnet"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = "A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime."
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Christopher Vaz", email = "christophervaz160@gmail.com" }
|
|
@@ -41,3 +41,4 @@ build-backend = "poetry.core.masonry.api"
|
|
|
41
41
|
asyncio_mode = "strict"
|
|
42
42
|
asyncio_default_test_loop_scope = "function"
|
|
43
43
|
asyncio_default_fixture_loop_scope = "function"
|
|
44
|
+
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
from typing import Generic, TypeVar, Type, Optional, Union, List, Dict, Any
|
|
3
|
-
from pydantic import BaseModel, ConfigDict
|
|
4
|
-
|
|
5
|
-
C = TypeVar('C', bound=BaseModel) # Context
|
|
6
|
-
P = TypeVar('P', bound=BaseModel) # Payload (Optional)
|
|
7
|
-
|
|
8
|
-
class ParallelTask(BaseModel):
|
|
9
|
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
10
|
-
node_cls: Type['Node']
|
|
11
|
-
payload: Optional[Any] = None
|
|
12
|
-
sub_context_path: str
|
|
13
|
-
|
|
14
|
-
class Parallel(BaseModel):
|
|
15
|
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
16
|
-
branches: List[ParallelTask]
|
|
17
|
-
join_node: Type['Node']
|
|
18
|
-
|
|
19
|
-
class Schedule(BaseModel):
|
|
20
|
-
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
21
|
-
node_cls: Type['Node']
|
|
22
|
-
delay_seconds: int
|
|
23
|
-
payload: Optional[Any] = None
|
|
24
|
-
idempotency_key: Optional[str] = None
|
|
25
|
-
|
|
26
|
-
# Minor typing improvement for readability
|
|
27
|
-
TransitionResult = Union[Type['Node'], Parallel, Schedule, None]
|
|
28
|
-
|
|
29
|
-
class Node(Generic[C, P]):
|
|
30
|
-
@classmethod
|
|
31
|
-
def get_node_name(cls) -> str:
|
|
32
|
-
return cls.__name__
|
|
33
|
-
|
|
34
|
-
async def run(self, ctx: C, payload: Optional[P] = None) -> TransitionResult:
|
|
35
|
-
"""Executes node logic. Returns the Next Node, a Parallel request, a Schedule request, or None."""
|
|
36
|
-
pass
|
|
@@ -1,211 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import logging
|
|
3
|
-
from datetime import datetime, timezone, timedelta
|
|
4
|
-
from typing import Optional, Type, Iterable, Dict
|
|
5
|
-
from pydantic import BaseModel
|
|
6
|
-
|
|
7
|
-
from ..core.models import Event
|
|
8
|
-
from ..core.node import Node, Parallel, Schedule
|
|
9
|
-
from ..core.graph import GraphAnalyzer
|
|
10
|
-
from ..interfaces.persistence import Persistence
|
|
11
|
-
from ..interfaces.event_bus import EventBus
|
|
12
|
-
from ..interfaces.observer import Observer
|
|
13
|
-
|
|
14
|
-
class Engine:
|
|
15
|
-
def __init__(
|
|
16
|
-
self,
|
|
17
|
-
persistence: Persistence,
|
|
18
|
-
event_bus: EventBus,
|
|
19
|
-
nodes: Iterable[Type[Node]],
|
|
20
|
-
observer: Optional[Observer] = None
|
|
21
|
-
):
|
|
22
|
-
self.db = persistence
|
|
23
|
-
self.bus = event_bus
|
|
24
|
-
self.observer = observer or Observer()
|
|
25
|
-
self.logger = logging.getLogger("CommandNet")
|
|
26
|
-
self._scheduler_task: Optional[asyncio.Task] = None
|
|
27
|
-
|
|
28
|
-
# Build Engine-scoped registry
|
|
29
|
-
self._registry: Dict[str, Type[Node]] = {}
|
|
30
|
-
for node_cls in nodes:
|
|
31
|
-
name = node_cls.get_node_name()
|
|
32
|
-
if name in self._registry and self._registry[name] is not node_cls:
|
|
33
|
-
raise RuntimeError(f"Node name collision in Engine: '{name}'")
|
|
34
|
-
self._registry[name] = node_cls
|
|
35
|
-
|
|
36
|
-
def validate_graph(self, start_node: Type[Node]):
|
|
37
|
-
"""Helper to validate the graph using this engine's registry."""
|
|
38
|
-
return GraphAnalyzer.validate(start_node, self._registry)
|
|
39
|
-
|
|
40
|
-
async def start_worker(self, poll_interval: float = 1.0):
|
|
41
|
-
await self.bus.subscribe(self.process_event)
|
|
42
|
-
self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
|
|
43
|
-
self.logger.info("Worker and Scheduler started.")
|
|
44
|
-
|
|
45
|
-
async def _scheduler_loop(self, poll_interval: float):
|
|
46
|
-
while True:
|
|
47
|
-
try:
|
|
48
|
-
due_events = await self.db.pop_due_events()
|
|
49
|
-
for evt in due_events:
|
|
50
|
-
await self.bus.publish(evt)
|
|
51
|
-
except asyncio.CancelledError:
|
|
52
|
-
break
|
|
53
|
-
except Exception as e:
|
|
54
|
-
self.logger.error(f"Scheduler Error: {e}")
|
|
55
|
-
await asyncio.sleep(poll_interval)
|
|
56
|
-
|
|
57
|
-
async def stop(self):
|
|
58
|
-
if self._scheduler_task:
|
|
59
|
-
self._scheduler_task.cancel()
|
|
60
|
-
try:
|
|
61
|
-
await self._scheduler_task
|
|
62
|
-
except asyncio.CancelledError:
|
|
63
|
-
pass
|
|
64
|
-
|
|
65
|
-
async def trigger_agent(self, agent_id: str, start_node: Type[Node], initial_context: BaseModel, payload: Optional[BaseModel] = None):
|
|
66
|
-
node_name = start_node.get_node_name()
|
|
67
|
-
if node_name not in self._registry:
|
|
68
|
-
raise ValueError(f"Node '{node_name}' is not registered with this Engine.")
|
|
69
|
-
|
|
70
|
-
start_event = Event(
|
|
71
|
-
agent_id=agent_id,
|
|
72
|
-
node_name=node_name,
|
|
73
|
-
payload=payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
74
|
-
)
|
|
75
|
-
await self.db.save_state(agent_id, node_name, initial_context.model_dump(), start_event)
|
|
76
|
-
await self.bus.publish(start_event)
|
|
77
|
-
|
|
78
|
-
async def process_event(self, event: Event):
|
|
79
|
-
start_time = asyncio.get_event_loop().time()
|
|
80
|
-
|
|
81
|
-
current_node_name, ctx_dict = await self.db.load_and_lock_agent(event.agent_id)
|
|
82
|
-
if not current_node_name:
|
|
83
|
-
return
|
|
84
|
-
|
|
85
|
-
locked = True
|
|
86
|
-
try:
|
|
87
|
-
if current_node_name != event.node_name:
|
|
88
|
-
# Same logic as before for sub-state/state
|
|
89
|
-
if "#" in event.agent_id:
|
|
90
|
-
parent_id = event.agent_id.split("#")[0]
|
|
91
|
-
await self.db.save_sub_state(event.agent_id, parent_id, current_node_name, ctx_dict, None)
|
|
92
|
-
else:
|
|
93
|
-
await self.db.save_state(event.agent_id, current_node_name, ctx_dict, None)
|
|
94
|
-
locked = False
|
|
95
|
-
return
|
|
96
|
-
|
|
97
|
-
node_cls = self._registry.get(current_node_name)
|
|
98
|
-
if not node_cls:
|
|
99
|
-
raise RuntimeError(f"Node '{current_node_name}' not found in this Engine's registry.")
|
|
100
|
-
|
|
101
|
-
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
102
|
-
payload_type = GraphAnalyzer.get_payload_type(node_cls)
|
|
103
|
-
|
|
104
|
-
ctx = ctx_type.model_validate(ctx_dict) if issubclass(ctx_type, BaseModel) else ctx_dict
|
|
105
|
-
|
|
106
|
-
payload = None
|
|
107
|
-
if event.payload is not None:
|
|
108
|
-
if isinstance(payload_type, type) and issubclass(payload_type, BaseModel):
|
|
109
|
-
payload = payload_type.model_validate(event.payload)
|
|
110
|
-
else:
|
|
111
|
-
payload = event.payload
|
|
112
|
-
|
|
113
|
-
node_instance = node_cls()
|
|
114
|
-
result = await node_instance.run(ctx, payload)
|
|
115
|
-
duration = (asyncio.get_event_loop().time() - start_time) * 1000
|
|
116
|
-
|
|
117
|
-
if isinstance(result, Parallel):
|
|
118
|
-
await self._handle_parallel_start(event.agent_id, ctx, result, duration)
|
|
119
|
-
elif isinstance(result, Schedule):
|
|
120
|
-
await self._handle_schedule(event.agent_id, current_node_name, ctx, result, duration)
|
|
121
|
-
elif result:
|
|
122
|
-
await self._handle_transition(event.agent_id, current_node_name, result, ctx, duration)
|
|
123
|
-
else:
|
|
124
|
-
await self._handle_terminal(event.agent_id, current_node_name, ctx, duration)
|
|
125
|
-
|
|
126
|
-
locked = False
|
|
127
|
-
|
|
128
|
-
except Exception as e:
|
|
129
|
-
await self.observer.on_error(event.agent_id, current_node_name, e)
|
|
130
|
-
raise
|
|
131
|
-
finally:
|
|
132
|
-
if locked:
|
|
133
|
-
await self.db.unlock_agent(event.agent_id)
|
|
134
|
-
|
|
135
|
-
# Remaining _handle_* methods use self._registry and self.validate_graph
|
|
136
|
-
async def _handle_transition(self, agent_id: str, from_node: str, next_node_cls: Type[Node], ctx: BaseModel, duration: float):
|
|
137
|
-
next_name = next_node_cls.get_node_name()
|
|
138
|
-
if next_name not in self._registry:
|
|
139
|
-
raise RuntimeError(f"Transition target '{next_name}' not in registry.")
|
|
140
|
-
|
|
141
|
-
await self.observer.on_transition(agent_id, from_node, next_name, duration)
|
|
142
|
-
next_event = Event(agent_id=agent_id, node_name=next_name)
|
|
143
|
-
await self.db.save_state(agent_id, next_name, ctx.model_dump(), next_event)
|
|
144
|
-
await self.bus.publish(next_event)
|
|
145
|
-
|
|
146
|
-
async def _handle_parallel_start(self, parent_id: str, parent_ctx: BaseModel, parallel: Parallel, duration: float):
|
|
147
|
-
join_name = parallel.join_node.get_node_name()
|
|
148
|
-
await self.observer.on_transition(parent_id, "ParallelStart", join_name, duration)
|
|
149
|
-
await self.db.create_task_group(parent_id=parent_id, join_node_name=join_name, task_count=len(parallel.branches))
|
|
150
|
-
|
|
151
|
-
for task in parallel.branches:
|
|
152
|
-
if not hasattr(parent_ctx, task.sub_context_path):
|
|
153
|
-
raise RuntimeError(f"Context missing path: '{task.sub_context_path}'.")
|
|
154
|
-
|
|
155
|
-
sub_ctx = getattr(parent_ctx, task.sub_context_path)
|
|
156
|
-
sub_id = f"{parent_id}#{task.sub_context_path}"
|
|
157
|
-
node_name = task.node_cls.get_node_name()
|
|
158
|
-
|
|
159
|
-
evt = Event(
|
|
160
|
-
agent_id=sub_id,
|
|
161
|
-
node_name=node_name,
|
|
162
|
-
payload=task.payload.model_dump() if hasattr(task.payload, "model_dump") else task.payload
|
|
163
|
-
)
|
|
164
|
-
await self.db.save_sub_state(sub_id, parent_id, node_name, sub_ctx.model_dump(), evt)
|
|
165
|
-
await self.bus.publish(evt)
|
|
166
|
-
|
|
167
|
-
await self.db.save_state(parent_id, "WAITING_FOR_JOIN", parent_ctx.model_dump(), None)
|
|
168
|
-
|
|
169
|
-
async def _handle_terminal(self, agent_id: str, from_node: str, ctx: BaseModel, duration: float):
|
|
170
|
-
await self.observer.on_transition(agent_id, from_node, "TERMINAL", duration)
|
|
171
|
-
if "#" in agent_id:
|
|
172
|
-
parent_id = agent_id.split("#")[0]
|
|
173
|
-
await self.db.save_sub_state(agent_id, parent_id, "TERMINAL", ctx.model_dump(), None)
|
|
174
|
-
join_node_name = await self.db.register_sub_task_completion(agent_id)
|
|
175
|
-
if join_node_name:
|
|
176
|
-
await self._trigger_recompose(parent_id, join_node_name)
|
|
177
|
-
else:
|
|
178
|
-
await self.db.save_state(agent_id, "TERMINAL", ctx.model_dump(), None)
|
|
179
|
-
|
|
180
|
-
async def _handle_schedule(self, agent_id: str, from_node: str, ctx: BaseModel, schedule: Schedule, duration: float):
|
|
181
|
-
target_name = schedule.node_cls.get_node_name()
|
|
182
|
-
await self.observer.on_transition(agent_id, from_node, f"SCHEDULED:{target_name}", duration)
|
|
183
|
-
run_at_dt = datetime.now(timezone.utc) + timedelta(seconds=schedule.delay_seconds)
|
|
184
|
-
|
|
185
|
-
evt = Event(
|
|
186
|
-
agent_id=agent_id,
|
|
187
|
-
node_name=target_name,
|
|
188
|
-
payload=schedule.payload.model_dump() if hasattr(schedule.payload, "model_dump") else schedule.payload,
|
|
189
|
-
run_at=run_at_dt.isoformat(),
|
|
190
|
-
idempotency_key=schedule.idempotency_key
|
|
191
|
-
)
|
|
192
|
-
|
|
193
|
-
scheduled = await self.db.schedule_event(evt)
|
|
194
|
-
next_node = target_name if scheduled else "TERMINAL"
|
|
195
|
-
|
|
196
|
-
if "#" in agent_id:
|
|
197
|
-
parent_id = agent_id.split("#")[0]
|
|
198
|
-
await self.db.save_sub_state(agent_id, parent_id, next_node, ctx.model_dump(), None)
|
|
199
|
-
if not scheduled:
|
|
200
|
-
join_node_name = await self.db.register_sub_task_completion(agent_id)
|
|
201
|
-
if join_node_name:
|
|
202
|
-
await self._trigger_recompose(parent_id, join_node_name)
|
|
203
|
-
else:
|
|
204
|
-
await self.db.save_state(agent_id, next_node, ctx.model_dump(), None)
|
|
205
|
-
|
|
206
|
-
async def _trigger_recompose(self, parent_id: str, join_node_name: str):
|
|
207
|
-
await self.db.load_and_lock_agent(parent_id)
|
|
208
|
-
merged_ctx_dict = await self.db.recompose_parent(parent_id)
|
|
209
|
-
join_event = Event(agent_id=parent_id, node_name=join_node_name)
|
|
210
|
-
await self.db.save_state(parent_id, join_node_name, merged_ctx_dict, join_event)
|
|
211
|
-
await self.bus.publish(join_event)
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, Optional, Tuple, List
|
|
3
|
-
from ..core.models import Event
|
|
4
|
-
|
|
5
|
-
class Persistence(ABC):
|
|
6
|
-
@abstractmethod
|
|
7
|
-
async def load_and_lock_agent(self, agent_id: str) -> Tuple[Optional[str], Optional[Dict]]:
|
|
8
|
-
pass
|
|
9
|
-
|
|
10
|
-
@abstractmethod
|
|
11
|
-
async def unlock_agent(self, agent_id: str):
|
|
12
|
-
"""Releases the row lock. Called automatically by the engine if an exception occurs."""
|
|
13
|
-
pass
|
|
14
|
-
|
|
15
|
-
@abstractmethod
|
|
16
|
-
async def save_state(self, agent_id: str, node_name: str, context: Dict, event: Optional[Event]):
|
|
17
|
-
pass
|
|
18
|
-
|
|
19
|
-
@abstractmethod
|
|
20
|
-
async def save_sub_state(self, sub_id: str, parent_id: str, node_name: str, ctx: dict, evt: Optional[Event]):
|
|
21
|
-
pass
|
|
22
|
-
|
|
23
|
-
@abstractmethod
|
|
24
|
-
async def create_task_group(self, parent_id: str, join_node_name: str, task_count: int):
|
|
25
|
-
pass
|
|
26
|
-
|
|
27
|
-
@abstractmethod
|
|
28
|
-
async def register_sub_task_completion(self, sub_id: str) -> Optional[str]:
|
|
29
|
-
pass
|
|
30
|
-
|
|
31
|
-
@abstractmethod
|
|
32
|
-
async def recompose_parent(self, parent_id: str) -> dict:
|
|
33
|
-
pass
|
|
34
|
-
|
|
35
|
-
@abstractmethod
|
|
36
|
-
async def schedule_event(self, event: Event) -> bool:
|
|
37
|
-
pass
|
|
38
|
-
|
|
39
|
-
@abstractmethod
|
|
40
|
-
async def pop_due_events(self) -> List[Event]:
|
|
41
|
-
pass
|