commandnet 0.2.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: commandnet
3
- Version: 0.2.2
3
+ Version: 0.4.0
4
4
  Summary: A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime.
5
5
  Author: Christopher Vaz
6
6
  Author-email: christophervaz160@gmail.com
@@ -53,7 +53,7 @@ pip install commandnet
53
53
  ## 🛠️ Quick Start
54
54
 
55
55
  ### 1. Define Your Context
56
- The "Context" is the persistent state of your agent, defined using Pydantic.
56
+ The "Context" is the persistent state of your subject, defined using Pydantic.
57
57
 
58
58
  ```python
59
59
  from pydantic import BaseModel
@@ -139,7 +139,7 @@ engine = Engine(persistence=db, event_bus=bus)
139
139
  await engine.start_worker()
140
140
 
141
141
  # Trigger an execution
142
- await engine.trigger_agent("agent-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
142
+ await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
143
143
  ```
144
144
 
145
145
  ---
@@ -172,3 +172,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
172
172
 
173
173
  MIT
174
174
 
175
+
@@ -30,7 +30,7 @@ pip install commandnet
30
30
  ## 🛠️ Quick Start
31
31
 
32
32
  ### 1. Define Your Context
33
- The "Context" is the persistent state of your agent, defined using Pydantic.
33
+ The "Context" is the persistent state of your subject, defined using Pydantic.
34
34
 
35
35
  ```python
36
36
  from pydantic import BaseModel
@@ -116,7 +116,7 @@ engine = Engine(persistence=db, event_bus=bus)
116
116
  await engine.start_worker()
117
117
 
118
118
  # Trigger an execution
119
- await engine.trigger_agent("agent-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
119
+ await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
120
120
  ```
121
121
 
122
122
  ---
@@ -148,3 +148,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
148
148
  ## 📄 License
149
149
 
150
150
  MIT
151
+
@@ -1,5 +1,5 @@
1
1
  from .core.models import Event
2
- from .core.node import Node, Parallel, ParallelTask, Schedule
2
+ from .core.node import Node, Parallel, ParallelTask, Schedule, Wait, Call, Interrupt
3
3
  from .core.graph import GraphAnalyzer
4
4
  from .interfaces.persistence import Persistence
5
5
  from .interfaces.event_bus import EventBus
@@ -9,4 +9,6 @@ from .engine.runtime import Engine
9
9
  __all__ = [
10
10
  "Event", "Node", "Parallel", "ParallelTask", "GraphAnalyzer",
11
11
  "Persistence", "EventBus", "Observer", "Engine", "Schedule",
12
+ "Call", "Interrupt",
12
13
  ]
14
+
@@ -96,3 +96,4 @@ class GraphAnalyzer:
96
96
  f"({source_ctx.__name__} -> {target_ctx.__name__})"
97
97
  )
98
98
  return True
99
+
@@ -8,10 +8,11 @@ def utcnow_iso() -> str:
8
8
 
9
9
  class Event(BaseModel):
10
10
  event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
11
- agent_id: str
11
+ subject_id: str
12
12
  node_name: str
13
13
  payload: Optional[Dict[str, Any]] = None
14
14
 
15
15
  timestamp: str = Field(default_factory=utcnow_iso)
16
16
  run_at: str = Field(default_factory=utcnow_iso)
17
17
  idempotency_key: Optional[str] = None
