commandnet 0.3.0__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {commandnet-0.3.0 → commandnet-0.4.1}/PKG-INFO +4 -3
- {commandnet-0.3.0 → commandnet-0.4.1}/README.md +3 -2
- {commandnet-0.3.0 → commandnet-0.4.1}/commandnet/__init__.py +3 -1
- {commandnet-0.3.0 → commandnet-0.4.1}/commandnet/core/graph.py +1 -0
- {commandnet-0.3.0 → commandnet-0.4.1}/commandnet/core/models.py +3 -1
- {commandnet-0.3.0 → commandnet-0.4.1}/commandnet/core/node.py +23 -1
- commandnet-0.4.1/commandnet/engine/runtime.py +407 -0
- {commandnet-0.3.0 → commandnet-0.4.1}/commandnet/interfaces/event_bus.py +1 -0
- commandnet-0.4.1/commandnet/interfaces/observer.py +6 -0
- {commandnet-0.3.0 → commandnet-0.4.1}/commandnet/interfaces/persistence.py +22 -5
- {commandnet-0.3.0 → commandnet-0.4.1}/pyproject.toml +2 -1
- commandnet-0.3.0/commandnet/engine/runtime.py +0 -292
- commandnet-0.3.0/commandnet/interfaces/observer.py +0 -5
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: commandnet
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime.
|
|
5
5
|
Author: Christopher Vaz
|
|
6
6
|
Author-email: christophervaz160@gmail.com
|
|
@@ -53,7 +53,7 @@ pip install commandnet
|
|
|
53
53
|
## 🛠️ Quick Start
|
|
54
54
|
|
|
55
55
|
### 1. Define Your Context
|
|
56
|
-
The "Context" is the persistent state of your
|
|
56
|
+
The "Context" is the persistent state of your subject, defined using Pydantic.
|
|
57
57
|
|
|
58
58
|
```python
|
|
59
59
|
from pydantic import BaseModel
|
|
@@ -139,7 +139,7 @@ engine = Engine(persistence=db, event_bus=bus)
|
|
|
139
139
|
await engine.start_worker()
|
|
140
140
|
|
|
141
141
|
# Trigger an execution
|
|
142
|
-
await engine.
|
|
142
|
+
await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
|
|
143
143
|
```
|
|
144
144
|
|
|
145
145
|
---
|
|
@@ -172,3 +172,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
|
|
|
172
172
|
|
|
173
173
|
MIT
|
|
174
174
|
|
|
175
|
+
|
|
@@ -30,7 +30,7 @@ pip install commandnet
|
|
|
30
30
|
## 🛠️ Quick Start
|
|
31
31
|
|
|
32
32
|
### 1. Define Your Context
|
|
33
|
-
The "Context" is the persistent state of your
|
|
33
|
+
The "Context" is the persistent state of your subject, defined using Pydantic.
|
|
34
34
|
|
|
35
35
|
```python
|
|
36
36
|
from pydantic import BaseModel
|
|
@@ -116,7 +116,7 @@ engine = Engine(persistence=db, event_bus=bus)
|
|
|
116
116
|
await engine.start_worker()
|
|
117
117
|
|
|
118
118
|
# Trigger an execution
|
|
119
|
-
await engine.
|
|
119
|
+
await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
|
|
120
120
|
```
|
|
121
121
|
|
|
122
122
|
---
|
|
@@ -148,3 +148,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
|
|
|
148
148
|
## 📄 License
|
|
149
149
|
|
|
150
150
|
MIT
|
|
151
|
+
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from .core.models import Event
|
|
2
|
-
from .core.node import Node, Parallel, ParallelTask, Schedule, Wait
|
|
2
|
+
from .core.node import Node, Parallel, ParallelTask, Schedule, Wait, Call, Interrupt
|
|
3
3
|
from .core.graph import GraphAnalyzer
|
|
4
4
|
from .interfaces.persistence import Persistence
|
|
5
5
|
from .interfaces.event_bus import EventBus
|
|
@@ -9,4 +9,6 @@ from .engine.runtime import Engine
|
|
|
9
9
|
__all__ = [
|
|
10
10
|
"Event", "Node", "Parallel", "ParallelTask", "GraphAnalyzer",
|
|
11
11
|
"Persistence", "EventBus", "Observer", "Engine", "Schedule",
|
|
12
|
+
"Call", "Interrupt",
|
|
12
13
|
]
|
|
14
|
+
|
|
@@ -8,10 +8,12 @@ def utcnow_iso() -> str:
|
|
|
8
8
|
|
|
9
9
|
class Event(BaseModel):
|
|
10
10
|
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
11
|
-
|
|
11
|
+
subject_id: str
|
|
12
12
|
node_name: str
|
|
13
13
|
payload: Optional[Dict[str, Any]] = None
|
|
14
|
+
headers: Dict[str, str] = Field(default_factory=dict)
|
|
14
15
|
|
|
15
16
|
timestamp: str = Field(default_factory=utcnow_iso)
|
|
16
17
|
run_at: str = Field(default_factory=utcnow_iso)
|
|
17
18
|
idempotency_key: Optional[str] = None
|
|
19
|
+
|
|
@@ -5,7 +5,7 @@ C = TypeVar('C', bound=BaseModel) # Context
|
|
|
5
5
|
P = TypeVar('P', bound=BaseModel) # Payload
|
|
6
6
|
|
|
7
7
|
# The Recursive Type Definition
|
|
8
|
-
Target = Union[Type['Node'], 'Parallel', 'Schedule', 'Wait', None]
|
|
8
|
+
Target = Union[Type['Node'], 'Parallel', 'Schedule', 'Wait', 'Call', 'Interrupt', None]
|
|
9
9
|
|
|
10
10
|
class ParallelTask(BaseModel):
|
|
11
11
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -31,6 +31,20 @@ class Wait(BaseModel):
|
|
|
31
31
|
resume_action: Target
|
|
32
32
|
sub_context_path: Optional[str] = None
|
|
33
33
|
|
|
34
|
+
class Call(BaseModel):
|
|
35
|
+
"""The 'Await' type: Deduplicates execution based on an idempotency key."""
|
|
36
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
37
|
+
node_cls: Type['Node']
|
|
38
|
+
idempotency_key: str
|
|
39
|
+
payload: Optional[Any] = None
|
|
40
|
+
# After the called node finishes, the caller resumes at this node
|
|
41
|
+
resume_action: Target = None
|
|
42
|
+
|
|
43
|
+
class Interrupt(BaseModel):
|
|
44
|
+
"""Signals the engine to stop execution."""
|
|
45
|
+
subject_id: str
|
|
46
|
+
hard: bool = True # True = kill task, False = flag for soft exit
|
|
47
|
+
|
|
34
48
|
TransitionResult = Target
|
|
35
49
|
|
|
36
50
|
class Node(Generic[C, P]):
|
|
@@ -41,3 +55,11 @@ class Node(Generic[C, P]):
|
|
|
41
55
|
async def run(self, ctx: C, payload: Optional[P] = None) -> TransitionResult:
|
|
42
56
|
"""Executes node logic. Returns a Target (Node class, Parallel, Schedule, Wait, or None)."""
|
|
43
57
|
pass
|
|
58
|
+
|
|
59
|
+
async def on_signal(self, ctx: C, signal_id: str, payload: Any) -> Target:
|
|
60
|
+
"""
|
|
61
|
+
CALLBACK: Called when a 'Wait' state is resolved or an external signal
|
|
62
|
+
is sent to this node specifically.
|
|
63
|
+
"""
|
|
64
|
+
return None
|
|
65
|
+
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from datetime import datetime, timezone, timedelta
|
|
4
|
+
from typing import Optional, Type, Iterable, Dict, Any, Union, List
|
|
5
|
+
|
|
6
|
+
from ..core.models import Event
|
|
7
|
+
from ..core.node import (
|
|
8
|
+
Node,
|
|
9
|
+
Parallel,
|
|
10
|
+
Schedule,
|
|
11
|
+
Wait,
|
|
12
|
+
Target,
|
|
13
|
+
ParallelTask,
|
|
14
|
+
Call,
|
|
15
|
+
Interrupt,
|
|
16
|
+
)
|
|
17
|
+
from ..core.graph import GraphAnalyzer
|
|
18
|
+
from ..interfaces.persistence import Persistence
|
|
19
|
+
from ..interfaces.event_bus import EventBus
|
|
20
|
+
from ..interfaces.observer import Observer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Engine:
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
persistence: Persistence,
|
|
27
|
+
event_bus: EventBus,
|
|
28
|
+
nodes: Iterable[Type[Node]],
|
|
29
|
+
observer: Optional[Observer] = None,
|
|
30
|
+
):
|
|
31
|
+
self.db = persistence
|
|
32
|
+
self.bus = event_bus
|
|
33
|
+
self.observer = observer or Observer()
|
|
34
|
+
self.logger = logging.getLogger("CommandNet")
|
|
35
|
+
self._scheduler_task: Optional[asyncio.Task] = None
|
|
36
|
+
self._active_tasks: Dict[str, asyncio.Task] = {}
|
|
37
|
+
self._registry: Dict[str, Type[Node]] = {n.get_node_name(): n for n in nodes}
|
|
38
|
+
|
|
39
|
+
def _dump_ctx(self, ctx: Any) -> dict:
|
|
40
|
+
return ctx.model_dump() if hasattr(ctx, "model_dump") else ctx
|
|
41
|
+
|
|
42
|
+
def _get_path(self, obj: Any, path: str) -> Any:
|
|
43
|
+
if isinstance(obj, dict):
|
|
44
|
+
return obj.get(path)
|
|
45
|
+
return getattr(obj, path)
|
|
46
|
+
|
|
47
|
+
# --- CORE WORKER LOGIC ---
|
|
48
|
+
|
|
49
|
+
async def process_event(self, event: Event):
|
|
50
|
+
# 1. Internal Control (Cross-worker Hard Cancel)
|
|
51
|
+
if event.node_name == "__CONTROL__":
|
|
52
|
+
action = (event.payload or {}).get("action")
|
|
53
|
+
if action == "HARD_CANCEL" and event.subject_id in self._active_tasks:
|
|
54
|
+
self._active_tasks[event.subject_id].cancel()
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
# 2. Check cancellation status
|
|
58
|
+
if await self.db.is_cancelled(event.subject_id):
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
# 3. Task Tracking for Local Hard Cancel
|
|
62
|
+
task = asyncio.create_task(self._run_node_logic(event))
|
|
63
|
+
self._active_tasks[event.subject_id] = task
|
|
64
|
+
try:
|
|
65
|
+
await task
|
|
66
|
+
except asyncio.CancelledError:
|
|
67
|
+
self.logger.warning(f"Task {event.subject_id} hard-cancelled.")
|
|
68
|
+
finally:
|
|
69
|
+
self._active_tasks.pop(event.subject_id, None)
|
|
70
|
+
|
|
71
|
+
async def _run_node_logic(self, event: Event):
|
|
72
|
+
subject_id = event.subject_id
|
|
73
|
+
# We allow "AWAITING_CALL" because an subject wakes up from that state when a 'Call' resolves
|
|
74
|
+
node_name, ctx_dict = await self.db.lock_and_load(subject_id)
|
|
75
|
+
if not node_name or (
|
|
76
|
+
node_name != event.node_name and node_name != "AWAITING_CALL"
|
|
77
|
+
):
|
|
78
|
+
if node_name:
|
|
79
|
+
await self.db.unlock_subject(subject_id)
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
node_cls = self._registry.get(event.node_name)
|
|
84
|
+
if not node_cls:
|
|
85
|
+
raise RuntimeError(f"Node '{event.node_name}' not found.")
|
|
86
|
+
|
|
87
|
+
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
88
|
+
payload_type = GraphAnalyzer.get_payload_type(node_cls)
|
|
89
|
+
|
|
90
|
+
ctx = (
|
|
91
|
+
ctx_type.model_validate(ctx_dict)
|
|
92
|
+
if hasattr(ctx_type, "model_validate")
|
|
93
|
+
else ctx_dict
|
|
94
|
+
)
|
|
95
|
+
payload = (
|
|
96
|
+
payload_type.model_validate(event.payload)
|
|
97
|
+
if (event.payload and hasattr(payload_type, "model_validate"))
|
|
98
|
+
else event.payload
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Support Soft-Cancel checking inside Node.run
|
|
102
|
+
if hasattr(ctx, "is_cancelled"):
|
|
103
|
+
ctx.is_cancelled = await self.db.is_cancelled(subject_id)
|
|
104
|
+
|
|
105
|
+
start_t = asyncio.get_event_loop().time()
|
|
106
|
+
result = await node_cls().run(ctx, payload)
|
|
107
|
+
|
|
108
|
+
await self._apply_target(
|
|
109
|
+
subject_id,
|
|
110
|
+
ctx,
|
|
111
|
+
result,
|
|
112
|
+
(asyncio.get_event_loop().time() - start_t) * 1000,
|
|
113
|
+
)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
await self.observer.on_error(subject_id, event.node_name, e)
|
|
116
|
+
raise
|
|
117
|
+
finally:
|
|
118
|
+
await self.db.unlock_subject(subject_id)
|
|
119
|
+
|
|
120
|
+
# --- RECURSIVE TARGET RESOLVER ---
|
|
121
|
+
|
|
122
|
+
async def _apply_target(
|
|
123
|
+
self,
|
|
124
|
+
subject_id: str,
|
|
125
|
+
context: Any,
|
|
126
|
+
target: Target,
|
|
127
|
+
duration: float = 0.0,
|
|
128
|
+
payload: Any = None,
|
|
129
|
+
):
|
|
130
|
+
# 1. Interrupt (Cancellation)
|
|
131
|
+
if isinstance(target, Interrupt):
|
|
132
|
+
await self.cancel_subject(target.subject_id, target.hard)
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
# 2. Call (Idempotent Await)
|
|
136
|
+
if isinstance(target, Call):
|
|
137
|
+
is_leader = await self.db.add_call_waiter(
|
|
138
|
+
target.idempotency_key,
|
|
139
|
+
subject_id,
|
|
140
|
+
target.resume_action,
|
|
141
|
+
self._dump_ctx(context),
|
|
142
|
+
)
|
|
143
|
+
await self.observer.on_transition(
|
|
144
|
+
subject_id, "RUN", f"CALLING:{target.node_cls.get_node_name()}", duration
|
|
145
|
+
)
|
|
146
|
+
await self.db.save_state(
|
|
147
|
+
subject_id, "AWAITING_CALL", self._dump_ctx(context), None
|
|
148
|
+
)
|
|
149
|
+
if is_leader:
|
|
150
|
+
await self.trigger_subject(
|
|
151
|
+
f"call#{target.idempotency_key}",
|
|
152
|
+
target.node_cls,
|
|
153
|
+
context,
|
|
154
|
+
target.payload,
|
|
155
|
+
)
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
# 3. Terminal State
|
|
159
|
+
if target is None:
|
|
160
|
+
await self.observer.on_transition(subject_id, "RUN", "TERMINAL", duration)
|
|
161
|
+
ctx_dict = self._dump_ctx(context)
|
|
162
|
+
|
|
163
|
+
# Resolve Callers if this was a shared Virtual subject
|
|
164
|
+
if subject_id.startswith("call#"):
|
|
165
|
+
key = subject_id.split("#")[1]
|
|
166
|
+
waiters = await self.db.resolve_call_group(key)
|
|
167
|
+
for w in waiters:
|
|
168
|
+
# Resume waiters using the final context of the shared node as payload
|
|
169
|
+
await self._apply_target(
|
|
170
|
+
w["subject_id"],
|
|
171
|
+
w["context"],
|
|
172
|
+
w["resume_target"],
|
|
173
|
+
payload=ctx_dict,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Parallel Sub-task Completion Logic
|
|
177
|
+
if "#" in subject_id and not subject_id.startswith("call#"):
|
|
178
|
+
parent_id = subject_id.split("#")[0]
|
|
179
|
+
await self.db.save_sub_state(
|
|
180
|
+
subject_id, parent_id, "TERMINAL", ctx_dict, None
|
|
181
|
+
)
|
|
182
|
+
join_node_name = await self.db.register_sub_task_completion(subject_id)
|
|
183
|
+
if join_node_name:
|
|
184
|
+
await self._trigger_recompose(parent_id, join_node_name)
|
|
185
|
+
else:
|
|
186
|
+
await self.db.save_state(subject_id, "TERMINAL", ctx_dict, None)
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
# 4. Standard Node Transition
|
|
190
|
+
if isinstance(target, type) and issubclass(target, Node):
|
|
191
|
+
node_name = target.get_node_name()
|
|
192
|
+
await self.observer.on_transition(subject_id, "RUN", node_name, duration)
|
|
193
|
+
|
|
194
|
+
p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
195
|
+
|
|
196
|
+
headers = {}
|
|
197
|
+
if hasattr(context, "trace_headers") and context.trace_headers:
|
|
198
|
+
headers = context.trace_headers
|
|
199
|
+
|
|
200
|
+
evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load, headers=headers)
|
|
201
|
+
ctx_dict = self._dump_ctx(context)
|
|
202
|
+
|
|
203
|
+
if "#" in subject_id:
|
|
204
|
+
await self.db.save_sub_state(
|
|
205
|
+
subject_id, subject_id.split("#")[0], node_name, ctx_dict, evt
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
await self.db.save_state(subject_id, node_name, ctx_dict, evt)
|
|
209
|
+
|
|
210
|
+
await self.bus.publish(evt)
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
# 5. Wait / Signal Parking
|
|
214
|
+
if isinstance(target, Wait):
|
|
215
|
+
actual_id = subject_id
|
|
216
|
+
actual_ctx = context
|
|
217
|
+
if target.sub_context_path and "#" not in subject_id:
|
|
218
|
+
actual_id = f"{subject_id}#{target.sub_context_path}"
|
|
219
|
+
actual_ctx = self._get_path(context, target.sub_context_path)
|
|
220
|
+
|
|
221
|
+
await self.observer.on_transition(
|
|
222
|
+
actual_id, "RUN", f"WAIT:{target.signal_id}", duration
|
|
223
|
+
)
|
|
224
|
+
await self.db.park_subject(
|
|
225
|
+
actual_id,
|
|
226
|
+
target.signal_id,
|
|
227
|
+
target.resume_action,
|
|
228
|
+
self._dump_ctx(actual_ctx),
|
|
229
|
+
)
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
# 6. Parallel Fan-out
|
|
233
|
+
if isinstance(target, Parallel):
|
|
234
|
+
join_name = target.join_node.get_node_name() if target.join_node else "FORK"
|
|
235
|
+
await self.observer.on_transition(
|
|
236
|
+
subject_id, "RUN", f"PARALLEL:{join_name}", duration
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if target.join_node:
|
|
240
|
+
await self.db.create_task_group(
|
|
241
|
+
subject_id, join_name, len(target.branches)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
for branch in target.branches:
|
|
245
|
+
# Normalize branch to ParallelTask without stripping Wait wrappers
|
|
246
|
+
if isinstance(branch, ParallelTask):
|
|
247
|
+
task = branch
|
|
248
|
+
else:
|
|
249
|
+
# If it's a Wait or a Node class, wrap it but keep the action intact
|
|
250
|
+
path = (
|
|
251
|
+
branch.sub_context_path
|
|
252
|
+
if isinstance(branch, Wait) and branch.sub_context_path
|
|
253
|
+
else "default"
|
|
254
|
+
)
|
|
255
|
+
task = ParallelTask(action=branch, sub_context_path=path)
|
|
256
|
+
|
|
257
|
+
sub_ctx = self._get_path(context, task.sub_context_path)
|
|
258
|
+
await self._apply_target(
|
|
259
|
+
f"{subject_id}#{task.sub_context_path}",
|
|
260
|
+
sub_ctx,
|
|
261
|
+
task.action,
|
|
262
|
+
payload=task.payload,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
if target.join_node:
|
|
266
|
+
await self.db.save_state(
|
|
267
|
+
subject_id, "WAITING_FOR_JOIN", self._dump_ctx(context), None
|
|
268
|
+
)
|
|
269
|
+
else:
|
|
270
|
+
await self._apply_target(subject_id, context, None)
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
# 7. Schedule (Delayed Execution)
|
|
274
|
+
if isinstance(target, Schedule):
|
|
275
|
+
if not (
|
|
276
|
+
isinstance(target.action, type) and issubclass(target.action, Node)
|
|
277
|
+
):
|
|
278
|
+
raise TypeError("Schedule.action must be a Node class.")
|
|
279
|
+
|
|
280
|
+
target_node_name = target.action.get_node_name()
|
|
281
|
+
await self.observer.on_transition(
|
|
282
|
+
subject_id, "RUN", f"SCHEDULED:{target_node_name}", duration
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
run_at_dt = datetime.now(timezone.utc) + timedelta(
|
|
286
|
+
seconds=target.delay_seconds
|
|
287
|
+
)
|
|
288
|
+
p_load = (
|
|
289
|
+
target.payload.model_dump()
|
|
290
|
+
if hasattr(target.payload, "model_dump")
|
|
291
|
+
else target.payload
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
scheduled_evt = Event(
|
|
295
|
+
subject_id=subject_id,
|
|
296
|
+
node_name=target_node_name,
|
|
297
|
+
payload=p_load,
|
|
298
|
+
run_at=run_at_dt.isoformat(),
|
|
299
|
+
idempotency_key=target.idempotency_key,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
await self.db.schedule_event(scheduled_evt)
|
|
303
|
+
await self.db.save_state(
|
|
304
|
+
subject_id, target_node_name, self._dump_ctx(context), None
|
|
305
|
+
)
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
# --- EXTERNAL CONTROL & SIGNALS ---
|
|
309
|
+
|
|
310
|
+
async def signal_node(self, subject_id: str, signal_id: str, payload: Any = None):
|
|
311
|
+
"""Resumes subject and triggers its specific on_signal instance method."""
|
|
312
|
+
node_name, ctx_dict = await self.db.lock_and_load(subject_id)
|
|
313
|
+
if not node_name:
|
|
314
|
+
return
|
|
315
|
+
try:
|
|
316
|
+
node_cls = self._registry.get(node_name)
|
|
317
|
+
node_inst = node_cls()
|
|
318
|
+
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
319
|
+
ctx = (
|
|
320
|
+
ctx_type.model_validate(ctx_dict)
|
|
321
|
+
if hasattr(ctx_type, "model_validate")
|
|
322
|
+
else ctx_dict
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
result = await node_inst.on_signal(ctx, signal_id, payload)
|
|
326
|
+
await self._apply_target(subject_id, ctx, result)
|
|
327
|
+
finally:
|
|
328
|
+
await self.db.unlock_subject(subject_id)
|
|
329
|
+
|
|
330
|
+
async def release_signal(self, signal_id: str, payload: Any = None):
|
|
331
|
+
"""Standard mass-resume for parked subjects."""
|
|
332
|
+
waiters = await self.db.get_and_clear_waiters(signal_id)
|
|
333
|
+
for waiter in waiters:
|
|
334
|
+
subject_id = waiter["subject_id"]
|
|
335
|
+
await self.db.lock_and_load(subject_id)
|
|
336
|
+
try:
|
|
337
|
+
await self._apply_target(
|
|
338
|
+
subject_id=subject_id,
|
|
339
|
+
context=waiter["context"],
|
|
340
|
+
target=waiter["next_target"],
|
|
341
|
+
payload=payload,
|
|
342
|
+
)
|
|
343
|
+
finally:
|
|
344
|
+
await self.db.unlock_subject(subject_id)
|
|
345
|
+
|
|
346
|
+
async def cancel_subject(self, subject_id: str, hard: bool = True):
|
|
347
|
+
await self.db.set_cancel_flag(subject_id, hard)
|
|
348
|
+
if hard:
|
|
349
|
+
if subject_id in self._active_tasks:
|
|
350
|
+
self._active_tasks[subject_id].cancel()
|
|
351
|
+
await self.bus.publish(
|
|
352
|
+
Event(
|
|
353
|
+
subject_id=subject_id,
|
|
354
|
+
node_name="__CONTROL__",
|
|
355
|
+
payload={"action": "HARD_CANCEL"},
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# --- LIFECYCLE & HELPERS ---
|
|
360
|
+
|
|
361
|
+
async def _trigger_recompose(self, parent_id: str, join_node_name: str):
|
|
362
|
+
await self.db.lock_and_load(parent_id)
|
|
363
|
+
merged_ctx_dict = await self.db.recompose_parent(parent_id)
|
|
364
|
+
await self._apply_target(
|
|
365
|
+
parent_id, merged_ctx_dict, self._registry[join_node_name]
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
async def start_worker(self, poll_interval: float = 1.0):
|
|
369
|
+
await self.bus.subscribe(self.process_event)
|
|
370
|
+
self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
|
|
371
|
+
|
|
372
|
+
async def _scheduler_loop(self, poll_interval: float):
|
|
373
|
+
while True:
|
|
374
|
+
try:
|
|
375
|
+
due = await self.db.pop_due_events()
|
|
376
|
+
for evt in due:
|
|
377
|
+
await self.bus.publish(evt)
|
|
378
|
+
except asyncio.CancelledError:
|
|
379
|
+
break
|
|
380
|
+
except Exception as e:
|
|
381
|
+
self.logger.error(f"Scheduler: {e}")
|
|
382
|
+
await asyncio.sleep(poll_interval)
|
|
383
|
+
|
|
384
|
+
async def trigger_subject(
|
|
385
|
+
self,
|
|
386
|
+
subject_id: str,
|
|
387
|
+
start_node: Type[Node],
|
|
388
|
+
initial_context: Any,
|
|
389
|
+
payload: Any = None,
|
|
390
|
+
):
|
|
391
|
+
node_name = start_node.get_node_name()
|
|
392
|
+
p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
393
|
+
evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load)
|
|
394
|
+
await self.db.save_state(
|
|
395
|
+
subject_id, node_name, self._dump_ctx(initial_context), evt
|
|
396
|
+
)
|
|
397
|
+
await self.bus.publish(evt)
|
|
398
|
+
|
|
399
|
+
async def stop(self):
|
|
400
|
+
if self._scheduler_task:
|
|
401
|
+
self._scheduler_task.cancel()
|
|
402
|
+
for task in self._active_tasks.values():
|
|
403
|
+
task.cancel()
|
|
404
|
+
|
|
405
|
+
def validate_graph(self, start_node: Type[Node]):
|
|
406
|
+
return GraphAnalyzer.validate(start_node, self._registry)
|
|
407
|
+
|
|
@@ -4,16 +4,15 @@ from ..core.models import Event
|
|
|
4
4
|
|
|
5
5
|
class Persistence(ABC):
|
|
6
6
|
@abstractmethod
|
|
7
|
-
async def
|
|
7
|
+
async def lock_and_load(self, subject_id: str) -> Tuple[Optional[str], Optional[Dict]]:
|
|
8
8
|
pass
|
|
9
9
|
|
|
10
10
|
@abstractmethod
|
|
11
|
-
async def
|
|
12
|
-
"""Releases the row lock. Called automatically by the engine if an exception occurs."""
|
|
11
|
+
async def unlock_subject(self, subject_id: str):
|
|
13
12
|
pass
|
|
14
13
|
|
|
15
14
|
@abstractmethod
|
|
16
|
-
async def save_state(self,
|
|
15
|
+
async def save_state(self, subject_id: str, node_name: str, context: Dict, event: Optional[Event]):
|
|
17
16
|
pass
|
|
18
17
|
|
|
19
18
|
@abstractmethod
|
|
@@ -41,9 +40,27 @@ class Persistence(ABC):
|
|
|
41
40
|
pass
|
|
42
41
|
|
|
43
42
|
@abstractmethod
|
|
44
|
-
async def
|
|
43
|
+
async def park_subject(self, subject_id: str, signal_id: str, next_target: Any, context: Dict):
|
|
45
44
|
pass
|
|
46
45
|
|
|
47
46
|
@abstractmethod
|
|
48
47
|
async def get_and_clear_waiters(self, signal_id: str) -> List[Dict]:
|
|
49
48
|
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
async def set_cancel_flag(self, subject_id: str, hard: bool):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
async def is_cancelled(self, subject_id: str) -> bool:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
async def add_call_waiter(self, key: str, subject_id: str, resume_target: Any, context: dict) -> bool:
|
|
60
|
+
"""Returns True if this is the first waiter for this key (the 'leader')."""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
@abstractmethod
|
|
64
|
+
async def resolve_call_group(self, key: str) -> List[Dict[str, Any]]:
|
|
65
|
+
pass
|
|
66
|
+
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "commandnet"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.1"
|
|
4
4
|
description = "A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime."
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Christopher Vaz", email = "christophervaz160@gmail.com" }
|
|
@@ -41,3 +41,4 @@ build-backend = "poetry.core.masonry.api"
|
|
|
41
41
|
asyncio_mode = "strict"
|
|
42
42
|
asyncio_default_test_loop_scope = "function"
|
|
43
43
|
asyncio_default_fixture_loop_scope = "function"
|
|
44
|
+
|
|
@@ -1,292 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import logging
|
|
3
|
-
from datetime import datetime, timezone, timedelta
|
|
4
|
-
from typing import Optional, Type, Iterable, Dict, Any, Union
|
|
5
|
-
from pydantic import BaseModel
|
|
6
|
-
|
|
7
|
-
from ..core.models import Event
|
|
8
|
-
from ..core.node import Node, Parallel, Schedule, Wait, Target, ParallelTask
|
|
9
|
-
from ..core.graph import GraphAnalyzer
|
|
10
|
-
from ..interfaces.persistence import Persistence
|
|
11
|
-
from ..interfaces.event_bus import EventBus
|
|
12
|
-
from ..interfaces.observer import Observer
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class Engine:
|
|
16
|
-
def __init__(
|
|
17
|
-
self,
|
|
18
|
-
persistence: Persistence,
|
|
19
|
-
event_bus: EventBus,
|
|
20
|
-
nodes: Iterable[Type[Node]],
|
|
21
|
-
observer: Optional[Observer] = None,
|
|
22
|
-
):
|
|
23
|
-
self.db = persistence
|
|
24
|
-
self.bus = event_bus
|
|
25
|
-
self.observer = observer or Observer()
|
|
26
|
-
self.logger = logging.getLogger("CommandNet")
|
|
27
|
-
self._scheduler_task: Optional[asyncio.Task] = None
|
|
28
|
-
|
|
29
|
-
self._registry: Dict[str, Type[Node]] = {n.get_node_name(): n for n in nodes}
|
|
30
|
-
|
|
31
|
-
# --- HELPER UTILITIES ---
|
|
32
|
-
|
|
33
|
-
def _dump_ctx(self, ctx: Any) -> dict:
|
|
34
|
-
"""Helper to safely dump context whether it's a model or a dict."""
|
|
35
|
-
if hasattr(ctx, "model_dump"):
|
|
36
|
-
return ctx.model_dump()
|
|
37
|
-
return ctx
|
|
38
|
-
|
|
39
|
-
def _get_path(self, obj: Any, path: str) -> Any:
|
|
40
|
-
"""Helper to get a value from an object or a dict."""
|
|
41
|
-
if isinstance(obj, dict):
|
|
42
|
-
return obj.get(path)
|
|
43
|
-
return getattr(obj, path)
|
|
44
|
-
|
|
45
|
-
# --- CORE RECURSIVE RESOLVER ---
|
|
46
|
-
|
|
47
|
-
async def _apply_target(
|
|
48
|
-
self,
|
|
49
|
-
agent_id: str,
|
|
50
|
-
context: Any,
|
|
51
|
-
target: Target,
|
|
52
|
-
duration: float = 0.0,
|
|
53
|
-
payload: Any = None,
|
|
54
|
-
):
|
|
55
|
-
# 1. Terminal State
|
|
56
|
-
if target is None:
|
|
57
|
-
await self.observer.on_transition(agent_id, "RUN", "TERMINAL", duration)
|
|
58
|
-
ctx_dict = self._dump_ctx(context)
|
|
59
|
-
if "#" in agent_id:
|
|
60
|
-
parent_id = agent_id.split("#")[0]
|
|
61
|
-
await self.db.save_sub_state(
|
|
62
|
-
agent_id, parent_id, "TERMINAL", ctx_dict, None
|
|
63
|
-
)
|
|
64
|
-
join_node_name = await self.db.register_sub_task_completion(agent_id)
|
|
65
|
-
if join_node_name:
|
|
66
|
-
await self._trigger_recompose(parent_id, join_node_name)
|
|
67
|
-
else:
|
|
68
|
-
await self.db.save_state(agent_id, "TERMINAL", ctx_dict, None)
|
|
69
|
-
return
|
|
70
|
-
|
|
71
|
-
# 2. Node Class Transition
|
|
72
|
-
if isinstance(target, type) and issubclass(target, Node):
|
|
73
|
-
node_name = target.get_node_name()
|
|
74
|
-
await self.observer.on_transition(agent_id, "RUN", node_name, duration)
|
|
75
|
-
|
|
76
|
-
p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
|
|
77
|
-
evt = Event(agent_id=agent_id, node_name=node_name, payload=p_load)
|
|
78
|
-
ctx_dict = self._dump_ctx(context)
|
|
79
|
-
|
|
80
|
-
if "#" in agent_id:
|
|
81
|
-
await self.db.save_sub_state(
|
|
82
|
-
agent_id, agent_id.split("#")[0], node_name, ctx_dict, evt
|
|
83
|
-
)
|
|
84
|
-
else:
|
|
85
|
-
await self.db.save_state(agent_id, node_name, ctx_dict, evt)
|
|
86
|
-
|
|
87
|
-
await self.bus.publish(evt)
|
|
88
|
-
return
|
|
89
|
-
|
|
90
|
-
# 3. Wait Directive
|
|
91
|
-
if isinstance(target, Wait):
|
|
92
|
-
actual_id = agent_id
|
|
93
|
-
actual_ctx = context
|
|
94
|
-
if target.sub_context_path and "#" not in agent_id:
|
|
95
|
-
actual_id = f"{agent_id}#{target.sub_context_path}"
|
|
96
|
-
actual_ctx = self._get_path(context, target.sub_context_path)
|
|
97
|
-
|
|
98
|
-
await self.observer.on_transition(
|
|
99
|
-
actual_id, "RUN", f"WAIT:{target.signal_id}", duration
|
|
100
|
-
)
|
|
101
|
-
await self.db.park_agent(
|
|
102
|
-
actual_id,
|
|
103
|
-
target.signal_id,
|
|
104
|
-
target.resume_action,
|
|
105
|
-
self._dump_ctx(actual_ctx),
|
|
106
|
-
)
|
|
107
|
-
return
|
|
108
|
-
|
|
109
|
-
# 4. Parallel Directive
|
|
110
|
-
if isinstance(target, Parallel):
|
|
111
|
-
join_name = target.join_node.get_node_name() if target.join_node else "FORK"
|
|
112
|
-
await self.observer.on_transition(
|
|
113
|
-
agent_id, "RUN", f"PARALLEL:{join_name}", duration
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
if target.join_node:
|
|
117
|
-
await self.db.create_task_group(
|
|
118
|
-
agent_id, join_name, len(target.branches)
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
for branch in target.branches:
|
|
122
|
-
task = (
|
|
123
|
-
branch
|
|
124
|
-
if isinstance(branch, ParallelTask)
|
|
125
|
-
else ParallelTask(
|
|
126
|
-
action=branch,
|
|
127
|
-
sub_context_path=branch.sub_context_path, # type: ignore
|
|
128
|
-
)
|
|
129
|
-
)
|
|
130
|
-
sub_ctx = self._get_path(context, task.sub_context_path)
|
|
131
|
-
await self._apply_target(
|
|
132
|
-
f"{agent_id}#{task.sub_context_path}",
|
|
133
|
-
sub_ctx,
|
|
134
|
-
task.action,
|
|
135
|
-
duration=0,
|
|
136
|
-
payload=task.payload,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
if target.join_node:
|
|
140
|
-
await self.db.save_state(
|
|
141
|
-
agent_id, "WAITING_FOR_JOIN", self._dump_ctx(context), None
|
|
142
|
-
)
|
|
143
|
-
else:
|
|
144
|
-
await self._apply_target(agent_id, context, None)
|
|
145
|
-
return
|
|
146
|
-
|
|
147
|
-
# 5. Schedule Directive
|
|
148
|
-
if isinstance(target, Schedule):
|
|
149
|
-
if isinstance(target.action, type) and issubclass(target.action, Node):
|
|
150
|
-
target_node_name = target.action.get_node_name()
|
|
151
|
-
else:
|
|
152
|
-
raise TypeError("Schedule.action must be a Node class.")
|
|
153
|
-
|
|
154
|
-
await self.observer.on_transition(
|
|
155
|
-
agent_id, "RUN", f"SCHEDULED:{target_node_name}", duration
|
|
156
|
-
)
|
|
157
|
-
run_at_dt = datetime.now(timezone.utc) + timedelta(
|
|
158
|
-
seconds=target.delay_seconds
|
|
159
|
-
)
|
|
160
|
-
p_load = (
|
|
161
|
-
target.payload.model_dump()
|
|
162
|
-
if hasattr(target.payload, "model_dump")
|
|
163
|
-
else target.payload
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
scheduled_evt = Event(
|
|
167
|
-
agent_id=agent_id,
|
|
168
|
-
node_name=target_node_name,
|
|
169
|
-
payload=p_load,
|
|
170
|
-
run_at=run_at_dt.isoformat(),
|
|
171
|
-
idempotency_key=target.idempotency_key,
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
await self.db.schedule_event(scheduled_evt)
|
|
175
|
-
await self.db.save_state(
|
|
176
|
-
agent_id, target_node_name, self._dump_ctx(context), None
|
|
177
|
-
)
|
|
178
|
-
return
|
|
179
|
-
|
|
180
|
-
# --- WORKER & EXTERNAL TRIGGERS ---
|
|
181
|
-
|
|
182
|
-
async def process_event(self, event: Event):
|
|
183
|
-
start_time = asyncio.get_event_loop().time()
|
|
184
|
-
agent_id = event.agent_id
|
|
185
|
-
current_node_name, ctx_dict = await self.db.load_and_lock_agent(agent_id)
|
|
186
|
-
|
|
187
|
-
if not current_node_name:
|
|
188
|
-
return
|
|
189
|
-
|
|
190
|
-
try:
|
|
191
|
-
if current_node_name != event.node_name:
|
|
192
|
-
return
|
|
193
|
-
|
|
194
|
-
node_cls = self._registry.get(current_node_name)
|
|
195
|
-
if not node_cls:
|
|
196
|
-
raise RuntimeError(f"Node '{current_node_name}' not found.")
|
|
197
|
-
|
|
198
|
-
ctx_type = GraphAnalyzer.get_context_type(node_cls)
|
|
199
|
-
payload_type = GraphAnalyzer.get_payload_type(node_cls)
|
|
200
|
-
|
|
201
|
-
ctx = (
|
|
202
|
-
ctx_type.model_validate(ctx_dict)
|
|
203
|
-
if issubclass(ctx_type, BaseModel)
|
|
204
|
-
else ctx_dict
|
|
205
|
-
)
|
|
206
|
-
payload = (
|
|
207
|
-
payload_type.model_validate(event.payload)
|
|
208
|
-
if (event.payload and issubclass(payload_type, BaseModel))
|
|
209
|
-
else event.payload
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
result = await node_cls().run(ctx, payload)
|
|
213
|
-
await self._apply_target(
|
|
214
|
-
agent_id,
|
|
215
|
-
ctx,
|
|
216
|
-
result,
|
|
217
|
-
(asyncio.get_event_loop().time() - start_time) * 1000,
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
except Exception as e:
|
|
221
|
-
await self.observer.on_error(agent_id, current_node_name, e)
|
|
222
|
-
raise
|
|
223
|
-
finally:
|
|
224
|
-
await self.db.unlock_agent(agent_id)
|
|
225
|
-
|
|
226
|
-
async def release_signal(self, signal_id: str, payload: Any = None):
|
|
227
|
-
waiters = await self.db.get_and_clear_waiters(signal_id)
|
|
228
|
-
for waiter in waiters:
|
|
229
|
-
agent_id = waiter["agent_id"]
|
|
230
|
-
await self.db.load_and_lock_agent(agent_id)
|
|
231
|
-
try:
|
|
232
|
-
await self._apply_target(
|
|
233
|
-
agent_id=agent_id,
|
|
234
|
-
context=waiter["context"],
|
|
235
|
-
target=waiter["next_target"],
|
|
236
|
-
payload=payload,
|
|
237
|
-
)
|
|
238
|
-
finally:
|
|
239
|
-
await self.db.unlock_agent(agent_id)
|
|
240
|
-
|
|
241
|
-
# --- LIFECYCLE METHODS ---
|
|
242
|
-
|
|
243
|
-
async def start_worker(self, poll_interval: float = 1.0):
|
|
244
|
-
await self.bus.subscribe(self.process_event)
|
|
245
|
-
self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
|
|
246
|
-
self.logger.info("Worker started.")
|
|
247
|
-
|
|
248
|
-
async def _scheduler_loop(self, poll_interval: float):
|
|
249
|
-
while True:
|
|
250
|
-
try:
|
|
251
|
-
due = await self.db.pop_due_events()
|
|
252
|
-
for evt in due:
|
|
253
|
-
await self.bus.publish(evt)
|
|
254
|
-
except asyncio.CancelledError:
|
|
255
|
-
break
|
|
256
|
-
except Exception as e:
|
|
257
|
-
self.logger.error(f"Scheduler Error: {e}")
|
|
258
|
-
await asyncio.sleep(poll_interval)
|
|
259
|
-
|
|
260
|
-
async def stop(self):
|
|
261
|
-
if self._scheduler_task:
|
|
262
|
-
self._scheduler_task.cancel()
|
|
263
|
-
try:
|
|
264
|
-
await self._scheduler_task
|
|
265
|
-
except asyncio.CancelledError:
|
|
266
|
-
pass
|
|
267
|
-
|
|
268
|
-
async def trigger_agent(
|
|
269
|
-
self,
|
|
270
|
-
agent_id: str,
|
|
271
|
-
start_node: Type[Node],
|
|
272
|
-
initial_context: BaseModel,
|
|
273
|
-
payload: Optional[BaseModel] = None,
|
|
274
|
-
):
|
|
275
|
-
node_name = start_node.get_node_name()
|
|
276
|
-
evt = Event(
|
|
277
|
-
agent_id=agent_id,
|
|
278
|
-
node_name=node_name,
|
|
279
|
-
payload=payload.model_dump() if hasattr(payload, "model_dump") else payload,
|
|
280
|
-
)
|
|
281
|
-
await self.db.save_state(agent_id, node_name, initial_context.model_dump(), evt)
|
|
282
|
-
await self.bus.publish(evt)
|
|
283
|
-
|
|
284
|
-
async def _trigger_recompose(self, parent_id: str, join_node_name: str):
|
|
285
|
-
await self.db.load_and_lock_agent(parent_id)
|
|
286
|
-
merged_ctx_dict = await self.db.recompose_parent(parent_id)
|
|
287
|
-
await self._apply_target(
|
|
288
|
-
parent_id, merged_ctx_dict, self._registry[join_node_name]
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
def validate_graph(self, start_node: Type[Node]):
|
|
292
|
-
return GraphAnalyzer.validate(start_node, self._registry)
|