commandnet 0.3.0__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: commandnet
3
- Version: 0.3.0
3
+ Version: 0.4.1
4
4
  Summary: A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime.
5
5
  Author: Christopher Vaz
6
6
  Author-email: christophervaz160@gmail.com
@@ -53,7 +53,7 @@ pip install commandnet
53
53
  ## 🛠️ Quick Start
54
54
 
55
55
  ### 1. Define Your Context
56
- The "Context" is the persistent state of your agent, defined using Pydantic.
56
+ The "Context" is the persistent state of your subject, defined using Pydantic.
57
57
 
58
58
  ```python
59
59
  from pydantic import BaseModel
@@ -139,7 +139,7 @@ engine = Engine(persistence=db, event_bus=bus)
139
139
  await engine.start_worker()
140
140
 
141
141
  # Trigger an execution
142
- await engine.trigger_agent("agent-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
142
+ await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
143
143
  ```
144
144
 
145
145
  ---
@@ -172,3 +172,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
172
172
 
173
173
  MIT
174
174
 
175
+
@@ -30,7 +30,7 @@ pip install commandnet
30
30
  ## 🛠️ Quick Start
31
31
 
32
32
  ### 1. Define Your Context
33
- The "Context" is the persistent state of your agent, defined using Pydantic.
33
+ The "Context" is the persistent state of your subject, defined using Pydantic.
34
34
 
35
35
  ```python
36
36
  from pydantic import BaseModel
@@ -116,7 +116,7 @@ engine = Engine(persistence=db, event_bus=bus)
116
116
  await engine.start_worker()
117
117
 
118
118
  # Trigger an execution
119
- await engine.trigger_agent("agent-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
119
+ await engine.trigger_subject("subject-123", CheckRisk, WorkflowCtx(user_id="user_abc"))
120
120
  ```
121
121
 
122
122
  ---
@@ -148,3 +148,4 @@ print(dag) # {'CheckRisk': ['ProcessPayment'], 'ProcessPayment': []}
148
148
  ## 📄 License
149
149
 
150
150
  MIT
151
+
@@ -1,5 +1,5 @@
1
1
  from .core.models import Event
2
- from .core.node import Node, Parallel, ParallelTask, Schedule, Wait
2
+ from .core.node import Node, Parallel, ParallelTask, Schedule, Wait, Call, Interrupt
3
3
  from .core.graph import GraphAnalyzer
4
4
  from .interfaces.persistence import Persistence
5
5
  from .interfaces.event_bus import EventBus
@@ -9,4 +9,6 @@ from .engine.runtime import Engine
9
9
  __all__ = [
10
10
  "Event", "Node", "Parallel", "ParallelTask", "GraphAnalyzer",
11
11
  "Persistence", "EventBus", "Observer", "Engine", "Schedule",
12
+ "Call", "Interrupt",
12
13
  ]
14
+
@@ -96,3 +96,4 @@ class GraphAnalyzer:
96
96
  f"({source_ctx.__name__} -> {target_ctx.__name__})"
97
97
  )
98
98
  return True
99
+
@@ -8,10 +8,12 @@ def utcnow_iso() -> str:
8
8
 
9
9
  class Event(BaseModel):
10
10
  event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
11
- agent_id: str
11
+ subject_id: str
12
12
  node_name: str
13
13
  payload: Optional[Dict[str, Any]] = None
14
+ headers: Dict[str, str] = Field(default_factory=dict)
14
15
 
15
16
  timestamp: str = Field(default_factory=utcnow_iso)
16
17
  run_at: str = Field(default_factory=utcnow_iso)
17
18
  idempotency_key: Optional[str] = None
19
+
@@ -5,7 +5,7 @@ C = TypeVar('C', bound=BaseModel) # Context
5
5
  P = TypeVar('P', bound=BaseModel) # Payload
6
6
 
7
7
  # The Recursive Type Definition
8
- Target = Union[Type['Node'], 'Parallel', 'Schedule', 'Wait', None]
8
+ Target = Union[Type['Node'], 'Parallel', 'Schedule', 'Wait', 'Call', 'Interrupt', None]
9
9
 
10
10
  class ParallelTask(BaseModel):
11
11
  model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -31,6 +31,20 @@ class Wait(BaseModel):
31
31
  resume_action: Target
32
32
  sub_context_path: Optional[str] = None
33
33
 