18
+
@@ -0,0 +1,65 @@
1
+ from typing import Generic, TypeVar, Type, Optional, Union, List, Any
2
+ from pydantic import BaseModel, ConfigDict
3
+
4
+ C = TypeVar('C', bound=BaseModel) # Context
5
+ P = TypeVar('P', bound=BaseModel) # Payload
6
+
7
+ # The Recursive Type Definition
8
+ Target = Union[Type['Node'], 'Parallel', 'Schedule', 'Wait', 'Call', 'Interrupt', None]
9
+
10
+ class ParallelTask(BaseModel):
11
+ model_config = ConfigDict(arbitrary_types_allowed=True)
12
+ action: Target
13
+ sub_context_path: str
14
+ payload: Optional[Any] = None
15
+
16
+ class Parallel(BaseModel):
17
+ model_config = ConfigDict(arbitrary_types_allowed=True)
18
+ branches: List[Union[ParallelTask, 'Wait']]
19
+ join_node: Optional[Type['Node']] = None
20
+
21
+ class Schedule(BaseModel):
22
+ model_config = ConfigDict(arbitrary_types_allowed=True)
23
+ action: Target
24
+ delay_seconds: int
25
+ payload: Optional[Any] = None
26
+ idempotency_key: Optional[str] = None
27
+
28
+ class Wait(BaseModel):
29
+ model_config = ConfigDict(arbitrary_types_allowed=True)
30
+ signal_id: str
31
+ resume_action: Target
32
+ sub_context_path: Optional[str] = None
33
+
34
+ class Call(BaseModel):
35
+ """The 'Await' type: Deduplicates execution based on an idempotency key."""
36
+ model_config = ConfigDict(arbitrary_types_allowed=True)
37
+ node_cls: Type['Node']
38
+ idempotency_key: str
39
+ payload: Optional[Any] = None
40
+ # After the called node finishes, the caller resumes at this node
41
+ resume_action: Target = None
42
+
43
+ class Interrupt(BaseModel):
44
+ """Signals the engine to stop execution."""
45
+ subject_id: str
46
+ hard: bool = True # True = kill task, False = flag for soft exit
47
+
48
+ TransitionResult = Target
49
+
50
+ class Node(Generic[C, P]):
51
+ @classmethod
52
+ def get_node_name(cls) -> str:
53
+ return cls.__name__
54
+
55
+ async def run(self, ctx: C, payload: Optional[P] = None) -> TransitionResult:
56
+ """Executes node logic. Returns a Target (Node class, Parallel, Schedule, Wait, or None)."""
57
+ pass
58
+
59
+ async def on_signal(self, ctx: C, signal_id: str, payload: Any) -> Target:
60
+ """
61
+ CALLBACK: Called when a 'Wait' state is resolved or an external signal
62
+ is sent to this node specifically.
63
+ """
64
+ return None
65
+
@@ -0,0 +1,402 @@
1
+ import asyncio
2
+ import logging
3
+ from datetime import datetime, timezone, timedelta
4
+ from typing import Optional, Type, Iterable, Dict, Any, Union, List
5
+
6
+ from ..core.models import Event
7
+ from ..core.node import (
8
+ Node,
9
+ Parallel,
10
+ Schedule,
11
+ Wait,
12
+ Target,
13
+ ParallelTask,
14
+ Call,
15
+ Interrupt,
16
+ )
17
+ from ..core.graph import GraphAnalyzer
18
+ from ..interfaces.persistence import Persistence
19
+ from ..interfaces.event_bus import EventBus
20
+ from ..interfaces.observer import Observer
21
+
22
+
23
+ class Engine:
24
+ def __init__(
25
+ self,
26
+ persistence: Persistence,
27
+ event_bus: EventBus,
28
+ nodes: Iterable[Type[Node]],
29
+ observer: Optional[Observer] = None,
30
+ ):
31
+ self.db = persistence
32
+ self.bus = event_bus
33
+ self.observer = observer or Observer()
34
+ self.logger = logging.getLogger("CommandNet")
35
+ self._scheduler_task: Optional[asyncio.Task] = None
36
+ self._active_tasks: Dict[str, asyncio.Task] = {}
37
+ self._registry: Dict[str, Type[Node]] = {n.get_node_name(): n for n in nodes}
38
+
39
+ def _dump_ctx(self, ctx: Any) -> dict:
40
+ return ctx.model_dump() if hasattr(ctx, "model_dump") else ctx
41
+
42
+ def _get_path(self, obj: Any, path: str) -> Any:
43
+ if isinstance(obj, dict):
44
+ return obj.get(path)
45
+ return getattr(obj, path)
46
+
47
+ # --- CORE WORKER LOGIC ---
48
+
49
+ async def process_event(self, event: Event):
50
+ # 1. Internal Control (Cross-worker Hard Cancel)
51
+ if event.node_name == "__CONTROL__":
52
+ action = (event.payload or {}).get("action")
53
+ if action == "HARD_CANCEL" and event.subject_id in self._active_tasks:
54
+ self._active_tasks[event.subject_id].cancel()
55
+ return
56
+
57
+ # 2. Check cancellation status
58
+ if await self.db.is_cancelled(event.subject_id):
59
+ return
60
+
61
+ # 3. Task Tracking for Local Hard Cancel
62
+ task = asyncio.create_task(self._run_node_logic(event))
63
+ self._active_tasks[event.subject_id] = task
64
+ try:
65
+ await task
66
+ except asyncio.CancelledError:
67
+ self.logger.warning(f"Task {event.subject_id} hard-cancelled.")
68
+ finally:
69
+ self._active_tasks.pop(event.subject_id, None)
70
+
71
+ async def _run_node_logic(self, event: Event):
72
+ subject_id = event.subject_id
73
+ # We allow "AWAITING_CALL" because an subject wakes up from that state when a 'Call' resolves
74
+ node_name, ctx_dict = await self.db.lock_and_load(subject_id)
75
+ if not node_name or (
76
+ node_name != event.node_name and node_name != "AWAITING_CALL"
77
+ ):
78
+ if node_name:
79
+ await self.db.unlock_subject(subject_id)
80
+ return
81
+
82
+ try:
83
+ node_cls = self._registry.get(event.node_name)
84
+ if not node_cls:
85
+ raise RuntimeError(f"Node '{event.node_name}' not found.")
86
+
87
+ ctx_type = GraphAnalyzer.get_context_type(node_cls)
88
+ payload_type = GraphAnalyzer.get_payload_type(node_cls)
89
+
90
+ ctx = (
91
+ ctx_type.model_validate(ctx_dict)
92
+ if hasattr(ctx_type, "model_validate")
93
+ else ctx_dict
94
+ )
95
+ payload = (
96
+ payload_type.model_validate(event.payload)
97
+ if (event.payload and hasattr(payload_type, "model_validate"))
98
+ else event.payload
99
+ )
100
+
101
+ # Support Soft-Cancel checking inside Node.run
102
+ if hasattr(ctx, "is_cancelled"):
103
+ ctx.is_cancelled = await self.db.is_cancelled(subject_id)
104
+
105
+ start_t = asyncio.get_event_loop().time()
106
+ result = await node_cls().run(ctx, payload)
107
+
108
+ await self._apply_target(
109
+ subject_id,
110
+ ctx,
111
+ result,
112
+ (asyncio.get_event_loop().time() - start_t) * 1000,
113
+ )
114
+ except Exception as e:
115
+ await self.observer.on_error(subject_id, event.node_name, e)
116
+ raise
117
+ finally:
118
+ await self.db.unlock_subject(subject_id)
119
+
120
+ # --- RECURSIVE TARGET RESOLVER ---
121
+
122
+ async def _apply_target(
123
+ self,
124
+ subject_id: str,
125
+ context: Any,
126
+ target: Target,
127
+ duration: float = 0.0,
128
+ payload: Any = None,
129
+ ):
130
+ # 1. Interrupt (Cancellation)
131
+ if isinstance(target, Interrupt):
132
+ await self.cancel_subject(target.subject_id, target.hard)
133
+ return
134
+
135
+ # 2. Call (Idempotent Await)
136
+ if isinstance(target, Call):
137
+ is_leader = await self.db.add_call_waiter(
138
+ target.idempotency_key,
139
+ subject_id,
140
+ target.resume_action,
141
+ self._dump_ctx(context),
142
+ )
143
+ await self.observer.on_transition(
144
+ subject_id, "RUN", f"CALLING:{target.node_cls.get_node_name()}", duration
145
+ )
146
+ await self.db.save_state(
147
+ subject_id, "AWAITING_CALL", self._dump_ctx(context), None
148
+ )
149
+ if is_leader:
150
+ await self.trigger_subject(
151
+ f"call#{target.idempotency_key}",
152
+ target.node_cls,
153
+ context,
154
+ target.payload,
155
+ )
156
+ return
157
+
158
+ # 3. Terminal State
159
+ if target is None:
160
+ await self.observer.on_transition(subject_id, "RUN", "TERMINAL", duration)
161
+ ctx_dict = self._dump_ctx(context)
162
+
163
+ # Resolve Callers if this was a shared Virtual subject
164
+ if subject_id.startswith("call#"):
165
+ key = subject_id.split("#")[1]
166
+ waiters = await self.db.resolve_call_group(key)
167
+ for w in waiters:
168
+ # Resume waiters using the final context of the shared node as payload
169
+ await self._apply_target(
170
+ w["subject_id"],
171
+ w["context"],
172
+ w["resume_target"],
173
+ payload=ctx_dict,
174
+ )
175
+
176
+ # Parallel Sub-task Completion Logic
177
+ if "#" in subject_id and not subject_id.startswith("call#"):
178
+ parent_id = subject_id.split("#")[0]
179
+ await self.db.save_sub_state(
180
+ subject_id, parent_id, "TERMINAL", ctx_dict, None
181
+ )
182
+ join_node_name = await self.db.register_sub_task_completion(subject_id)
183
+ if join_node_name:
184
+ await self._trigger_recompose(parent_id, join_node_name)
185
+ else:
186
+ await self.db.save_state(subject_id, "TERMINAL", ctx_dict, None)
187
+ return
188
+
189
+ # 4. Standard Node Transition
190
+ if isinstance(target, type) and issubclass(target, Node):
191
+ node_name = target.get_node_name()
192
+ await self.observer.on_transition(subject_id, "RUN", node_name, duration)
193
+
194
+ p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
195
+ evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load)
196
+ ctx_dict = self._dump_ctx(context)
197
+
198
+ if "#" in subject_id:
199
+ await self.db.save_sub_state(
200
+ subject_id, subject_id.split("#")[0], node_name, ctx_dict, evt
201
+ )
202
+ else:
203
+ await self.db.save_state(subject_id, node_name, ctx_dict, evt)
204
+
205
+ await self.bus.publish(evt)
206
+ return
207
+
208
+ # 5. Wait / Signal Parking
209
+ if isinstance(target, Wait):
210
+ actual_id = subject_id
211
+ actual_ctx = context
212
+ if target.sub_context_path and "#" not in subject_id:
213
+ actual_id = f"{subject_id}#{target.sub_context_path}"
214
+ actual_ctx = self._get_path(context, target.sub_context_path)
215
+
216
+ await self.observer.on_transition(
217
+ actual_id, "RUN", f"WAIT:{target.signal_id}", duration
218
+ )
219
+ await self.db.park_subject(
220
+ actual_id,
221
+ target.signal_id,
222
+ target.resume_action,
223
+ self._dump_ctx(actual_ctx),
224
+ )
225
+ return
226
+
227
+ # 6. Parallel Fan-out
228
+ if isinstance(target, Parallel):
229
+ join_name = target.join_node.get_node_name() if target.join_node else "FORK"
230
+ await self.observer.on_transition(
231
+ subject_id, "RUN", f"PARALLEL:{join_name}", duration
232
+ )
233
+
234
+ if target.join_node:
235
+ await self.db.create_task_group(
236
+ subject_id, join_name, len(target.branches)
237
+ )
238
+
239
+ for branch in target.branches:
240
+ # Normalize branch to ParallelTask without stripping Wait wrappers
241
+ if isinstance(branch, ParallelTask):
242
+ task = branch
243
+ else:
244
+ # If it's a Wait or a Node class, wrap it but keep the action intact
245
+ path = (
246
+ branch.sub_context_path
247
+ if isinstance(branch, Wait) and branch.sub_context_path
248
+ else "default"
249
+ )
250
+ task = ParallelTask(action=branch, sub_context_path=path)
251
+
252
+ sub_ctx = self._get_path(context, task.sub_context_path)
253
+ await self._apply_target(
254
+ f"{subject_id}#{task.sub_context_path}",
255
+ sub_ctx,
256
+ task.action,
257
+ payload=task.payload,
258
+ )
259
+
260
+ if target.join_node:
261
+ await self.db.save_state(
262
+ subject_id, "WAITING_FOR_JOIN", self._dump_ctx(context), None
263
+ )
264
+ else:
265
+ await self._apply_target(subject_id, context, None)
266
+ return
267
+
268
+ # 7. Schedule (Delayed Execution)
269
+ if isinstance(target, Schedule):
270
+ if not (
271
+ isinstance(target.action, type) and issubclass(target.action, Node)
272
+ ):
273
+ raise TypeError("Schedule.action must be a Node class.")
274
+
275
+ target_node_name = target.action.get_node_name()
276
+ await self.observer.on_transition(
277
+ subject_id, "RUN", f"SCHEDULED:{target_node_name}", duration
278
+ )
279
+
280
+ run_at_dt = datetime.now(timezone.utc) + timedelta(
281
+ seconds=target.delay_seconds
282
+ )
283
+ p_load = (
284
+ target.payload.model_dump()
285
+ if hasattr(target.payload, "model_dump")
286
+ else target.payload
287
+ )
288
+
289
+ scheduled_evt = Event(
290
+ subject_id=subject_id,
291
+ node_name=target_node_name,
292
+ payload=p_load,
293
+ run_at=run_at_dt.isoformat(),
294
+ idempotency_key=target.idempotency_key,
295
+ )
296
+
297
+ await self.db.schedule_event(scheduled_evt)
298
+ await self.db.save_state(
299
+ subject_id, target_node_name, self._dump_ctx(context), None
300
+ )
301
+ return
302
+
303
+ # --- EXTERNAL CONTROL & SIGNALS ---
304
+
305
+ async def signal_node(self, subject_id: str, signal_id: str, payload: Any = None):
306
+ """Resumes subject and triggers its specific on_signal instance method."""
307
+ node_name, ctx_dict = await self.db.lock_and_load(subject_id)
308
+ if not node_name:
309
+ return
310
+ try:
311
+ node_cls = self._registry.get(node_name)
312
+ node_inst = node_cls()
313
+ ctx_type = GraphAnalyzer.get_context_type(node_cls)
314
+ ctx = (
315
+ ctx_type.model_validate(ctx_dict)
316
+ if hasattr(ctx_type, "model_validate")
317
+ else ctx_dict
318
+ )
319
+
320
+ result = await node_inst.on_signal(ctx, signal_id, payload)
321
+ await self._apply_target(subject_id, ctx, result)
322
+ finally:
323
+ await self.db.unlock_subject(subject_id)
324
+
325
+ async def release_signal(self, signal_id: str, payload: Any = None):
326
+ """Standard mass-resume for parked subjects."""
327
+ waiters = await self.db.get_and_clear_waiters(signal_id)
328
+ for waiter in waiters:
329
+ subject_id = waiter["subject_id"]
330
+ await self.db.lock_and_load(subject_id)
331
+ try:
332
+ await self._apply_target(
333
+ subject_id=subject_id,
334
+ context=waiter["context"],
335
+ target=waiter["next_target"],
336
+ payload=payload,
337
+ )
338
+ finally:
339
+ await self.db.unlock_subject(subject_id)
340
+
341
+ async def cancel_subject(self, subject_id: str, hard: bool = True):
342
+ await self.db.set_cancel_flag(subject_id, hard)
343
+ if hard:
344
+ if subject_id in self._active_tasks:
345
+ self._active_tasks[subject_id].cancel()
346
+ await self.bus.publish(
347
+ Event(
348
+ subject_id=subject_id,
349
+ node_name="__CONTROL__",
350
+ payload={"action": "HARD_CANCEL"},
351
+ )
352
+ )
353
+
354
+ # --- LIFECYCLE & HELPERS ---
355
+
356
+ async def _trigger_recompose(self, parent_id: str, join_node_name: str):
357
+ await self.db.lock_and_load(parent_id)
358
+ merged_ctx_dict = await self.db.recompose_parent(parent_id)
359
+ await self._apply_target(
360
+ parent_id, merged_ctx_dict, self._registry[join_node_name]
361
+ )
362
+
363
+ async def start_worker(self, poll_interval: float = 1.0):
364
+ await self.bus.subscribe(self.process_event)
365
+ self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
366
+
367
+ async def _scheduler_loop(self, poll_interval: float):
368
+ while True:
369
+ try:
370
+ due = await self.db.pop_due_events()
371
+ for evt in due:
372
+ await self.bus.publish(evt)
373
+ except asyncio.CancelledError:
374
+ break
375
+ except Exception as e:
376
+ self.logger.error(f"Scheduler: {e}")
377
+ await asyncio.sleep(poll_interval)
378
+
379
+ async def trigger_subject(
380
+ self,
381
+ subject_id: str,
382
+ start_node: Type[Node],
383
+ initial_context: Any,
384
+ payload: Any = None,
385
+ ):
386
+ node_name = start_node.get_node_name()
387
+ p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
388
+ evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load)
389
+ await self.db.save_state(
390
+ subject_id, node_name, self._dump_ctx(initial_context), evt
391
+ )
392
+ await self.bus.publish(evt)
393
+
394
+ async def stop(self):
395
+ if self._scheduler_task:
396
+ self._scheduler_task.cancel()
397
+ for task in self._active_tasks.values():
398
+ task.cancel()
399
+
400
+ def validate_graph(self, start_node: Type[Node]):
401
+ return GraphAnalyzer.validate(start_node, self._registry)
402
+
@@ -10,3 +10,4 @@ class EventBus(ABC):
10
10
  @abstractmethod
