digitalkin 0.2.25rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/grpc_servers/_base_server.py +1 -1
  3. digitalkin/grpc_servers/module_server.py +26 -42
  4. digitalkin/grpc_servers/module_servicer.py +30 -24
  5. digitalkin/grpc_servers/utils/grpc_client_wrapper.py +3 -3
  6. digitalkin/grpc_servers/utils/models.py +1 -1
  7. digitalkin/logger.py +60 -23
  8. digitalkin/mixins/__init__.py +19 -0
  9. digitalkin/mixins/base_mixin.py +10 -0
  10. digitalkin/mixins/callback_mixin.py +24 -0
  11. digitalkin/mixins/chat_history_mixin.py +108 -0
  12. digitalkin/mixins/cost_mixin.py +76 -0
  13. digitalkin/mixins/file_history_mixin.py +99 -0
  14. digitalkin/mixins/filesystem_mixin.py +47 -0
  15. digitalkin/mixins/logger_mixin.py +59 -0
  16. digitalkin/mixins/storage_mixin.py +79 -0
  17. digitalkin/models/module/__init__.py +2 -0
  18. digitalkin/models/module/module.py +9 -1
  19. digitalkin/models/module/module_context.py +90 -6
  20. digitalkin/models/module/module_types.py +5 -5
  21. digitalkin/models/module/task_monitor.py +51 -0
  22. digitalkin/models/services/__init__.py +9 -0
  23. digitalkin/models/services/storage.py +39 -5
  24. digitalkin/modules/_base_module.py +105 -74
  25. digitalkin/modules/job_manager/base_job_manager.py +12 -8
  26. digitalkin/modules/job_manager/single_job_manager.py +84 -78
  27. digitalkin/modules/job_manager/surrealdb_repository.py +225 -0
  28. digitalkin/modules/job_manager/task_manager.py +391 -0
  29. digitalkin/modules/job_manager/task_session.py +276 -0
  30. digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
  31. digitalkin/modules/tool_module.py +10 -2
  32. digitalkin/modules/trigger_handler.py +7 -6
  33. digitalkin/services/cost/__init__.py +9 -2
  34. digitalkin/services/storage/grpc_storage.py +1 -1
  35. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/METADATA +18 -18
  36. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/RECORD +39 -26
  37. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/WHEEL +0 -0
  38. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/licenses/LICENSE +0 -0
  39. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,391 @@