34
+ class Call(BaseModel):
35
+ """The 'Await' type: Deduplicates execution based on an idempotency key."""
36
+ model_config = ConfigDict(arbitrary_types_allowed=True)
37
+ node_cls: Type['Node']
38
+ idempotency_key: str
39
+ payload: Optional[Any] = None
40
+ # After the called node finishes, the caller resumes at this node
41
+ resume_action: Target = None
42
+
43
+ class Interrupt(BaseModel):
44
+ """Signals the engine to stop execution."""
45
+ subject_id: str
46
+ hard: bool = True # True = kill task, False = flag for soft exit
47
+
34
48
  TransitionResult = Target
35
49
 
36
50
  class Node(Generic[C, P]):
@@ -41,3 +55,11 @@ class Node(Generic[C, P]):
41
55
  async def run(self, ctx: C, payload: Optional[P] = None) -> TransitionResult:
42
56
  """Executes node logic. Returns a Target (Node class, Parallel, Schedule, Wait, or None)."""
43
57
  pass
58
+
59
+ async def on_signal(self, ctx: C, signal_id: str, payload: Any) -> Target:
60
+ """
61
+ CALLBACK: Called when a 'Wait' state is resolved or an external signal
62
+ is sent to this node specifically.
63
+ """
64
+ return None
65
+
@@ -0,0 +1,407 @@
1
+ import asyncio
2
+ import logging
3
+ from datetime import datetime, timezone, timedelta
4
+ from typing import Optional, Type, Iterable, Dict, Any, Union, List
5
+
6
+ from ..core.models import Event
7
+ from ..core.node import (
8
+ Node,
9
+ Parallel,
10
+ Schedule,
11
+ Wait,
12
+ Target,
13
+ ParallelTask,
14
+ Call,
15
+ Interrupt,
16
+ )
17
+ from ..core.graph import GraphAnalyzer
18
+ from ..interfaces.persistence import Persistence
19
+ from ..interfaces.event_bus import EventBus
20
+ from ..interfaces.observer import Observer
21
+
22
+
23
+ class Engine:
24
+ def __init__(
25
+ self,
26
+ persistence: Persistence,
27
+ event_bus: EventBus,
28
+ nodes: Iterable[Type[Node]],
29
+ observer: Optional[Observer] = None,
30
+ ):
31
+ self.db = persistence
32
+ self.bus = event_bus
33
+ self.observer = observer or Observer()
34
+ self.logger = logging.getLogger("CommandNet")
35
+ self._scheduler_task: Optional[asyncio.Task] = None
36
+ self._active_tasks: Dict[str, asyncio.Task] = {}
37
+ self._registry: Dict[str, Type[Node]] = {n.get_node_name(): n for n in nodes}
38
+
39
+ def _dump_ctx(self, ctx: Any) -> dict:
40
+ return ctx.model_dump() if hasattr(ctx, "model_dump") else ctx
41
+
42
+ def _get_path(self, obj: Any, path: str) -> Any:
43
+ if isinstance(obj, dict):
44
+ return obj.get(path)
45
+ return getattr(obj, path)
46
+
47
+ # --- CORE WORKER LOGIC ---
48
+
49
+ async def process_event(self, event: Event):
50
+ # 1. Internal Control (Cross-worker Hard Cancel)
51
+ if event.node_name == "__CONTROL__":
52
+ action = (event.payload or {}).get("action")
53
+ if action == "HARD_CANCEL" and event.subject_id in self._active_tasks:
54
+ self._active_tasks[event.subject_id].cancel()
55
+ return
56
+
57
+ # 2. Check cancellation status
58
+ if await self.db.is_cancelled(event.subject_id):
59
+ return
60
+
61
+ # 3. Task Tracking for Local Hard Cancel
62
+ task = asyncio.create_task(self._run_node_logic(event))
63
+ self._active_tasks[event.subject_id] = task
64
+ try:
65
+ await task
66
+ except asyncio.CancelledError:
67
+ self.logger.warning(f"Task {event.subject_id} hard-cancelled.")
68
+ finally:
69
+ self._active_tasks.pop(event.subject_id, None)
70
+
71
+ async def _run_node_logic(self, event: Event):
72
+ subject_id = event.subject_id
73
+ # We allow "AWAITING_CALL" because an subject wakes up from that state when a 'Call' resolves
74
+ node_name, ctx_dict = await self.db.lock_and_load(subject_id)
75
+ if not node_name or (
76
+ node_name != event.node_name and node_name != "AWAITING_CALL"
77
+ ):
78
+ if node_name:
79
+ await self.db.unlock_subject(subject_id)
80
+ return
81
+
82
+ try:
83
+ node_cls = self._registry.get(event.node_name)
84
+ if not node_cls:
85
+ raise RuntimeError(f"Node '{event.node_name}' not found.")
86
+
87
+ ctx_type = GraphAnalyzer.get_context_type(node_cls)
88
+ payload_type = GraphAnalyzer.get_payload_type(node_cls)
89
+
90
+ ctx = (
91
+ ctx_type.model_validate(ctx_dict)
92
+ if hasattr(ctx_type, "model_validate")
93
+ else ctx_dict
94
+ )
95
+ payload = (
96
+ payload_type.model_validate(event.payload)
97
+ if (event.payload and hasattr(payload_type, "model_validate"))
98
+ else event.payload
99
+ )
100
+
101
+ # Support Soft-Cancel checking inside Node.run
102
+ if hasattr(ctx, "is_cancelled"):
103
+ ctx.is_cancelled = await self.db.is_cancelled(subject_id)
104
+
105
+ start_t = asyncio.get_event_loop().time()
106
+ result = await node_cls().run(ctx, payload)
107
+
108
+ await self._apply_target(
109
+ subject_id,
110
+ ctx,
111
+ result,
112
+ (asyncio.get_event_loop().time() - start_t) * 1000,
113
+ )
114
+ except Exception as e:
115
+ await self.observer.on_error(subject_id, event.node_name, e)
116
+ raise
117
+ finally:
118
+ await self.db.unlock_subject(subject_id)
119
+
120
+ # --- RECURSIVE TARGET RESOLVER ---
121
+
122
+ async def _apply_target(
123
+ self,
124
+ subject_id: str,
125
+ context: Any,
126
+ target: Target,
127
+ duration: float = 0.0,
128
+ payload: Any = None,
129
+ ):
130
+ # 1. Interrupt (Cancellation)
131
+ if isinstance(target, Interrupt):
132
+ await self.cancel_subject(target.subject_id, target.hard)
133
+ return
134
+
135
+ # 2. Call (Idempotent Await)
136
+ if isinstance(target, Call):
137
+ is_leader = await self.db.add_call_waiter(
138
+ target.idempotency_key,
139
+ subject_id,
140
+ target.resume_action,
141
+ self._dump_ctx(context),
142
+ )
143
+ await self.observer.on_transition(
144
+ subject_id, "RUN", f"CALLING:{target.node_cls.get_node_name()}", duration
145
+ )
146
+ await self.db.save_state(
147
+ subject_id, "AWAITING_CALL", self._dump_ctx(context), None
148
+ )
149
+ if is_leader:
150
+ await self.trigger_subject(
151
+ f"call#{target.idempotency_key}",
152
+ target.node_cls,
153
+ context,
154
+ target.payload,
155
+ )
156
+ return
157
+
158
+ # 3. Terminal State
159
+ if target is None:
160
+ await self.observer.on_transition(subject_id, "RUN", "TERMINAL", duration)
161
+ ctx_dict = self._dump_ctx(context)
162
+
163
+ # Resolve Callers if this was a shared Virtual subject
164
+ if subject_id.startswith("call#"):
165
+ key = subject_id.split("#")[1]
166
+ waiters = await self.db.resolve_call_group(key)
167
+ for w in waiters:
168
+ # Resume waiters using the final context of the shared node as payload
169
+ await self._apply_target(
170
+ w["subject_id"],
171
+ w["context"],
172
+ w["resume_target"],
173
+ payload=ctx_dict,
174
+ )
175
+
176
+ # Parallel Sub-task Completion Logic
177
+ if "#" in subject_id and not subject_id.startswith("call#"):
178
+ parent_id = subject_id.split("#")[0]
179
+ await self.db.save_sub_state(
180
+ subject_id, parent_id, "TERMINAL", ctx_dict, None
181
+ )
182
+ join_node_name = await self.db.register_sub_task_completion(subject_id)
183
+ if join_node_name:
184
+ await self._trigger_recompose(parent_id, join_node_name)
185
+ else:
186
+ await self.db.save_state(subject_id, "TERMINAL", ctx_dict, None)
187
+ return
188
+
189
+ # 4. Standard Node Transition
190
+ if isinstance(target, type) and issubclass(target, Node):
191
+ node_name = target.get_node_name()
192
+ await self.observer.on_transition(subject_id, "RUN", node_name, duration)
193
+
194
+ p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
195
+
196
+ headers = {}
197
+ if hasattr(context, "trace_headers") and context.trace_headers:
198
+ headers = context.trace_headers
199
+
200
+ evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load, headers=headers)
201
+ ctx_dict = self._dump_ctx(context)
202
+
203
+ if "#" in subject_id:
204
+ await self.db.save_sub_state(
205
+ subject_id, subject_id.split("#")[0], node_name, ctx_dict, evt
206
+ )
207
+ else:
208
+ await self.db.save_state(subject_id, node_name, ctx_dict, evt)
209
+
210
+ await self.bus.publish(evt)
211
+ return
212
+
213
+ # 5. Wait / Signal Parking
214
+ if isinstance(target, Wait):
215
+ actual_id = subject_id
216
+ actual_ctx = context
217
+ if target.sub_context_path and "#" not in subject_id:
218
+ actual_id = f"{subject_id}#{target.sub_context_path}"
219
+ actual_ctx = self._get_path(context, target.sub_context_path)
220
+
221
+ await self.observer.on_transition(
222
+ actual_id, "RUN", f"WAIT:{target.signal_id}", duration
223
+ )
224
+ await self.db.park_subject(
225
+ actual_id,
226
+ target.signal_id,
227
+ target.resume_action,
228
+ self._dump_ctx(actual_ctx),
229
+ )
230
+ return
231
+
232
+ # 6. Parallel Fan-out
233
+ if isinstance(target, Parallel):
234
+ join_name = target.join_node.get_node_name() if target.join_node else "FORK"
235
+ await self.observer.on_transition(
236
+ subject_id, "RUN", f"PARALLEL:{join_name}", duration
237
+ )
238
+
239
+ if target.join_node:
240
+ await self.db.create_task_group(
241
+ subject_id, join_name, len(target.branches)
242
+ )
243
+
244
+ for branch in target.branches:
245
+ # Normalize branch to ParallelTask without stripping Wait wrappers
246
+ if isinstance(branch, ParallelTask):
247
+ task = branch
248
+ else:
249
+ # If it's a Wait or a Node class, wrap it but keep the action intact
250
+ path = (
251
+ branch.sub_context_path
252
+ if isinstance(branch, Wait) and branch.sub_context_path
253
+ else "default"
254
+ )
255
+ task = ParallelTask(action=branch, sub_context_path=path)
256
+
257
+ sub_ctx = self._get_path(context, task.sub_context_path)
258
+ await self._apply_target(
259
+ f"{subject_id}#{task.sub_context_path}",
260
+ sub_ctx,
261
+ task.action,
262
+ payload=task.payload,
263
+ )
264
+
265
+ if target.join_node:
266
+ await self.db.save_state(
267
+ subject_id, "WAITING_FOR_JOIN", self._dump_ctx(context), None
268
+ )
269
+ else:
270
+ await self._apply_target(subject_id, context, None)
271
+ return
272
+
273
+ # 7. Schedule (Delayed Execution)
274
+ if isinstance(target, Schedule):
275
+ if not (
276
+ isinstance(target.action, type) and issubclass(target.action, Node)
277
+ ):
278
+ raise TypeError("Schedule.action must be a Node class.")
279
+
280
+ target_node_name = target.action.get_node_name()
281
+ await self.observer.on_transition(
282
+ subject_id, "RUN", f"SCHEDULED:{target_node_name}", duration
283
+ )
284
+
285
+ run_at_dt = datetime.now(timezone.utc) + timedelta(
286
+ seconds=target.delay_seconds
287
+ )
288
+ p_load = (
289
+ target.payload.model_dump()
290
+ if hasattr(target.payload, "model_dump")
291
+ else target.payload
292
+ )
293
+
294
+ scheduled_evt = Event(
295
+ subject_id=subject_id,
296
+ node_name=target_node_name,
297
+ payload=p_load,
298
+ run_at=run_at_dt.isoformat(),
299
+ idempotency_key=target.idempotency_key,
300
+ )
301
+
302
+ await self.db.schedule_event(scheduled_evt)
303
+ await self.db.save_state(
304
+ subject_id, target_node_name, self._dump_ctx(context), None
305
+ )
306
+ return
307
+
308
+ # --- EXTERNAL CONTROL & SIGNALS ---
309
+
310
+ async def signal_node(self, subject_id: str, signal_id: str, payload: Any = None):
311
+ """Resumes subject and triggers its specific on_signal instance method."""
312
+ node_name, ctx_dict = await self.db.lock_and_load(subject_id)
313
+ if not node_name:
314
+ return
315
+ try:
316
+ node_cls = self._registry.get(node_name)
317
+ node_inst = node_cls()
318
+ ctx_type = GraphAnalyzer.get_context_type(node_cls)
319
+ ctx = (
320
+ ctx_type.model_validate(ctx_dict)
321
+ if hasattr(ctx_type, "model_validate")
322
+ else ctx_dict
323
+ )
324
+
325
+ result = await node_inst.on_signal(ctx, signal_id, payload)
326
+ await self._apply_target(subject_id, ctx, result)
327
+ finally:
328
+ await self.db.unlock_subject(subject_id)
329
+
330
+ async def release_signal(self, signal_id: str, payload: Any = None):
331
+ """Standard mass-resume for parked subjects."""
332
+ waiters = await self.db.get_and_clear_waiters(signal_id)
333
+ for waiter in waiters:
334
+ subject_id = waiter["subject_id"]
335
+ await self.db.lock_and_load(subject_id)
336
+ try:
337
+ await self._apply_target(
338
+ subject_id=subject_id,
339
+ context=waiter["context"],
340
+ target=waiter["next_target"],
341
+ payload=payload,
342
+ )
343
+ finally:
344
+ await self.db.unlock_subject(subject_id)
345
+
346
+ async def cancel_subject(self, subject_id: str, hard: bool = True):
347
+ await self.db.set_cancel_flag(subject_id, hard)
348
+ if hard:
349
+ if subject_id in self._active_tasks:
350
+ self._active_tasks[subject_id].cancel()
351
+ await self.bus.publish(
352
+ Event(
353
+ subject_id=subject_id,
354
+ node_name="__CONTROL__",
355
+ payload={"action": "HARD_CANCEL"},
356
+ )
357
+ )
358
+
359
+ # --- LIFECYCLE & HELPERS ---
360
+
361
+ async def _trigger_recompose(self, parent_id: str, join_node_name: str):
362
+ await self.db.lock_and_load(parent_id)
363
+ merged_ctx_dict = await self.db.recompose_parent(parent_id)
364
+ await self._apply_target(
365
+ parent_id, merged_ctx_dict, self._registry[join_node_name]
366
+ )
367
+
368
+ async def start_worker(self, poll_interval: float = 1.0):
369
+ await self.bus.subscribe(self.process_event)
370
+ self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
371
+
372
+ async def _scheduler_loop(self, poll_interval: float):
373
+ while True:
374
+ try:
375
+ due = await self.db.pop_due_events()
376
+ for evt in due:
377
+ await self.bus.publish(evt)
378
+ except asyncio.CancelledError:
379
+ break
380
+ except Exception as e:
381
+ self.logger.error(f"Scheduler: {e}")
382
+ await asyncio.sleep(poll_interval)
383
+
384
+ async def trigger_subject(
385
+ self,
386
+ subject_id: str,
387
+ start_node: Type[Node],
388
+ initial_context: Any,
389
+ payload: Any = None,
390
+ ):
391
+ node_name = start_node.get_node_name()
392
+ p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
393
+ evt = Event(subject_id=subject_id, node_name=node_name, payload=p_load)
394
+ await self.db.save_state(
395
+ subject_id, node_name, self._dump_ctx(initial_context), evt
396
+ )
397
+ await self.bus.publish(evt)
398
+
399
+ async def stop(self):
400
+ if self._scheduler_task:
401
+ self._scheduler_task.cancel()
402
+ for task in self._active_tasks.values():
403
+ task.cancel()
404
+
405
+ def validate_graph(self, start_node: Type[Node]):
406
+ return GraphAnalyzer.validate(start_node, self._registry)
407
+
@@ -10,3 +10,4 @@ class EventBus(ABC):
10
10
  @abstractmethod