11
11
  async def subscribe(self, handler: Callable[[Event], Coroutine]):
12
12
  pass
13
+
@@ -0,0 +1,6 @@
1
+ from abc import ABC
2
+
3
+ class Observer(ABC):
4
+ async def on_transition(self, subject_id: str, from_node: str, to_node: str, duration_ms: float): pass
5
+ async def on_error(self, subject_id: str, node: str, error: Exception): pass
6
+
@@ -0,0 +1,66 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional, Tuple, List, Any
3
+ from ..core.models import Event
4
+
5
+ class Persistence(ABC):
6
+ @abstractmethod
7
+ async def lock_and_load(self, subject_id: str) -> Tuple[Optional[str], Optional[Dict]]:
8
+ pass
9
+
10
+ @abstractmethod
11
+ async def unlock_subject(self, subject_id: str):
12
+ pass
13
+
14
+ @abstractmethod
15
+ async def save_state(self, subject_id: str, node_name: str, context: Dict, event: Optional[Event]):
16
+ pass
17
+
18
+ @abstractmethod
19
+ async def save_sub_state(self, sub_id: str, parent_id: str, node_name: str, ctx: dict, evt: Optional[Event]):
20
+ pass
21
+
22
+ @abstractmethod
23
+ async def create_task_group(self, parent_id: str, join_node_name: str, task_count: int):
24
+ pass
25
+
26
+ @abstractmethod
27
+ async def register_sub_task_completion(self, sub_id: str) -> Optional[str]:
28
+ pass
29
+
30
+ @abstractmethod
31
+ async def recompose_parent(self, parent_id: str) -> dict:
32
+ pass
33
+
34
+ @abstractmethod
35
+ async def schedule_event(self, event: Event) -> bool:
36
+ pass
37
+
38
+ @abstractmethod
39
+ async def pop_due_events(self) -> List[Event]:
40
+ pass
41
+
42
+ @abstractmethod
43
+ async def park_subject(self, subject_id: str, signal_id: str, next_target: Any, context: Dict):
44
+ pass
45
+
46
+ @abstractmethod
47
+ async def get_and_clear_waiters(self, signal_id: str) -> List[Dict]:
48
+ pass
49
+
50
+ @abstractmethod
51
+ async def set_cancel_flag(self, subject_id: str, hard: bool):
52
+ pass
53
+
54
+ @abstractmethod
55
+ async def is_cancelled(self, subject_id: str) -> bool:
56
+ pass
57
+
58
+ @abstractmethod
59
+ async def add_call_waiter(self, key: str, subject_id: str, resume_target: Any, context: dict) -> bool:
60
+ """Returns True if this is the first waiter for this key (the 'leader')."""
61
+ pass
62
+
63
+ @abstractmethod
64
+ async def resolve_call_group(self, key: str) -> List[Dict[str, Any]]:
65
+ pass
66
+
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "commandnet"
3
- version = "0.2.2"
3
+ version = "0.4.0"
4
4
  description = "A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime."