1
+ """Task manager with comprehensive lifecycle management."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import datetime
6
+ from collections.abc import Coroutine
7
+ from typing import Any
8
+
9
+ from digitalkin.logger import logger
10
+ from digitalkin.models.module.task_monitor import SignalMessage, SignalType, TaskStatus
11
+ from digitalkin.modules._base_module import BaseModule
12
+ from digitalkin.modules.job_manager.task_session import SurrealDBConnection, TaskSession
13
+
14
+
15
+ class TaskManager:
16
+ """Task manager with comprehensive lifecycle management."""
17
+
18
+ tasks: dict[str, asyncio.Task]
19
+ tasks_sessions: dict[str, TaskSession]
20
+ channel: SurrealDBConnection
21
+ default_timeout: float
22
+ max_concurrent_tasks: int
23
+ _shutdown_event: asyncio.Event
24
+
25
+ def __init__(self, default_timeout: float = 10.0, max_concurrent_tasks: int = 100) -> None:
26
+ """."""
27
+ self.tasks = {}
28
+ self.tasks_sessions = {}
29
+ self.default_timeout = default_timeout
30
+ self.max_concurrent_tasks = max_concurrent_tasks
31
+ self._shutdown_event = asyncio.Event()
32
+
33
+ logger.info(
34
+ "TaskManager initialized with max_concurrent_tasks: %d, default_timeout: %.1f",
35
+ max_concurrent_tasks,
36
+ default_timeout,
37
+ extra={"max_concurrent_tasks": max_concurrent_tasks, "default_timeout": default_timeout},
38
+ )
39
+
40
+ @property
41
+ def task_count(self) -> int:
42
+ """."""
43
+ return len(self.tasks_sessions)
44
+
45
+ @property
46
+ def running_tasks(self) -> set[str]:
47
+ """."""
48
+ return {task_id for task_id, task in self.tasks.items() if not task.done()}
49
+
50
+ async def _cleanup_task(self, task_id: str) -> None:
51
+ """Clean up task resources."""
52
+ logger.debug("Cleaning up resources for task: '%s'", task_id, extra={"task_id": task_id})
53
+ if task_id in self.tasks_sessions:
54
+ await self.tasks_sessions[task_id].db.close()
55
+ # Remove from collections
56
+
57
+ async def _task_wrapper( # noqa: C901, PLR0915
58
+ self,
59
+ task_id: str,
60
+ coro: Coroutine[Any, Any, None],
61
+ session: TaskSession,
62
+ ) -> asyncio.Task[None]:
63
+ """Task wrapper that runs main, heartbeat, and listener concurrently.
64
+
65
+ The first to finish determines the outcome. Returns a Task that the
66
+ caller can await externally.
67
+
68
+ Returns:
69
+ asyncio.Task[None]: The supervisor task managing the lifecycle.
70
+ """
71
+
72
+ async def signal_wrapper() -> None:
73
+ try:
74
+ await self.channel.create(
75
+ "tasks",
76
+ SignalMessage(
77
+ task_id=task_id,
78
+ status=session.status,
79
+ action=SignalType.START,
80
+ ).model_dump(),
81
+ )
82
+ await session.listen_signals()
83
+ except asyncio.CancelledError:
84
+ logger.debug("Signal listener cancelled", extra={"task_id": task_id})
85
+ finally:
86
+ await self.channel.create(
87
+ "tasks",
88
+ SignalMessage(
89
+ task_id=task_id,
90
+ status=session.status,
91
+ action=SignalType.STOP,
92
+ ).model_dump(),
93
+ )
94
+ logger.info("Signal listener ended", extra={"task_id": task_id})
95
+
96
+ async def heartbeat_wrapper() -> None:
97
+ try:
98
+ await session.generate_heartbeats()
99
+ except asyncio.CancelledError:
100
+ logger.debug("Signal listener cancelled", extra={"task_id": task_id})
101
+ finally:
102
+ logger.info("Heartbeat task ended", extra={"task_id": task_id})
103
+
104
+ async def supervisor() -> None:
105
+ session.started_at = datetime.datetime.now(datetime.timezone.utc)
106
+ session.status = TaskStatus.RUNNING
107
+
108
+ main_task = asyncio.create_task(coro, name=f"{task_id}_main")
109
+ hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
110
+ sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
111
+
112
+ try:
113
+ done, pending = await asyncio.wait(
114
+ [main_task, sig_task, hb_task],
115
+ return_when=asyncio.FIRST_COMPLETED,
116
+ )
117
+
118
+ # One task completed -> cancel the others
119
+ for t in pending:
120
+ t.cancel()
121
+
122
+ # Propagate exception/result from the finished task
123
+ completed = next(iter(done))
124
+ await completed
125
+
126
+ logger.critical(f"{completed=} | {main_task=} | {hb_task=} | {sig_task=}")
127
+
128
+ if completed is main_task:
129
+ session.status = TaskStatus.COMPLETED
130
+ elif completed is sig_task or (completed is hb_task and sig_task.done()):
131
+ logger.critical(f"{sig_task=}")
132
+ session.status = TaskStatus.CANCELLED
133
+ elif completed is hb_task:
134
+ session.status = TaskStatus.FAILED
135
+ msg = f"Heartbeat stopped for {task_id}"
136
+ raise RuntimeError(msg) # noqa: TRY301
137
+
138
+ except asyncio.CancelledError:
139
+ session.status = TaskStatus.CANCELLED
140
+ raise
141
+ except Exception:
142
+ session.status = TaskStatus.FAILED
143
+ raise
144
+ finally:
145
+ session.completed_at = datetime.datetime.now(datetime.timezone.utc)
146
+ # Ensure all tasks are cleaned up
147
+ for t in [main_task, hb_task, sig_task]:
148
+ if not t.done():
149
+ t.cancel()
150
+ await asyncio.gather(main_task, hb_task, sig_task, return_exceptions=True)
151
+
152
+ # Return the supervisor task to be awaited outside
153
+ return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
154
+
155
+ async def create_task(
156
+ self,
157
+ task_id: str,
158
+ module: BaseModule,
159
+ coro: Coroutine[Any, Any, None],
160
+ heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
161
+ connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
162
+ ) -> None:
163
+ """Create and start a new managed task.
164
+
165
+ Raises:
166
+ ValueError: task_id duplicated
167
+ RuntimeError: task overload
168
+ """
169
+ if task_id in self.tasks:
170
+ # close Coroutine during runtime
171
+ coro.close()
172
+ logger.warning("Task creation failed - task already exists: '%s'", task_id, extra={"task_id": task_id})
173
+ msg = f"Task {task_id} already exists"
174
+ raise ValueError(msg)
175
+
176
+ if len(self.tasks) >= self.max_concurrent_tasks:
177
+ coro.close()
178
+ logger.error(
179
+ "Task creation failed - max concurrent tasks reached: %d",
180
+ self.max_concurrent_tasks,
181
+ extra={
182
+ "task_id": task_id,
183
+ "current_count": len(self.tasks),
184
+ "max_concurrent": self.max_concurrent_tasks,
185
+ },
186
+ )
187
+ msg = f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached"
188
+ raise RuntimeError(msg)
189
+
190
+ logger.info(
191
+ "Creating new task: '%s'",
192
+ task_id,
193
+ extra={
194
+ "task_id": task_id,
195
+ "heartbeat_interval": heartbeat_interval,
196
+ "connection_timeout": connection_timeout,
197
+ },
198
+ )
199
+
200
+ try:
201
+ # Initialize components
202
+ channel: SurrealDBConnection = SurrealDBConnection("task_manager", connection_timeout)
203
+ await channel.init_surreal_instance()
204
+ session = TaskSession(task_id, channel, module, heartbeat_interval)
205
+
206
+ self.tasks_sessions[task_id] = session
207
+
208
+ # Create wrapper task
209
+ self.tasks[task_id] = asyncio.create_task(self._task_wrapper(task_id, coro, session), name=task_id)
210
+
211
+ logger.info(
212
+ "Task created successfully: '%s'",
213
+ task_id,
214
+ extra={
215
+ "task_id": task_id,
216
+ "total_tasks": len(self.tasks),
217
+ },
218
+ )
219
+
220
+ except Exception as e:
221
+ logger.error(
222
+ "Failed to create task: '%s'", task_id, extra={"task_id": task_id, "error": str(e)}, exc_info=True
223
+ )
224
+ # Cleanup on failure
225
+ await self._cleanup_task(task_id)
226
+ raise
227
+
228
+ async def send_signal(self, task_id: str, signal_type: str, payload: dict) -> bool:
229
+ """Send signal to a specific task.
230
+
231
+ Returns:
232
+ bool: True if the task sent successfully the given signal, False otherwise.
233
+ """
234
+ if task_id not in self.tasks_sessions:
235
+ logger.warning(
236
+ "Cannot send signal - task not found: '%s'",
237
+ task_id,
238
+ extra={"task_id": task_id, "signal_type": signal_type},
239
+ )
240
+ return False
241
+
242
+ logger.info(
243
+ "Sending signal '%s' to task: '%s'",
244
+ signal_type,
245
+ task_id,
246
+ extra={"task_id": task_id, "signal_type": signal_type, "payload": payload},
247
+ )
248
+
249
+ await self.channel.update("tasks", signal_type, payload)
250
+ return True
251
+
252
+ async def cancel_task(self, task_id: str, timeout: float | None = None) -> bool:
253
+ """Cancel a task with graceful shutdown and fallback.
254
+
255
+ Returns:
256
+ bool: True if the task was cancelled successfully, False otherwise.
257
+ """
258
+ if task_id not in self.tasks:
259
+ logger.warning("Cannot cancel - task not found: '%s'", task_id, extra={"task_id": task_id})
260
+ return True
261
+
262
+ timeout = timeout or self.default_timeout
263
+ task = self.tasks[task_id]
264
+
265
+ logger.info(
266
+ "Initiating task cancellation: '%s', timeout: %.1fs",
267
+ task_id,
268
+ timeout,
269
+ extra={"task_id": task_id, "timeout": timeout},
270
+ )
271
+
272
+ try:
273
+ # Phase 1: Cooperative cancellation
274
+ # await self.send_signal(task_id, "cancel") # noqa: ERA001
275
+
276
+ # Wait for graceful shutdown
277
+ await asyncio.wait_for(task, timeout=timeout)
278
+
279
+ logger.info("Task cancelled gracefully: '%s'", task_id, extra={"task_id": task_id})
280
+
281
+ except asyncio.TimeoutError:
282
+ logger.warning(
283
+ "Graceful cancellation timed out for task: '%s', forcing cancellation",
284
+ task_id,
285
+ extra={"task_id": task_id, "timeout": timeout},
286
+ )
287
+
288
+ # Phase 2: Force cancellation
289
+ task.cancel()
290
+ with contextlib.suppress(asyncio.CancelledError):
291
+ await task
292
+
293
+ logger.warning("Task force-cancelled: '%s'", task_id, extra={"task_id": task_id})
294
+ return True
295
+
296
+ except Exception as e:
297
+ logger.error(
298
+ "Error during task cancellation: '%s'",
299
+ task_id,
300
+ extra={"task_id": task_id, "error": str(e)},
301
+ exc_info=True,
302
+ )
303
+ return False
304
+ return True
305
+
306
+ async def clean_session(self, task_id: str) -> bool:
307
+ """Clean up task session without cancelling the task.
308
+
309
+ Returns:
310
+ bool: True if the task was cleaned successfully, False otherwise.
311
+ """
312
+ if task_id not in self.tasks_sessions:
313
+ logger.warning("Cannot clean session - task not found: '%s'", task_id, extra={"task_id": task_id})
314
+ return False
315
+
316
+ await self.tasks_sessions[task_id].module.stop()
317
+ await self.cancel_task(task_id)
318
+
319
+ logger.info("Cleaning up session for task: '%s'", task_id, extra={"task_id": task_id})
320
+ self.tasks_sessions.pop(task_id, None)
321
+ return True
322
+
323
+ async def pause_task(self, task_id: str) -> bool:
324
+ """Pause a running task.
325
+
326
+ Returns:
327
+ bool: True if the task was paused successfully, False otherwise.
328
+ """
329
+ return await self.send_signal(task_id, "pause", {})
330
+
331
+ async def resume_task(self, task_id: str) -> bool:
332
+ """Resume a paused task.
333
+
334
+ Returns:
335
+ bool: True if the task was paused successfully, False otherwise.
336
+ """
337
+ return await self.send_signal(task_id, "resume", {})
338
+
339
+ async def get_task_status(self, task_id: str) -> bool:
340
+ """Request status from a task.
341
+
342
+ Returns:
343
+ bool: True if the task was paused successfully, False otherwise.
344
+ """
345
+ return await self.send_signal(task_id, "status", {})
346
+
347
+ async def cancel_all_tasks(self, timeout: float | None = None) -> dict[str, bool]:
348
+ """Cancel all running tasks.
349
+
350
+ Returns:
351
+ dict[str: bool]: True if the tasks were paused successfully, False otherwise.
352
+ """
353
+ timeout = timeout or self.default_timeout
354
+ task_ids = list(self.running_tasks)
355
+
356
+ logger.info(
357
+ "Cancelling all tasks: %d tasks", len(task_ids), extra={"task_count": len(task_ids), "timeout": timeout}
358
+ )
359
+
360
+ results = {}
361
+ for task_id in task_ids:
362
+ results[task_id] = await self.cancel_task(task_id, timeout)
363
+
364
+ return results
365
+
366
+ async def shutdown(self, timeout: float = 30.0) -> None:
367
+ """Graceful shutdown of all tasks."""
368
+ logger.info(
369
+ "TaskManager shutdown initiated, timeout: %.1fs",
370
+ timeout,
371
+ extra={"timeout": timeout, "active_tasks": len(self.running_tasks)},
372
+ )
373
+
374
+ self._shutdown_event.set()
375
+ results = await self.cancel_all_tasks(timeout)
376
+
377
+ failed_tasks = [task_id for task_id, success in results.items() if not success]
378
+ if failed_tasks:
379
+ logger.error(
380
+ "Failed to cancel %d tasks during shutdown: %s",
381
+ len(failed_tasks),
382
+ failed_tasks,
383
+ extra={"failed_tasks": failed_tasks, "failed_count": len(failed_tasks)},
384
+ )
385
+
386
+ logger.info(
387
+ "TaskManager shutdown completed, cancelled: %d, failed: %d",
388
+ len(results) - len(failed_tasks),
389
+ len(failed_tasks),
390
+ extra={"cancelled_count": len(results) - len(failed_tasks), "failed_count": len(failed_tasks)},
391
+ )
@@ -0,0 +1,276 @@
1
+ """."""
2
+
3
+ import asyncio
4
+ import datetime
5
+ from collections.abc import AsyncGenerator
6
+
7
+ from digitalkin.logger import logger
8
+ from digitalkin.models.module.task_monitor import HeartbeatMessage, SignalMessage, SignalType, TaskStatus
9
+ from digitalkin.modules._base_module import BaseModule
10
+ from digitalkin.modules.job_manager.surrealdb_repository import SurrealDBConnection
11
+
12
+
13
+ class TaskSession:
14
+ """Task Session with lifecycle management.
15
+
16
+ The Session defined the whole lifecycle of a task as an epheneral context.
17
+ """
18
+
19
+ db: SurrealDBConnection
20
+ module: BaseModule
21
+
22
+ status: TaskStatus
23
+ signal_queue: AsyncGenerator | None
24
+
25
+ task_id: str
26
+ signal_record_id: str | None
27
+ heartbeat_record_id: str | None
28
+
29
+ started_at: datetime.datetime | None
30
+ completed_at: datetime.datetime | None
31
+
32
+ is_cancelled: asyncio.Event
33
+ _paused: asyncio.Event
34
+ _heartbeat_interval: datetime.timedelta
35
+ _last_heartbeat: datetime.datetime
36
+
37
+ def __init__(
38
+ self,
39
+ task_id: str,
40
+ db: SurrealDBConnection,
41
+ module: BaseModule,
42
+ heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
43
+ ) -> None:
44
+ """."""
45
+ self.db = db
46
+ self.module = module
47
+
48
+ self.status = TaskStatus.PENDING
49
+ self.queue: asyncio.Queue = asyncio.Queue()
50
+
51
+ self.task_id = task_id
52
+ self.heartbeat = None
53
+ self.started_at = None
54
+ self.completed_at = None
55
+
56
+ self.signal_record_id = None
57
+ self.heartbeat_record_id = None
58
+
59
+ self.is_cancelled = asyncio.Event()
60
+ self._paused = asyncio.Event()
61
+ self._heartbeat_interval = heartbeat_interval
62
+
63
+ logger.info(
64
+ "TaskContext initialized for task: '%s'",
65
+ task_id,
66
+ extra={"task_id": task_id, "heartbeat_interval": heartbeat_interval},
67
+ )
68
+
69
+ @property
70
+ def cancelled(self) -> bool:
71
+ """."""
72
+ return self.is_cancelled.is_set()
73
+
74
+ @property
75
+ def paused(self) -> bool:
76
+ """."""
77
+ return self._paused.is_set()
78
+
79
+ async def send_heartbeat(self) -> bool:
80
+ """Rate-limited heartbeat with connection resilience.
81
+
82
+ Returns:
83
+ bool: True if heartbeat was successful, False otherwise
84
+ """
85
+ heartbeat = HeartbeatMessage(
86
+ task_id=self.task_id,
87
+ timestamp=datetime.datetime.now(datetime.timezone.utc),
88
+ )
89
+
90
+ if self.heartbeat_record_id is None:
91
+ try:
92
+ success = await self.db.create("heartbeats", heartbeat.model_dump())
93
+ logger.critical(f"{success=} | {'code' not in success}")
94
+ if "code" not in success:
95
+ self.heartbeat_record_id = success.get("id") # type: ignore
96
+ self._last_heartbeat = heartbeat.timestamp
97
+ return True
98
+ except Exception as e:
99
+ logger.error(
100
+ "Heartbeat exception for task: '%s'",
101
+ self.task_id,
102
+ extra={"task_id": self.task_id, "error": str(e)},
103
+ exc_info=True,
104
+ )
105
+ logger.error(
106
+ "Initial heartbeat failed for task: '%s'",
107
+ self.task_id,
108
+ extra={"task_id": self.task_id},
109
+ )
110
+ return False
111
+
112
+ if (heartbeat.timestamp - self._last_heartbeat) < self._heartbeat_interval:
113
+ logger.debug(
114
+ "Heartbeat skipped due to rate limiting for task: '%s' | delta=%s",
115
+ self.task_id,
116
+ heartbeat.timestamp - self._last_heartbeat,
117
+ )
118
+ return True
119
+
120
+ try:
121
+ success = await self.db.merge("heartbeats", self.heartbeat_record_id, heartbeat.model_dump())
122
+ if "code" not in success:
123
+ self._last_heartbeat = heartbeat.timestamp
124
+ return True
125
+ except Exception as e:
126
+ logger.error(
127
+ "Heartbeat exception for task: '%s'",
128
+ self.task_id,
129
+ extra={"task_id": self.task_id, "error": str(e)},
130
+ exc_info=True,
131
+ )
132
+ logger.warning(
133
+ "Heartbeat failed for task: '%s'",
134
+ self.task_id,
135
+ extra={"task_id": self.task_id},
136
+ )
137
+ return False
138
+
139
+ async def generate_heartbeats(self) -> None:
140
+ """Periodic heartbeat generator with cancellation support."""
141
+ logger.critical("Heartbeat started")
142
+ while not self.cancelled:
143
+ logger.debug(f"Heartbeat tick for task: '{self.task_id}' | {self.cancelled=}")
144
+ success = await self.send_heartbeat()
145
+ if not success:
146
+ logger.error("Heartbeat failed, cancelling task: '%s'", self.task_id, extra={"task_id": self.task_id})
147
+ await self._handle_cancel()
148
+ break
149
+ await asyncio.sleep(self._heartbeat_interval.total_seconds())
150
+
151
+ async def wait_if_paused(self) -> None:
152
+ """Block execution if task is paused."""
153
+ if self._paused.is_set():
154
+ logger.info("Task paused, waiting for resume: '%s'", self.task_id, extra={"task_id": self.task_id})
155
+ await self._paused.wait()
156
+
157
+ async def listen_signals(self) -> None: # noqa: C901
158
+ """Enhanced signal listener with comprehensive handling.
159
+
160
+ Raises:
161
+ CancelledError: Asyncio when task cancelling
162
+ """
163
+ logger.info("Signal listener started for task: '%s'", self.task_id, extra={"task_id": self.task_id})
164
+ if self.signal_record_id is None:
165
+ self.signal_record_id = (await self.db.select_by_task_id("tasks", self.task_id)).get("id")
166
+
167
+ live_id, live_signals = await self.db.start_live("tasks")
168
+ try:
169
+ async for signal in live_signals:
170
+ logger.critical("Signal received for task '%s': %s", self.task_id, signal)
171
+ if self.cancelled:
172
+ break
173
+
174
+ if signal is None or signal["id"] == self.signal_record_id or "payload" not in signal:
175
+ continue
176
+
177
+ if signal["action"] == "cancel":
178
+ await self._handle_cancel()
179
+ elif signal["action"] == "pause":
180
+ await self._handle_pause()
181
+ elif signal["action"] == "resume":
182
+ await self._handle_resume()
183
+ elif signal["action"] == "status":
184
+ await self._handle_status_request()
185
+
186
+ except asyncio.CancelledError:
187
+ logger.debug("Signal listener cancelled for task: '%s'", self.task_id, extra={"task_id": self.task_id})
188
+ raise
189
+ except Exception as e:
190
+ logger.error(
191
+ "Signal listener fatal error for task: '%s'",
192
+ self.task_id,
193
+ extra={"task_id": self.task_id, "error": str(e)},
194
+ exc_info=True,
195
+ )
196
+ finally:
197
+ await self.db.stop_live(live_id)
198
+ logger.info("Signal listener stopped for task: '%s'", self.task_id, extra={"task_id": self.task_id})
199
+
200
+ async def _handle_cancel(self) -> None:
201
+ """Idempotent cancellation with acknowledgment."""
202
+ logger.critical("Handle cancel called")
203
+ if self.is_cancelled.is_set():
204
+ logger.debug(
205
+ "Cancel signal ignored - task already cancelled: '%s'", self.task_id, extra={"task_id": self.task_id}
206
+ )
207
+ return
208
+
209
+ logger.info("Cancelling task: '%s'", self.task_id, extra={"task_id": self.task_id})
210
+
211
+ self.status = TaskStatus.CANCELLED
212
+ self.is_cancelled.set()
213
+
214
+ # Resume if paused so cancellation can proceed
215
+ if self._paused.is_set():
216
+ self._paused.set()
217
+
218
+ await self.db.update(
219
+ "tasks",
220
+ self.signal_record_id, # type: ignore
221
+ SignalMessage(
222
+ task_id=self.task_id,
223
+ action=SignalType.ACK_CANCEL,
224
+ status=self.status,
225
+ ).model_dump(),
226
+ )
227
+
228
+ async def _handle_pause(self) -> None:
229
+ """Pause task execution."""
230
+ if not self._paused.is_set():
231
+ logger.info("Pausing task: '%s'", self.task_id, extra={"task_id": self.task_id})
232
+ self._paused.set()
233
+
234
+ await self.db.update(
235
+ "tasks",
236
+ self.signal_record_id, # type: ignore
237
+ SignalMessage(
238
+ task_id=self.task_id,
239
+ action=SignalType.ACK_PAUSE,
240
+ status=self.status,
241
+ ).model_dump(),
242
+ )
243
+
244
+ async def _handle_resume(self) -> None:
245
+ """Resume paused task."""
246
+ if self._paused.is_set():
247
+ logger.info("Resuming task: '%s'", self.task_id, extra={"task_id": self.task_id})
248
+ self._paused.clear()
249
+
250
+ await self.db.update(
251
+ "tasks",
252
+ self.signal_record_id, # type: ignore
253
+ SignalMessage(
254
+ task_id=self.task_id,
255
+ action=SignalType.ACK_RESUME,
256
+ status=self.status,
257
+ ).model_dump(),
258
+ )
259
+
260
+ async def _handle_status_request(self) -> None:
261
+ """Send current task status."""
262
+ await self.db.update(
263
+ "tasks",
264
+ self.signal_record_id, # type: ignore
265
+ SignalMessage(
266
+ action=SignalType.ACK_STATUS,
267
+ task_id=self.task_id,
268
+ status=self.status,
269
+ ).model_dump(),
270
+ )
271
+
272
+ logger.debug(
273
+ "Status report sent for task: '%s'",
274
+ self.task_id,
275
+ extra={"task_id": self.task_id},
276
+ )
@@ -19,7 +19,7 @@ from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, Offse
19
19
 
20
20
  from digitalkin.logger import logger
21
21
  from digitalkin.models.module import InputModelT, SetupModelT
22
- from digitalkin.models.module.module import ModuleStatus
22
+ from digitalkin.models.module.task_monitor import TaskStatus
23
23
  from digitalkin.modules._base_module import BaseModule
24
24
  from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
25
25
  from digitalkin.modules.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
@@ -279,7 +279,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
279
279
  msg = "stop_all_modules not implemented in TaskiqJobManager"
280
280
  raise NotImplementedError(msg)
281
281
 
282
- async def get_module_status(self, job_id: str) -> ModuleStatus | None:
282
+ async def get_module_status(self, job_id: str) -> TaskStatus:
283
283
  """Query a module status."""
284
284
  msg = "get_module_status not implemented in TaskiqJobManager"
285
285
  raise NotImplementedError(msg)