11
11
  async def subscribe(self, handler: Callable[[Event], Coroutine]):
12
12
  pass
13
+
@@ -0,0 +1,6 @@
1
+ from abc import ABC
2
+
3
+ class Observer(ABC):
4
+ async def on_transition(self, subject_id: str, from_node: str, to_node: str, duration_ms: float): pass
5
+ async def on_error(self, subject_id: str, node: str, error: Exception): pass
6
+
@@ -4,16 +4,15 @@ from ..core.models import Event
4
4
 
5
5
  class Persistence(ABC):
6
6
  @abstractmethod
7
- async def load_and_lock_agent(self, agent_id: str) -> Tuple[Optional[str], Optional[Dict]]:
7
+ async def lock_and_load(self, subject_id: str) -> Tuple[Optional[str], Optional[Dict]]:
8
8
  pass
9
9
 
10
10
  @abstractmethod
11
- async def unlock_agent(self, agent_id: str):
12
- """Releases the row lock. Called automatically by the engine if an exception occurs."""
11
+ async def unlock_subject(self, subject_id: str):
13
12
  pass
14
13
 
15
14
  @abstractmethod
16
- async def save_state(self, agent_id: str, node_name: str, context: Dict, event: Optional[Event]):
15
+ async def save_state(self, subject_id: str, node_name: str, context: Dict, event: Optional[Event]):
17
16
  pass
