digitalkin 0.2.25rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/grpc_servers/_base_server.py +1 -1
- digitalkin/grpc_servers/module_server.py +26 -42
- digitalkin/grpc_servers/module_servicer.py +30 -24
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +3 -3
- digitalkin/grpc_servers/utils/models.py +1 -1
- digitalkin/logger.py +60 -23
- digitalkin/mixins/__init__.py +19 -0
- digitalkin/mixins/base_mixin.py +10 -0
- digitalkin/mixins/callback_mixin.py +24 -0
- digitalkin/mixins/chat_history_mixin.py +108 -0
- digitalkin/mixins/cost_mixin.py +76 -0
- digitalkin/mixins/file_history_mixin.py +99 -0
- digitalkin/mixins/filesystem_mixin.py +47 -0
- digitalkin/mixins/logger_mixin.py +59 -0
- digitalkin/mixins/storage_mixin.py +79 -0
- digitalkin/models/module/__init__.py +2 -0
- digitalkin/models/module/module.py +9 -1
- digitalkin/models/module/module_context.py +90 -6
- digitalkin/models/module/module_types.py +5 -5
- digitalkin/models/module/task_monitor.py +51 -0
- digitalkin/models/services/__init__.py +9 -0
- digitalkin/models/services/storage.py +39 -5
- digitalkin/modules/_base_module.py +105 -74
- digitalkin/modules/job_manager/base_job_manager.py +12 -8
- digitalkin/modules/job_manager/single_job_manager.py +84 -78
- digitalkin/modules/job_manager/surrealdb_repository.py +225 -0
- digitalkin/modules/job_manager/task_manager.py +391 -0
- digitalkin/modules/job_manager/task_session.py +276 -0
- digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
- digitalkin/modules/tool_module.py +10 -2
- digitalkin/modules/trigger_handler.py +7 -6
- digitalkin/services/cost/__init__.py +9 -2
- digitalkin/services/storage/grpc_storage.py +1 -1
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/METADATA +18 -18
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/RECORD +39 -26
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/WHEEL +0 -0
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
"""Task manager with comprehensive lifecycle management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import datetime
|
|
6
|
+
from collections.abc import Coroutine
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from digitalkin.logger import logger
|
|
10
|
+
from digitalkin.models.module.task_monitor import SignalMessage, SignalType, TaskStatus
|
|
11
|
+
from digitalkin.modules._base_module import BaseModule
|
|
12
|
+
from digitalkin.modules.job_manager.task_session import SurrealDBConnection, TaskSession
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TaskManager:
|
|
16
|
+
"""Task manager with comprehensive lifecycle management."""
|
|
17
|
+
|
|
18
|
+
tasks: dict[str, asyncio.Task]
|
|
19
|
+
tasks_sessions: dict[str, TaskSession]
|
|
20
|
+
channel: SurrealDBConnection
|
|
21
|
+
default_timeout: float
|
|
22
|
+
max_concurrent_tasks: int
|
|
23
|
+
_shutdown_event: asyncio.Event
|
|
24
|
+
|
|
25
|
+
def __init__(self, default_timeout: float = 10.0, max_concurrent_tasks: int = 100) -> None:
|
|
26
|
+
"""."""
|
|
27
|
+
self.tasks = {}
|
|
28
|
+
self.tasks_sessions = {}
|
|
29
|
+
self.default_timeout = default_timeout
|
|
30
|
+
self.max_concurrent_tasks = max_concurrent_tasks
|
|
31
|
+
self._shutdown_event = asyncio.Event()
|
|
32
|
+
|
|
33
|
+
logger.info(
|
|
34
|
+
"TaskManager initialized with max_concurrent_tasks: %d, default_timeout: %.1f",
|
|
35
|
+
max_concurrent_tasks,
|
|
36
|
+
default_timeout,
|
|
37
|
+
extra={"max_concurrent_tasks": max_concurrent_tasks, "default_timeout": default_timeout},
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def task_count(self) -> int:
|
|
42
|
+
"""."""
|
|
43
|
+
return len(self.tasks_sessions)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def running_tasks(self) -> set[str]:
|
|
47
|
+
"""."""
|
|
48
|
+
return {task_id for task_id, task in self.tasks.items() if not task.done()}
|
|
49
|
+
|
|
50
|
+
async def _cleanup_task(self, task_id: str) -> None:
|
|
51
|
+
"""Clean up task resources."""
|
|
52
|
+
logger.debug("Cleaning up resources for task: '%s'", task_id, extra={"task_id": task_id})
|
|
53
|
+
if task_id in self.tasks_sessions:
|
|
54
|
+
await self.tasks_sessions[task_id].db.close()
|
|
55
|
+
# Remove from collections
|
|
56
|
+
|
|
57
|
+
async def _task_wrapper( # noqa: C901, PLR0915
|
|
58
|
+
self,
|
|
59
|
+
task_id: str,
|
|
60
|
+
coro: Coroutine[Any, Any, None],
|
|
61
|
+
session: TaskSession,
|
|
62
|
+
) -> asyncio.Task[None]:
|
|
63
|
+
"""Task wrapper that runs main, heartbeat, and listener concurrently.
|
|
64
|
+
|
|
65
|
+
The first to finish determines the outcome. Returns a Task that the
|
|
66
|
+
caller can await externally.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
asyncio.Task[None]: The supervisor task managing the lifecycle.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
async def signal_wrapper() -> None:
|
|
73
|
+
try:
|
|
74
|
+
await self.channel.create(
|
|
75
|
+
"tasks",
|
|
76
|
+
SignalMessage(
|
|
77
|
+
task_id=task_id,
|
|
78
|
+
status=session.status,
|
|
79
|
+
action=SignalType.START,
|
|
80
|
+
).model_dump(),
|
|
81
|
+
)
|
|
82
|
+
await session.listen_signals()
|
|
83
|
+
except asyncio.CancelledError:
|
|
84
|
+
logger.debug("Signal listener cancelled", extra={"task_id": task_id})
|
|
85
|
+
finally:
|
|
86
|
+
await self.channel.create(
|
|
87
|
+
"tasks",
|
|
88
|
+
SignalMessage(
|
|
89
|
+
task_id=task_id,
|
|
90
|
+
status=session.status,
|
|
91
|
+
action=SignalType.STOP,
|
|
92
|
+
).model_dump(),
|
|
93
|
+
)
|
|
94
|
+
logger.info("Signal listener ended", extra={"task_id": task_id})
|
|
95
|
+
|
|
96
|
+
async def heartbeat_wrapper() -> None:
|
|
97
|
+
try:
|
|
98
|
+
await session.generate_heartbeats()
|
|
99
|
+
except asyncio.CancelledError:
|
|
100
|
+
logger.debug("Signal listener cancelled", extra={"task_id": task_id})
|
|
101
|
+
finally:
|
|
102
|
+
logger.info("Heartbeat task ended", extra={"task_id": task_id})
|
|
103
|
+
|
|
104
|
+
async def supervisor() -> None:
|
|
105
|
+
session.started_at = datetime.datetime.now(datetime.timezone.utc)
|
|
106
|
+
session.status = TaskStatus.RUNNING
|
|
107
|
+
|
|
108
|
+
main_task = asyncio.create_task(coro, name=f"{task_id}_main")
|
|
109
|
+
hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
|
|
110
|
+
sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
done, pending = await asyncio.wait(
|
|
114
|
+
[main_task, sig_task, hb_task],
|
|
115
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# One task completed -> cancel the others
|
|
119
|
+
for t in pending:
|
|
120
|
+
t.cancel()
|
|
121
|
+
|
|
122
|
+
# Propagate exception/result from the finished task
|
|
123
|
+
completed = next(iter(done))
|
|
124
|
+
await completed
|
|
125
|
+
|
|
126
|
+
logger.critical(f"{completed=} | {main_task=} | {hb_task=} | {sig_task=}")
|
|
127
|
+
|
|
128
|
+
if completed is main_task:
|
|
129
|
+
session.status = TaskStatus.COMPLETED
|
|
130
|
+
elif completed is sig_task or (completed is hb_task and sig_task.done()):
|
|
131
|
+
logger.critical(f"{sig_task=}")
|
|
132
|
+
session.status = TaskStatus.CANCELLED
|
|
133
|
+
elif completed is hb_task:
|
|
134
|
+
session.status = TaskStatus.FAILED
|
|
135
|
+
msg = f"Heartbeat stopped for {task_id}"
|
|
136
|
+
raise RuntimeError(msg) # noqa: TRY301
|
|
137
|
+
|
|
138
|
+
except asyncio.CancelledError:
|
|
139
|
+
session.status = TaskStatus.CANCELLED
|
|
140
|
+
raise
|
|
141
|
+
except Exception:
|
|
142
|
+
session.status = TaskStatus.FAILED
|
|
143
|
+
raise
|
|
144
|
+
finally:
|
|
145
|
+
session.completed_at = datetime.datetime.now(datetime.timezone.utc)
|
|
146
|
+
# Ensure all tasks are cleaned up
|
|
147
|
+
for t in [main_task, hb_task, sig_task]:
|
|
148
|
+
if not t.done():
|
|
149
|
+
t.cancel()
|
|
150
|
+
await asyncio.gather(main_task, hb_task, sig_task, return_exceptions=True)
|
|
151
|
+
|
|
152
|
+
# Return the supervisor task to be awaited outside
|
|
153
|
+
return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
|
|
154
|
+
|
|
155
|
+
async def create_task(
|
|
156
|
+
self,
|
|
157
|
+
task_id: str,
|
|
158
|
+
module: BaseModule,
|
|
159
|
+
coro: Coroutine[Any, Any, None],
|
|
160
|
+
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
161
|
+
connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
|
|
162
|
+
) -> None:
|
|
163
|
+
"""Create and start a new managed task.
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
ValueError: task_id duplicated
|
|
167
|
+
RuntimeError: task overload
|
|
168
|
+
"""
|
|
169
|
+
if task_id in self.tasks:
|
|
170
|
+
# close Coroutine during runtime
|
|
171
|
+
coro.close()
|
|
172
|
+
logger.warning("Task creation failed - task already exists: '%s'", task_id, extra={"task_id": task_id})
|
|
173
|
+
msg = f"Task {task_id} already exists"
|
|
174
|
+
raise ValueError(msg)
|
|
175
|
+
|
|
176
|
+
if len(self.tasks) >= self.max_concurrent_tasks:
|
|
177
|
+
coro.close()
|
|
178
|
+
logger.error(
|
|
179
|
+
"Task creation failed - max concurrent tasks reached: %d",
|
|
180
|
+
self.max_concurrent_tasks,
|
|
181
|
+
extra={
|
|
182
|
+
"task_id": task_id,
|
|
183
|
+
"current_count": len(self.tasks),
|
|
184
|
+
"max_concurrent": self.max_concurrent_tasks,
|
|
185
|
+
},
|
|
186
|
+
)
|
|
187
|
+
msg = f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached"
|
|
188
|
+
raise RuntimeError(msg)
|
|
189
|
+
|
|
190
|
+
logger.info(
|
|
191
|
+
"Creating new task: '%s'",
|
|
192
|
+
task_id,
|
|
193
|
+
extra={
|
|
194
|
+
"task_id": task_id,
|
|
195
|
+
"heartbeat_interval": heartbeat_interval,
|
|
196
|
+
"connection_timeout": connection_timeout,
|
|
197
|
+
},
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
# Initialize components
|
|
202
|
+
channel: SurrealDBConnection = SurrealDBConnection("task_manager", connection_timeout)
|
|
203
|
+
await channel.init_surreal_instance()
|
|
204
|
+
session = TaskSession(task_id, channel, module, heartbeat_interval)
|
|
205
|
+
|
|
206
|
+
self.tasks_sessions[task_id] = session
|
|
207
|
+
|
|
208
|
+
# Create wrapper task
|
|
209
|
+
self.tasks[task_id] = asyncio.create_task(self._task_wrapper(task_id, coro, session), name=task_id)
|
|
210
|
+
|
|
211
|
+
logger.info(
|
|
212
|
+
"Task created successfully: '%s'",
|
|
213
|
+
task_id,
|
|
214
|
+
extra={
|
|
215
|
+
"task_id": task_id,
|
|
216
|
+
"total_tasks": len(self.tasks),
|
|
217
|
+
},
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
except Exception as e:
|
|
221
|
+
logger.error(
|
|
222
|
+
"Failed to create task: '%s'", task_id, extra={"task_id": task_id, "error": str(e)}, exc_info=True
|
|
223
|
+
)
|
|
224
|
+
# Cleanup on failure
|
|
225
|
+
await self._cleanup_task(task_id)
|
|
226
|
+
raise
|
|
227
|
+
|
|
228
|
+
async def send_signal(self, task_id: str, signal_type: str, payload: dict) -> bool:
|
|
229
|
+
"""Send signal to a specific task.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
bool: True if the task sent successfully the given signal, False otherwise.
|
|
233
|
+
"""
|
|
234
|
+
if task_id not in self.tasks_sessions:
|
|
235
|
+
logger.warning(
|
|
236
|
+
"Cannot send signal - task not found: '%s'",
|
|
237
|
+
task_id,
|
|
238
|
+
extra={"task_id": task_id, "signal_type": signal_type},
|
|
239
|
+
)
|
|
240
|
+
return False
|
|
241
|
+
|
|
242
|
+
logger.info(
|
|
243
|
+
"Sending signal '%s' to task: '%s'",
|
|
244
|
+
signal_type,
|
|
245
|
+
task_id,
|
|
246
|
+
extra={"task_id": task_id, "signal_type": signal_type, "payload": payload},
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
await self.channel.update("tasks", signal_type, payload)
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
async def cancel_task(self, task_id: str, timeout: float | None = None) -> bool:
|
|
253
|
+
"""Cancel a task with graceful shutdown and fallback.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
bool: True if the task was cancelled successfully, False otherwise.
|
|
257
|
+
"""
|
|
258
|
+
if task_id not in self.tasks:
|
|
259
|
+
logger.warning("Cannot cancel - task not found: '%s'", task_id, extra={"task_id": task_id})
|
|
260
|
+
return True
|
|
261
|
+
|
|
262
|
+
timeout = timeout or self.default_timeout
|
|
263
|
+
task = self.tasks[task_id]
|
|
264
|
+
|
|
265
|
+
logger.info(
|
|
266
|
+
"Initiating task cancellation: '%s', timeout: %.1fs",
|
|
267
|
+
task_id,
|
|
268
|
+
timeout,
|
|
269
|
+
extra={"task_id": task_id, "timeout": timeout},
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
# Phase 1: Cooperative cancellation
|
|
274
|
+
# await self.send_signal(task_id, "cancel") # noqa: ERA001
|
|
275
|
+
|
|
276
|
+
# Wait for graceful shutdown
|
|
277
|
+
await asyncio.wait_for(task, timeout=timeout)
|
|
278
|
+
|
|
279
|
+
logger.info("Task cancelled gracefully: '%s'", task_id, extra={"task_id": task_id})
|
|
280
|
+
|
|
281
|
+
except asyncio.TimeoutError:
|
|
282
|
+
logger.warning(
|
|
283
|
+
"Graceful cancellation timed out for task: '%s', forcing cancellation",
|
|
284
|
+
task_id,
|
|
285
|
+
extra={"task_id": task_id, "timeout": timeout},
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Phase 2: Force cancellation
|
|
289
|
+
task.cancel()
|
|
290
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
291
|
+
await task
|
|
292
|
+
|
|
293
|
+
logger.warning("Task force-cancelled: '%s'", task_id, extra={"task_id": task_id})
|
|
294
|
+
return True
|
|
295
|
+
|
|
296
|
+
except Exception as e:
|
|
297
|
+
logger.error(
|
|
298
|
+
"Error during task cancellation: '%s'",
|
|
299
|
+
task_id,
|
|
300
|
+
extra={"task_id": task_id, "error": str(e)},
|
|
301
|
+
exc_info=True,
|
|
302
|
+
)
|
|
303
|
+
return False
|
|
304
|
+
return True
|
|
305
|
+
|
|
306
|
+
async def clean_session(self, task_id: str) -> bool:
|
|
307
|
+
"""Clean up task session without cancelling the task.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
bool: True if the task was cleaned successfully, False otherwise.
|
|
311
|
+
"""
|
|
312
|
+
if task_id not in self.tasks_sessions:
|
|
313
|
+
logger.warning("Cannot clean session - task not found: '%s'", task_id, extra={"task_id": task_id})
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
await self.tasks_sessions[task_id].module.stop()
|
|
317
|
+
await self.cancel_task(task_id)
|
|
318
|
+
|
|
319
|
+
logger.info("Cleaning up session for task: '%s'", task_id, extra={"task_id": task_id})
|
|
320
|
+
self.tasks_sessions.pop(task_id, None)
|
|
321
|
+
return True
|
|
322
|
+
|
|
323
|
+
async def pause_task(self, task_id: str) -> bool:
|
|
324
|
+
"""Pause a running task.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
bool: True if the task was paused successfully, False otherwise.
|
|
328
|
+
"""
|
|
329
|
+
return await self.send_signal(task_id, "pause", {})
|
|
330
|
+
|
|
331
|
+
async def resume_task(self, task_id: str) -> bool:
|
|
332
|
+
"""Resume a paused task.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
bool: True if the task was paused successfully, False otherwise.
|
|
336
|
+
"""
|
|
337
|
+
return await self.send_signal(task_id, "resume", {})
|
|
338
|
+
|
|
339
|
+
async def get_task_status(self, task_id: str) -> bool:
|
|
340
|
+
"""Request status from a task.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
bool: True if the task was paused successfully, False otherwise.
|
|
344
|
+
"""
|
|
345
|
+
return await self.send_signal(task_id, "status", {})
|
|
346
|
+
|
|
347
|
+
async def cancel_all_tasks(self, timeout: float | None = None) -> dict[str, bool]:
|
|
348
|
+
"""Cancel all running tasks.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
dict[str: bool]: True if the tasks were paused successfully, False otherwise.
|
|
352
|
+
"""
|
|
353
|
+
timeout = timeout or self.default_timeout
|
|
354
|
+
task_ids = list(self.running_tasks)
|
|
355
|
+
|
|
356
|
+
logger.info(
|
|
357
|
+
"Cancelling all tasks: %d tasks", len(task_ids), extra={"task_count": len(task_ids), "timeout": timeout}
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
results = {}
|
|
361
|
+
for task_id in task_ids:
|
|
362
|
+
results[task_id] = await self.cancel_task(task_id, timeout)
|
|
363
|
+
|
|
364
|
+
return results
|
|
365
|
+
|
|
366
|
+
async def shutdown(self, timeout: float = 30.0) -> None:
|
|
367
|
+
"""Graceful shutdown of all tasks."""
|
|
368
|
+
logger.info(
|
|
369
|
+
"TaskManager shutdown initiated, timeout: %.1fs",
|
|
370
|
+
timeout,
|
|
371
|
+
extra={"timeout": timeout, "active_tasks": len(self.running_tasks)},
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
self._shutdown_event.set()
|
|
375
|
+
results = await self.cancel_all_tasks(timeout)
|
|
376
|
+
|
|
377
|
+
failed_tasks = [task_id for task_id, success in results.items() if not success]
|
|
378
|
+
if failed_tasks:
|
|
379
|
+
logger.error(
|
|
380
|
+
"Failed to cancel %d tasks during shutdown: %s",
|
|
381
|
+
len(failed_tasks),
|
|
382
|
+
failed_tasks,
|
|
383
|
+
extra={"failed_tasks": failed_tasks, "failed_count": len(failed_tasks)},
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
logger.info(
|
|
387
|
+
"TaskManager shutdown completed, cancelled: %d, failed: %d",
|
|
388
|
+
len(results) - len(failed_tasks),
|
|
389
|
+
len(failed_tasks),
|
|
390
|
+
extra={"cancelled_count": len(results) - len(failed_tasks), "failed_count": len(failed_tasks)},
|
|
391
|
+
)
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import datetime
|
|
5
|
+
from collections.abc import AsyncGenerator
|
|
6
|
+
|
|
7
|
+
from digitalkin.logger import logger
|
|
8
|
+
from digitalkin.models.module.task_monitor import HeartbeatMessage, SignalMessage, SignalType, TaskStatus
|
|
9
|
+
from digitalkin.modules._base_module import BaseModule
|
|
10
|
+
from digitalkin.modules.job_manager.surrealdb_repository import SurrealDBConnection
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TaskSession:
|
|
14
|
+
"""Task Session with lifecycle management.
|
|
15
|
+
|
|
16
|
+
The Session defined the whole lifecycle of a task as an epheneral context.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
db: SurrealDBConnection
|
|
20
|
+
module: BaseModule
|
|
21
|
+
|
|
22
|
+
status: TaskStatus
|
|
23
|
+
signal_queue: AsyncGenerator | None
|
|
24
|
+
|
|
25
|
+
task_id: str
|
|
26
|
+
signal_record_id: str | None
|
|
27
|
+
heartbeat_record_id: str | None
|
|
28
|
+
|
|
29
|
+
started_at: datetime.datetime | None
|
|
30
|
+
completed_at: datetime.datetime | None
|
|
31
|
+
|
|
32
|
+
is_cancelled: asyncio.Event
|
|
33
|
+
_paused: asyncio.Event
|
|
34
|
+
_heartbeat_interval: datetime.timedelta
|
|
35
|
+
_last_heartbeat: datetime.datetime
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
task_id: str,
|
|
40
|
+
db: SurrealDBConnection,
|
|
41
|
+
module: BaseModule,
|
|
42
|
+
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
43
|
+
) -> None:
|
|
44
|
+
"""."""
|
|
45
|
+
self.db = db
|
|
46
|
+
self.module = module
|
|
47
|
+
|
|
48
|
+
self.status = TaskStatus.PENDING
|
|
49
|
+
self.queue: asyncio.Queue = asyncio.Queue()
|
|
50
|
+
|
|
51
|
+
self.task_id = task_id
|
|
52
|
+
self.heartbeat = None
|
|
53
|
+
self.started_at = None
|
|
54
|
+
self.completed_at = None
|
|
55
|
+
|
|
56
|
+
self.signal_record_id = None
|
|
57
|
+
self.heartbeat_record_id = None
|
|
58
|
+
|
|
59
|
+
self.is_cancelled = asyncio.Event()
|
|
60
|
+
self._paused = asyncio.Event()
|
|
61
|
+
self._heartbeat_interval = heartbeat_interval
|
|
62
|
+
|
|
63
|
+
logger.info(
|
|
64
|
+
"TaskContext initialized for task: '%s'",
|
|
65
|
+
task_id,
|
|
66
|
+
extra={"task_id": task_id, "heartbeat_interval": heartbeat_interval},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def cancelled(self) -> bool:
|
|
71
|
+
"""."""
|
|
72
|
+
return self.is_cancelled.is_set()
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def paused(self) -> bool:
|
|
76
|
+
"""."""
|
|
77
|
+
return self._paused.is_set()
|
|
78
|
+
|
|
79
|
+
async def send_heartbeat(self) -> bool:
|
|
80
|
+
"""Rate-limited heartbeat with connection resilience.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
bool: True if heartbeat was successful, False otherwise
|
|
84
|
+
"""
|
|
85
|
+
heartbeat = HeartbeatMessage(
|
|
86
|
+
task_id=self.task_id,
|
|
87
|
+
timestamp=datetime.datetime.now(datetime.timezone.utc),
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if self.heartbeat_record_id is None:
|
|
91
|
+
try:
|
|
92
|
+
success = await self.db.create("heartbeats", heartbeat.model_dump())
|
|
93
|
+
logger.critical(f"{success=} | {'code' not in success}")
|
|
94
|
+
if "code" not in success:
|
|
95
|
+
self.heartbeat_record_id = success.get("id") # type: ignore
|
|
96
|
+
self._last_heartbeat = heartbeat.timestamp
|
|
97
|
+
return True
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.error(
|
|
100
|
+
"Heartbeat exception for task: '%s'",
|
|
101
|
+
self.task_id,
|
|
102
|
+
extra={"task_id": self.task_id, "error": str(e)},
|
|
103
|
+
exc_info=True,
|
|
104
|
+
)
|
|
105
|
+
logger.error(
|
|
106
|
+
"Initial heartbeat failed for task: '%s'",
|
|
107
|
+
self.task_id,
|
|
108
|
+
extra={"task_id": self.task_id},
|
|
109
|
+
)
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
if (heartbeat.timestamp - self._last_heartbeat) < self._heartbeat_interval:
|
|
113
|
+
logger.debug(
|
|
114
|
+
"Heartbeat skipped due to rate limiting for task: '%s' | delta=%s",
|
|
115
|
+
self.task_id,
|
|
116
|
+
heartbeat.timestamp - self._last_heartbeat,
|
|
117
|
+
)
|
|
118
|
+
return True
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
success = await self.db.merge("heartbeats", self.heartbeat_record_id, heartbeat.model_dump())
|
|
122
|
+
if "code" not in success:
|
|
123
|
+
self._last_heartbeat = heartbeat.timestamp
|
|
124
|
+
return True
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.error(
|
|
127
|
+
"Heartbeat exception for task: '%s'",
|
|
128
|
+
self.task_id,
|
|
129
|
+
extra={"task_id": self.task_id, "error": str(e)},
|
|
130
|
+
exc_info=True,
|
|
131
|
+
)
|
|
132
|
+
logger.warning(
|
|
133
|
+
"Heartbeat failed for task: '%s'",
|
|
134
|
+
self.task_id,
|
|
135
|
+
extra={"task_id": self.task_id},
|
|
136
|
+
)
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
async def generate_heartbeats(self) -> None:
|
|
140
|
+
"""Periodic heartbeat generator with cancellation support."""
|
|
141
|
+
logger.critical("Heartbeat started")
|
|
142
|
+
while not self.cancelled:
|
|
143
|
+
logger.debug(f"Heartbeat tick for task: '{self.task_id}' | {self.cancelled=}")
|
|
144
|
+
success = await self.send_heartbeat()
|
|
145
|
+
if not success:
|
|
146
|
+
logger.error("Heartbeat failed, cancelling task: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
147
|
+
await self._handle_cancel()
|
|
148
|
+
break
|
|
149
|
+
await asyncio.sleep(self._heartbeat_interval.total_seconds())
|
|
150
|
+
|
|
151
|
+
async def wait_if_paused(self) -> None:
|
|
152
|
+
"""Block execution if task is paused."""
|
|
153
|
+
if self._paused.is_set():
|
|
154
|
+
logger.info("Task paused, waiting for resume: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
155
|
+
await self._paused.wait()
|
|
156
|
+
|
|
157
|
+
async def listen_signals(self) -> None: # noqa: C901
|
|
158
|
+
"""Enhanced signal listener with comprehensive handling.
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
CancelledError: Asyncio when task cancelling
|
|
162
|
+
"""
|
|
163
|
+
logger.info("Signal listener started for task: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
164
|
+
if self.signal_record_id is None:
|
|
165
|
+
self.signal_record_id = (await self.db.select_by_task_id("tasks", self.task_id)).get("id")
|
|
166
|
+
|
|
167
|
+
live_id, live_signals = await self.db.start_live("tasks")
|
|
168
|
+
try:
|
|
169
|
+
async for signal in live_signals:
|
|
170
|
+
logger.critical("Signal received for task '%s': %s", self.task_id, signal)
|
|
171
|
+
if self.cancelled:
|
|
172
|
+
break
|
|
173
|
+
|
|
174
|
+
if signal is None or signal["id"] == self.signal_record_id or "payload" not in signal:
|
|
175
|
+
continue
|
|
176
|
+
|
|
177
|
+
if signal["action"] == "cancel":
|
|
178
|
+
await self._handle_cancel()
|
|
179
|
+
elif signal["action"] == "pause":
|
|
180
|
+
await self._handle_pause()
|
|
181
|
+
elif signal["action"] == "resume":
|
|
182
|
+
await self._handle_resume()
|
|
183
|
+
elif signal["action"] == "status":
|
|
184
|
+
await self._handle_status_request()
|
|
185
|
+
|
|
186
|
+
except asyncio.CancelledError:
|
|
187
|
+
logger.debug("Signal listener cancelled for task: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
188
|
+
raise
|
|
189
|
+
except Exception as e:
|
|
190
|
+
logger.error(
|
|
191
|
+
"Signal listener fatal error for task: '%s'",
|
|
192
|
+
self.task_id,
|
|
193
|
+
extra={"task_id": self.task_id, "error": str(e)},
|
|
194
|
+
exc_info=True,
|
|
195
|
+
)
|
|
196
|
+
finally:
|
|
197
|
+
await self.db.stop_live(live_id)
|
|
198
|
+
logger.info("Signal listener stopped for task: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
199
|
+
|
|
200
|
+
async def _handle_cancel(self) -> None:
|
|
201
|
+
"""Idempotent cancellation with acknowledgment."""
|
|
202
|
+
logger.critical("Handle cancel called")
|
|
203
|
+
if self.is_cancelled.is_set():
|
|
204
|
+
logger.debug(
|
|
205
|
+
"Cancel signal ignored - task already cancelled: '%s'", self.task_id, extra={"task_id": self.task_id}
|
|
206
|
+
)
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
logger.info("Cancelling task: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
210
|
+
|
|
211
|
+
self.status = TaskStatus.CANCELLED
|
|
212
|
+
self.is_cancelled.set()
|
|
213
|
+
|
|
214
|
+
# Resume if paused so cancellation can proceed
|
|
215
|
+
if self._paused.is_set():
|
|
216
|
+
self._paused.set()
|
|
217
|
+
|
|
218
|
+
await self.db.update(
|
|
219
|
+
"tasks",
|
|
220
|
+
self.signal_record_id, # type: ignore
|
|
221
|
+
SignalMessage(
|
|
222
|
+
task_id=self.task_id,
|
|
223
|
+
action=SignalType.ACK_CANCEL,
|
|
224
|
+
status=self.status,
|
|
225
|
+
).model_dump(),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
async def _handle_pause(self) -> None:
|
|
229
|
+
"""Pause task execution."""
|
|
230
|
+
if not self._paused.is_set():
|
|
231
|
+
logger.info("Pausing task: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
232
|
+
self._paused.set()
|
|
233
|
+
|
|
234
|
+
await self.db.update(
|
|
235
|
+
"tasks",
|
|
236
|
+
self.signal_record_id, # type: ignore
|
|
237
|
+
SignalMessage(
|
|
238
|
+
task_id=self.task_id,
|
|
239
|
+
action=SignalType.ACK_PAUSE,
|
|
240
|
+
status=self.status,
|
|
241
|
+
).model_dump(),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
async def _handle_resume(self) -> None:
|
|
245
|
+
"""Resume paused task."""
|
|
246
|
+
if self._paused.is_set():
|
|
247
|
+
logger.info("Resuming task: '%s'", self.task_id, extra={"task_id": self.task_id})
|
|
248
|
+
self._paused.clear()
|
|
249
|
+
|
|
250
|
+
await self.db.update(
|
|
251
|
+
"tasks",
|
|
252
|
+
self.signal_record_id, # type: ignore
|
|
253
|
+
SignalMessage(
|
|
254
|
+
task_id=self.task_id,
|
|
255
|
+
action=SignalType.ACK_RESUME,
|
|
256
|
+
status=self.status,
|
|
257
|
+
).model_dump(),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
async def _handle_status_request(self) -> None:
|
|
261
|
+
"""Send current task status."""
|
|
262
|
+
await self.db.update(
|
|
263
|
+
"tasks",
|
|
264
|
+
self.signal_record_id, # type: ignore
|
|
265
|
+
SignalMessage(
|
|
266
|
+
action=SignalType.ACK_STATUS,
|
|
267
|
+
task_id=self.task_id,
|
|
268
|
+
status=self.status,
|
|
269
|
+
).model_dump(),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
logger.debug(
|
|
273
|
+
"Status report sent for task: '%s'",
|
|
274
|
+
self.task_id,
|
|
275
|
+
extra={"task_id": self.task_id},
|
|
276
|
+
)
|
|
@@ -19,7 +19,7 @@ from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, Offse
|
|
|
19
19
|
|
|
20
20
|
from digitalkin.logger import logger
|
|
21
21
|
from digitalkin.models.module import InputModelT, SetupModelT
|
|
22
|
-
from digitalkin.models.module.
|
|
22
|
+
from digitalkin.models.module.task_monitor import TaskStatus
|
|
23
23
|
from digitalkin.modules._base_module import BaseModule
|
|
24
24
|
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
25
25
|
from digitalkin.modules.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
|
|
@@ -279,7 +279,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
279
279
|
msg = "stop_all_modules not implemented in TaskiqJobManager"
|
|
280
280
|
raise NotImplementedError(msg)
|
|
281
281
|
|
|
282
|
-
async def get_module_status(self, job_id: str) ->
|
|
282
|
+
async def get_module_status(self, job_id: str) -> TaskStatus:
|
|
283
283
|
"""Query a module status."""
|
|
284
284
|
msg = "get_module_status not implemented in TaskiqJobManager"
|
|
285
285
|
raise NotImplementedError(msg)
|