5
5
  authors = [
6
6
  { name = "Christopher Vaz", email = "christophervaz160@gmail.com" }
@@ -41,3 +41,4 @@ build-backend = "poetry.core.masonry.api"
41
41
  asyncio_mode = "strict"
42
42
  asyncio_default_test_loop_scope = "function"
43
43
  asyncio_default_fixture_loop_scope = "function"
44
+
@@ -1,36 +0,0 @@
1
- import inspect
2
- from typing import Generic, TypeVar, Type, Optional, Union, List, Dict, Any
3
- from pydantic import BaseModel, ConfigDict
4
-
5
- C = TypeVar('C', bound=BaseModel) # Context
6
- P = TypeVar('P', bound=BaseModel) # Payload (Optional)
7
-
8
- class ParallelTask(BaseModel):
9
- model_config = ConfigDict(arbitrary_types_allowed=True)
10
- node_cls: Type['Node']
11
- payload: Optional[Any] = None
12
- sub_context_path: str
13
-
14
- class Parallel(BaseModel):
15
- model_config = ConfigDict(arbitrary_types_allowed=True)
16
- branches: List[ParallelTask]
17
- join_node: Type['Node']
18
-
19
- class Schedule(BaseModel):
20
- model_config = ConfigDict(arbitrary_types_allowed=True)
21
- node_cls: Type['Node']
22
- delay_seconds: int
23
- payload: Optional[Any] = None
24
- idempotency_key: Optional[str] = None
25
-
26
- # Minor typing improvement for readability
27
- TransitionResult = Union[Type['Node'], Parallel, Schedule, None]
28
-
29
- class Node(Generic[C, P]):
30
- @classmethod
31
- def get_node_name(cls) -> str:
32
- return cls.__name__
33
-
34
- async def run(self, ctx: C, payload: Optional[P] = None) -> TransitionResult:
35
- """Executes node logic. Returns the Next Node, a Parallel request, a Schedule request, or None."""
36
- pass
@@ -1,211 +0,0 @@
1
- import asyncio
2
- import logging
3
- from datetime import datetime, timezone, timedelta
4
- from typing import Optional, Type, Iterable, Dict
5
- from pydantic import BaseModel
6
-
7
- from ..core.models import Event
8
- from ..core.node import Node, Parallel, Schedule
9
- from ..core.graph import GraphAnalyzer
10
- from ..interfaces.persistence import Persistence
11
- from ..interfaces.event_bus import EventBus
12
- from ..interfaces.observer import Observer
13
-
14
- class Engine:
15
- def __init__(
16
- self,
17
- persistence: Persistence,
18
- event_bus: EventBus,
19
- nodes: Iterable[Type[Node]],
20
- observer: Optional[Observer] = None
21
- ):
22
- self.db = persistence
23
- self.bus = event_bus
24
- self.observer = observer or Observer()
25
- self.logger = logging.getLogger("CommandNet")
26
- self._scheduler_task: Optional[asyncio.Task] = None
27
-
28
- # Build Engine-scoped registry
29
- self._registry: Dict[str, Type[Node]] = {}
30
- for node_cls in nodes:
31
- name = node_cls.get_node_name()
32
- if name in self._registry and self._registry[name] is not node_cls:
33
- raise RuntimeError(f"Node name collision in Engine: '{name}'")
34
- self._registry[name] = node_cls
35
-
36
- def validate_graph(self, start_node: Type[Node]):
37
- """Helper to validate the graph using this engine's registry."""
38
- return GraphAnalyzer.validate(start_node, self._registry)
39
-
40
- async def start_worker(self, poll_interval: float = 1.0):
41
- await self.bus.subscribe(self.process_event)
42
- self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
43
- self.logger.info("Worker and Scheduler started.")
44
-
45
- async def _scheduler_loop(self, poll_interval: float):
46
- while True:
47
- try:
48
- due_events = await self.db.pop_due_events()
49
- for evt in due_events:
50
- await self.bus.publish(evt)
51
- except asyncio.CancelledError:
52
- break
53
- except Exception as e:
54
- self.logger.error(f"Scheduler Error: {e}")
55
- await asyncio.sleep(poll_interval)
56
-
57
- async def stop(self):
58
- if self._scheduler_task:
59
- self._scheduler_task.cancel()
60
- try:
61
- await self._scheduler_task
62
- except asyncio.CancelledError:
63
- pass
64
-
65
- async def trigger_agent(self, agent_id: str, start_node: Type[Node], initial_context: BaseModel, payload: Optional[BaseModel] = None):
66
- node_name = start_node.get_node_name()
67
- if node_name not in self._registry:
68
- raise ValueError(f"Node '{node_name}' is not registered with this Engine.")
69
-
70
- start_event = Event(
71
- agent_id=agent_id,
72
- node_name=node_name,
73
- payload=payload.model_dump() if hasattr(payload, "model_dump") else payload
74
- )
75
- await self.db.save_state(agent_id, node_name, initial_context.model_dump(), start_event)
76
- await self.bus.publish(start_event)
77
-
78
- async def process_event(self, event: Event):
79
- start_time = asyncio.get_event_loop().time()
80
-
81
- current_node_name, ctx_dict = await self.db.load_and_lock_agent(event.agent_id)
82
- if not current_node_name:
83
- return
84
-
85
- locked = True
86
- try:
87
- if current_node_name != event.node_name:
88
- # Same logic as before for sub-state/state
89
- if "#" in event.agent_id:
90
- parent_id = event.agent_id.split("#")[0]
91
- await self.db.save_sub_state(event.agent_id, parent_id, current_node_name, ctx_dict, None)
92
- else:
93
- await self.db.save_state(event.agent_id, current_node_name, ctx_dict, None)
94
- locked = False
95
- return
96
-
97
- node_cls = self._registry.get(current_node_name)
98
- if not node_cls:
99
- raise RuntimeError(f"Node '{current_node_name}' not found in this Engine's registry.")
100
-
101
- ctx_type = GraphAnalyzer.get_context_type(node_cls)
102
- payload_type = GraphAnalyzer.get_payload_type(node_cls)
103
-
104
- ctx = ctx_type.model_validate(ctx_dict) if issubclass(ctx_type, BaseModel) else ctx_dict
105
-
106
- payload = None
107
- if event.payload is not None:
108
- if isinstance(payload_type, type) and issubclass(payload_type, BaseModel):
109
- payload = payload_type.model_validate(event.payload)
110
- else:
111
- payload = event.payload
112
-
113
- node_instance = node_cls()
114
- result = await node_instance.run(ctx, payload)
115
- duration = (asyncio.get_event_loop().time() - start_time) * 1000
116
-
117
- if isinstance(result, Parallel):
118
- await self._handle_parallel_start(event.agent_id, ctx, result, duration)
119
- elif isinstance(result, Schedule):
120
- await self._handle_schedule(event.agent_id, current_node_name, ctx, result, duration)
121
- elif result:
122
- await self._handle_transition(event.agent_id, current_node_name, result, ctx, duration)
123
- else:
124
- await self._handle_terminal(event.agent_id, current_node_name, ctx, duration)
125
-
126
- locked = False
127
-
128
- except Exception as e:
129
- await self.observer.on_error(event.agent_id, current_node_name, e)
130
- raise
131
- finally:
132
- if locked:
133
- await self.db.unlock_agent(event.agent_id)
134
-
135
- # Remaining _handle_* methods use self._registry and self.validate_graph
136
- async def _handle_transition(self, agent_id: str, from_node: str, next_node_cls: Type[Node], ctx: BaseModel, duration: float):
137
- next_name = next_node_cls.get_node_name()
138
- if next_name not in self._registry:
139
- raise RuntimeError(f"Transition target '{next_name}' not in registry.")
140
-
141
- await self.observer.on_transition(agent_id, from_node, next_name, duration)
142
- next_event = Event(agent_id=agent_id, node_name=next_name)
143
- await self.db.save_state(agent_id, next_name, ctx.model_dump(), next_event)
144
- await self.bus.publish(next_event)
145
-
146
- async def _handle_parallel_start(self, parent_id: str, parent_ctx: BaseModel, parallel: Parallel, duration: float):
147
- join_name = parallel.join_node.get_node_name()
148
- await self.observer.on_transition(parent_id, "ParallelStart", join_name, duration)
149
- await self.db.create_task_group(parent_id=parent_id, join_node_name=join_name, task_count=len(parallel.branches))
150
-
151
- for task in parallel.branches:
152
- if not hasattr(parent_ctx, task.sub_context_path):
153
- raise RuntimeError(f"Context missing path: '{task.sub_context_path}'.")
154
-
155
- sub_ctx = getattr(parent_ctx, task.sub_context_path)
156
- sub_id = f"{parent_id}#{task.sub_context_path}"
157
- node_name = task.node_cls.get_node_name()
158
-
159
- evt = Event(
160
- agent_id=sub_id,
161
- node_name=node_name,
162
- payload=task.payload.model_dump() if hasattr(task.payload, "model_dump") else task.payload
163
- )
164
- await self.db.save_sub_state(sub_id, parent_id, node_name, sub_ctx.model_dump(), evt)
165
- await self.bus.publish(evt)
166
-
167
- await self.db.save_state(parent_id, "WAITING_FOR_JOIN", parent_ctx.model_dump(), None)
168
-
169
- async def _handle_terminal(self, agent_id: str, from_node: str, ctx: BaseModel, duration: float):
170
- await self.observer.on_transition(agent_id, from_node, "TERMINAL", duration)
171
- if "#" in agent_id:
172
- parent_id = agent_id.split("#")[0]
173
- await self.db.save_sub_state(agent_id, parent_id, "TERMINAL", ctx.model_dump(), None)
174
- join_node_name = await self.db.register_sub_task_completion(agent_id)
175
- if join_node_name:
176
- await self._trigger_recompose(parent_id, join_node_name)
177
- else:
178
- await self.db.save_state(agent_id, "TERMINAL", ctx.model_dump(), None)
179
-
180
- async def _handle_schedule(self, agent_id: str, from_node: str, ctx: BaseModel, schedule: Schedule, duration: float):
181
- target_name = schedule.node_cls.get_node_name()
182
- await self.observer.on_transition(agent_id, from_node, f"SCHEDULED:{target_name}", duration)
183
- run_at_dt = datetime.now(timezone.utc) + timedelta(seconds=schedule.delay_seconds)
184
-
185
- evt = Event(
186
- agent_id=agent_id,
187
- node_name=target_name,
188
- payload=schedule.payload.model_dump() if hasattr(schedule.payload, "model_dump") else schedule.payload,
189
- run_at=run_at_dt.isoformat(),
190
- idempotency_key=schedule.idempotency_key
191
- )
192
-
193
- scheduled = await self.db.schedule_event(evt)
194
- next_node = target_name if scheduled else "TERMINAL"
195
-
196
- if "#" in agent_id:
197
- parent_id = agent_id.split("#")[0]
198
- await self.db.save_sub_state(agent_id, parent_id, next_node, ctx.model_dump(), None)
199
- if not scheduled:
200
- join_node_name = await self.db.register_sub_task_completion(agent_id)
201
- if join_node_name:
202
- await self._trigger_recompose(parent_id, join_node_name)
203
- else:
204
- await self.db.save_state(agent_id, next_node, ctx.model_dump(), None)
205
-
206
- async def _trigger_recompose(self, parent_id: str, join_node_name: str):
207
- await self.db.load_and_lock_agent(parent_id)
208
- merged_ctx_dict = await self.db.recompose_parent(parent_id)
209
- join_event = Event(agent_id=parent_id, node_name=join_node_name)
210
- await self.db.save_state(parent_id, join_node_name, merged_ctx_dict, join_event)
211
- await self.bus.publish(join_event)
@@ -1,5 +0,0 @@
1
- from abc import ABC
2
-
3
- class Observer(ABC):
4
- async def on_transition(self, agent_id: str, from_node: str, to_node: str, duration_ms: float): pass
5
- async def on_error(self, agent_id: str, node: str, error: Exception): pass
@@ -1,41 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import Dict, Optional, Tuple, List
3
- from ..core.models import Event
4
-
5
- class Persistence(ABC):
6
- @abstractmethod
7
- async def load_and_lock_agent(self, agent_id: str) -> Tuple[Optional[str], Optional[Dict]]:
8
- pass
9
-
10
- @abstractmethod
11
- async def unlock_agent(self, agent_id: str):
12
- """Releases the row lock. Called automatically by the engine if an exception occurs."""
13
- pass
14
-
15
- @abstractmethod
16
- async def save_state(self, agent_id: str, node_name: str, context: Dict, event: Optional[Event]):
17
- pass
18
-
19
- @abstractmethod
20
- async def save_sub_state(self, sub_id: str, parent_id: str, node_name: str, ctx: dict, evt: Optional[Event]):
21
- pass
22
-
23
- @abstractmethod
24
- async def create_task_group(self, parent_id: str, join_node_name: str, task_count: int):
25
- pass
26
-
27
- @abstractmethod
28
- async def register_sub_task_completion(self, sub_id: str) -> Optional[str]:
29
- pass
30
-
31
- @abstractmethod
32
- async def recompose_parent(self, parent_id: str) -> dict:
33
- pass
34
-
35
- @abstractmethod
36
- async def schedule_event(self, event: Event) -> bool:
37
- pass
38
-
39
- @abstractmethod
40
- async def pop_due_events(self) -> List[Event]:
41
- pass