18
17
 
19
18
  @abstractmethod
@@ -41,9 +40,27 @@ class Persistence(ABC):
41
40
  pass
42
41
 
43
42
  @abstractmethod
44
- async def park_agent(self, agent_id: str, signal_id: str, next_target: Any, context: Dict):
43
+ async def park_subject(self, subject_id: str, signal_id: str, next_target: Any, context: Dict):
45
44
  pass
46
45
 
47
46
  @abstractmethod
48
47
  async def get_and_clear_waiters(self, signal_id: str) -> List[Dict]:
49
48
  pass
49
+
50
+ @abstractmethod
51
+ async def set_cancel_flag(self, subject_id: str, hard: bool):
52
+ pass
53
+
54
+ @abstractmethod
55
+ async def is_cancelled(self, subject_id: str) -> bool:
56
+ pass
57
+
58
+ @abstractmethod
59
+ async def add_call_waiter(self, key: str, subject_id: str, resume_target: Any, context: dict) -> bool:
60
+ """Returns True if this is the first waiter for this key (the 'leader')."""
61
+ pass
62
+
63
+ @abstractmethod
64
+ async def resolve_call_group(self, key: str) -> List[Dict[str, Any]]:
65
+ pass
66
+
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "commandnet"
3
- version = "0.3.0"
3
+ version = "0.4.1"
4
4
  description = "A lightweight, Pydantic-powered, distributed event-driven state machine and typed node graph runtime."
5
5
  authors = [
6
6
  { name = "Christopher Vaz", email = "christophervaz160@gmail.com" }
@@ -41,3 +41,4 @@ build-backend = "poetry.core.masonry.api"
41
41
  asyncio_mode = "strict"
42
42
  asyncio_default_test_loop_scope = "function"
43
43
  asyncio_default_fixture_loop_scope = "function"
44
+
@@ -1,292 +0,0 @@
1
- import asyncio
2
- import logging
3
- from datetime import datetime, timezone, timedelta
4
- from typing import Optional, Type, Iterable, Dict, Any, Union
5
- from pydantic import BaseModel
6
-
7
- from ..core.models import Event
8
- from ..core.node import Node, Parallel, Schedule, Wait, Target, ParallelTask
9
- from ..core.graph import GraphAnalyzer
10
- from ..interfaces.persistence import Persistence
11
- from ..interfaces.event_bus import EventBus
12
- from ..interfaces.observer import Observer
13
-
14
-
15
- class Engine:
16
- def __init__(
17
- self,
18
- persistence: Persistence,
19
- event_bus: EventBus,
20
- nodes: Iterable[Type[Node]],
21
- observer: Optional[Observer] = None,
22
- ):
23
- self.db = persistence
24
- self.bus = event_bus
25
- self.observer = observer or Observer()
26
- self.logger = logging.getLogger("CommandNet")
27
- self._scheduler_task: Optional[asyncio.Task] = None
28
-
29
- self._registry: Dict[str, Type[Node]] = {n.get_node_name(): n for n in nodes}
30
-
31
- # --- HELPER UTILITIES ---
32
-
33
- def _dump_ctx(self, ctx: Any) -> dict:
34
- """Helper to safely dump context whether it's a model or a dict."""
35
- if hasattr(ctx, "model_dump"):
36
- return ctx.model_dump()
37
- return ctx
38
-
39
- def _get_path(self, obj: Any, path: str) -> Any:
40
- """Helper to get a value from an object or a dict."""
41
- if isinstance(obj, dict):
42
- return obj.get(path)
43
- return getattr(obj, path)
44
-
45
- # --- CORE RECURSIVE RESOLVER ---
46
-
47
- async def _apply_target(
48
- self,
49
- agent_id: str,
50
- context: Any,
51
- target: Target,
52
- duration: float = 0.0,
53
- payload: Any = None,
54
- ):
55
- # 1. Terminal State
56
- if target is None:
57
- await self.observer.on_transition(agent_id, "RUN", "TERMINAL", duration)
58
- ctx_dict = self._dump_ctx(context)
59
- if "#" in agent_id:
60
- parent_id = agent_id.split("#")[0]
61
- await self.db.save_sub_state(
62
- agent_id, parent_id, "TERMINAL", ctx_dict, None
63
- )
64
- join_node_name = await self.db.register_sub_task_completion(agent_id)
65
- if join_node_name:
66
- await self._trigger_recompose(parent_id, join_node_name)
67
- else:
68
- await self.db.save_state(agent_id, "TERMINAL", ctx_dict, None)
69
- return
70
-
71
- # 2. Node Class Transition
72
- if isinstance(target, type) and issubclass(target, Node):
73
- node_name = target.get_node_name()
74
- await self.observer.on_transition(agent_id, "RUN", node_name, duration)
75
-
76
- p_load = payload.model_dump() if hasattr(payload, "model_dump") else payload
77
- evt = Event(agent_id=agent_id, node_name=node_name, payload=p_load)
78
- ctx_dict = self._dump_ctx(context)
79
-
80
- if "#" in agent_id:
81
- await self.db.save_sub_state(
82
- agent_id, agent_id.split("#")[0], node_name, ctx_dict, evt
83
- )
84
- else:
85
- await self.db.save_state(agent_id, node_name, ctx_dict, evt)
86
-
87
- await self.bus.publish(evt)
88
- return
89
-
90
- # 3. Wait Directive
91
- if isinstance(target, Wait):
92
- actual_id = agent_id
93
- actual_ctx = context
94
- if target.sub_context_path and "#" not in agent_id:
95
- actual_id = f"{agent_id}#{target.sub_context_path}"
96
- actual_ctx = self._get_path(context, target.sub_context_path)
97
-
98
- await self.observer.on_transition(
99
- actual_id, "RUN", f"WAIT:{target.signal_id}", duration
100
- )
101
- await self.db.park_agent(
102
- actual_id,
103
- target.signal_id,
104
- target.resume_action,
105
- self._dump_ctx(actual_ctx),
106
- )
107
- return
108
-
109
- # 4. Parallel Directive
110
- if isinstance(target, Parallel):
111
- join_name = target.join_node.get_node_name() if target.join_node else "FORK"
112
- await self.observer.on_transition(
113
- agent_id, "RUN", f"PARALLEL:{join_name}", duration
114
- )
115
-
116
- if target.join_node:
117
- await self.db.create_task_group(
118
- agent_id, join_name, len(target.branches)
119
- )
120
-
121
- for branch in target.branches:
122
- task = (
123
- branch
124
- if isinstance(branch, ParallelTask)
125
- else ParallelTask(
126
- action=branch,
127
- sub_context_path=branch.sub_context_path, # type: ignore
128
- )
129
- )
130
- sub_ctx = self._get_path(context, task.sub_context_path)
131
- await self._apply_target(
132
- f"{agent_id}#{task.sub_context_path}",
133
- sub_ctx,
134
- task.action,
135
- duration=0,
136
- payload=task.payload,
137
- )
138
-
139
- if target.join_node:
140
- await self.db.save_state(
141
- agent_id, "WAITING_FOR_JOIN", self._dump_ctx(context), None
142
- )
143
- else:
144
- await self._apply_target(agent_id, context, None)
145
- return
146
-
147
- # 5. Schedule Directive
148
- if isinstance(target, Schedule):
149
- if isinstance(target.action, type) and issubclass(target.action, Node):
150
- target_node_name = target.action.get_node_name()
151
- else:
152
- raise TypeError("Schedule.action must be a Node class.")
153
-
154
- await self.observer.on_transition(
155
- agent_id, "RUN", f"SCHEDULED:{target_node_name}", duration
156
- )
157
- run_at_dt = datetime.now(timezone.utc) + timedelta(
158
- seconds=target.delay_seconds
159
- )
160
- p_load = (
161
- target.payload.model_dump()
162
- if hasattr(target.payload, "model_dump")
163
- else target.payload
164
- )
165
-
166
- scheduled_evt = Event(
167
- agent_id=agent_id,
168
- node_name=target_node_name,
169
- payload=p_load,
170
- run_at=run_at_dt.isoformat(),
171
- idempotency_key=target.idempotency_key,
172
- )
173
-
174
- await self.db.schedule_event(scheduled_evt)
175
- await self.db.save_state(
176
- agent_id, target_node_name, self._dump_ctx(context), None
177
- )
178
- return
179
-
180
- # --- WORKER & EXTERNAL TRIGGERS ---
181
-
182
- async def process_event(self, event: Event):
183
- start_time = asyncio.get_event_loop().time()
184
- agent_id = event.agent_id
185
- current_node_name, ctx_dict = await self.db.load_and_lock_agent(agent_id)
186
-
187
- if not current_node_name:
188
- return
189
-
190
- try:
191
- if current_node_name != event.node_name:
192
- return
193
-
194
- node_cls = self._registry.get(current_node_name)
195
- if not node_cls:
196
- raise RuntimeError(f"Node '{current_node_name}' not found.")
197
-
198
- ctx_type = GraphAnalyzer.get_context_type(node_cls)
199
- payload_type = GraphAnalyzer.get_payload_type(node_cls)
200
-
201
- ctx = (
202
- ctx_type.model_validate(ctx_dict)
203
- if issubclass(ctx_type, BaseModel)
204
- else ctx_dict
205
- )
206
- payload = (
207
- payload_type.model_validate(event.payload)
208
- if (event.payload and issubclass(payload_type, BaseModel))
209
- else event.payload
210
- )
211
-
212
- result = await node_cls().run(ctx, payload)
213
- await self._apply_target(
214
- agent_id,
215
- ctx,
216
- result,
217
- (asyncio.get_event_loop().time() - start_time) * 1000,
218
- )
219
-
220
- except Exception as e:
221
- await self.observer.on_error(agent_id, current_node_name, e)
222
- raise
223
- finally:
224
- await self.db.unlock_agent(agent_id)
225
-
226
- async def release_signal(self, signal_id: str, payload: Any = None):
227
- waiters = await self.db.get_and_clear_waiters(signal_id)
228
- for waiter in waiters:
229
- agent_id = waiter["agent_id"]
230
- await self.db.load_and_lock_agent(agent_id)
231
- try:
232
- await self._apply_target(
233
- agent_id=agent_id,
234
- context=waiter["context"],
235
- target=waiter["next_target"],
236
- payload=payload,
237
- )
238
- finally:
239
- await self.db.unlock_agent(agent_id)
240
-
241
- # --- LIFECYCLE METHODS ---
242
-
243
- async def start_worker(self, poll_interval: float = 1.0):
244
- await self.bus.subscribe(self.process_event)
245
- self._scheduler_task = asyncio.create_task(self._scheduler_loop(poll_interval))
246
- self.logger.info("Worker started.")
247
-
248
- async def _scheduler_loop(self, poll_interval: float):
249
- while True:
250
- try:
251
- due = await self.db.pop_due_events()
252
- for evt in due:
253
- await self.bus.publish(evt)
254
- except asyncio.CancelledError:
255
- break
256
- except Exception as e:
257
- self.logger.error(f"Scheduler Error: {e}")
258
- await asyncio.sleep(poll_interval)
259
-
260
- async def stop(self):
261
- if self._scheduler_task:
262
- self._scheduler_task.cancel()
263
- try:
264
- await self._scheduler_task
265
- except asyncio.CancelledError:
266
- pass
267
-
268
- async def trigger_agent(
269
- self,
270
- agent_id: str,
271
- start_node: Type[Node],
272
- initial_context: BaseModel,
273
- payload: Optional[BaseModel] = None,
274
- ):
275
- node_name = start_node.get_node_name()
276
- evt = Event(
277
- agent_id=agent_id,
278
- node_name=node_name,
279
- payload=payload.model_dump() if hasattr(payload, "model_dump") else payload,
280
- )
281
- await self.db.save_state(agent_id, node_name, initial_context.model_dump(), evt)
282
- await self.bus.publish(evt)
283
-
284
- async def _trigger_recompose(self, parent_id: str, join_node_name: str):
285
- await self.db.load_and_lock_agent(parent_id)
286
- merged_ctx_dict = await self.db.recompose_parent(parent_id)
287
- await self._apply_target(
288
- parent_id, merged_ctx_dict, self._registry[join_node_name]
289
- )
290
-
291
- def validate_graph(self, start_node: Type[Node]):
292
- return GraphAnalyzer.validate(start_node, self._registry)
@@ -1,5 +0,0 @@
1
- from abc import ABC
2
-
3
- class Observer(ABC):
4
- async def on_transition(self, agent_id: str, from_node: str, to_node: str, duration_ms: float): pass
5
- async def on_error(self, agent_id: str, node: str, error: Exception): pass