digitalkin 0.3.0rc2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/base_job_manager.py +128 -28
- digitalkin/core/job_manager/single_job_manager.py +80 -25
- digitalkin/core/job_manager/taskiq_broker.py +114 -19
- digitalkin/core/job_manager/taskiq_job_manager.py +291 -39
- digitalkin/core/task_manager/base_task_manager.py +539 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +43 -4
- digitalkin/core/task_manager/task_executor.py +249 -0
- digitalkin/core/task_manager/task_session.py +95 -17
- digitalkin/grpc_servers/module_server.py +2 -2
- digitalkin/grpc_servers/module_servicer.py +21 -12
- digitalkin/grpc_servers/registry_server.py +1 -1
- digitalkin/grpc_servers/registry_servicer.py +4 -4
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/models/core/task_monitor.py +17 -0
- digitalkin/models/module/module_context.py +5 -0
- digitalkin/models/module/module_types.py +299 -15
- digitalkin/modules/_base_module.py +66 -28
- digitalkin/services/cost/grpc_cost.py +8 -41
- digitalkin/services/filesystem/grpc_filesystem.py +9 -38
- digitalkin/services/services_config.py +11 -0
- digitalkin/services/services_models.py +3 -1
- digitalkin/services/setup/default_setup.py +5 -6
- digitalkin/services/setup/grpc_setup.py +51 -14
- digitalkin/services/storage/grpc_storage.py +2 -2
- digitalkin/services/user_profile/__init__.py +12 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- digitalkin/utils/__init__.py +28 -0
- digitalkin/utils/dynamic_schema.py +483 -0
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dist-info}/METADATA +8 -8
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dist-info}/RECORD +41 -29
- modules/dynamic_setup_module.py +362 -0
- digitalkin/core/task_manager/task_manager.py +0 -442
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dist-info}/WHEEL +0 -0
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Remote task manager for distributed execution."""
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
from collections.abc import Coroutine
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from digitalkin.core.task_manager.base_task_manager import BaseTaskManager
|
|
8
|
+
from digitalkin.logger import logger
|
|
9
|
+
from digitalkin.modules._base_module import BaseModule
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RemoteTaskManager(BaseTaskManager):
|
|
13
|
+
"""Task manager for distributed/remote execution.
|
|
14
|
+
|
|
15
|
+
Only manages task metadata and signals - actual execution happens in remote workers.
|
|
16
|
+
Suitable for horizontally scaled deployments with Taskiq/Celery workers.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
async def create_task(
|
|
20
|
+
self,
|
|
21
|
+
task_id: str,
|
|
22
|
+
mission_id: str,
|
|
23
|
+
module: BaseModule,
|
|
24
|
+
coro: Coroutine[Any, Any, None],
|
|
25
|
+
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
26
|
+
connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Register task for remote execution (metadata only).
|
|
29
|
+
|
|
30
|
+
Creates TaskSession for signal handling and monitoring, but doesn't execute the coroutine.
|
|
31
|
+
The coroutine will be recreated and executed by a remote worker.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
task_id: Unique identifier for the task
|
|
35
|
+
mission_id: Mission identifier
|
|
36
|
+
module: Module instance for metadata (not executed here)
|
|
37
|
+
coro: Coroutine (will be closed - execution happens in worker)
|
|
38
|
+
heartbeat_interval: Interval between heartbeats
|
|
39
|
+
connection_timeout: Connection timeout for SurrealDB
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If task_id duplicated
|
|
43
|
+
RuntimeError: If task overload
|
|
44
|
+
"""
|
|
45
|
+
# Validation
|
|
46
|
+
await self._validate_task_creation(task_id, mission_id, coro)
|
|
47
|
+
|
|
48
|
+
logger.info(
|
|
49
|
+
"Registering remote task: '%s'",
|
|
50
|
+
task_id,
|
|
51
|
+
extra={
|
|
52
|
+
"mission_id": mission_id,
|
|
53
|
+
"task_id": task_id,
|
|
54
|
+
"heartbeat_interval": heartbeat_interval,
|
|
55
|
+
"connection_timeout": connection_timeout,
|
|
56
|
+
},
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
# Create session for metadata and signal handling
|
|
61
|
+
_channel, _session = await self._create_session(
|
|
62
|
+
task_id, mission_id, module, heartbeat_interval, connection_timeout
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Close coroutine - worker will recreate and execute it
|
|
66
|
+
coro.close()
|
|
67
|
+
|
|
68
|
+
logger.info(
|
|
69
|
+
"Remote task registered: '%s'",
|
|
70
|
+
task_id,
|
|
71
|
+
extra={
|
|
72
|
+
"mission_id": mission_id,
|
|
73
|
+
"task_id": task_id,
|
|
74
|
+
"total_sessions": len(self.tasks_sessions),
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(
|
|
80
|
+
"Failed to register remote task: '%s'",
|
|
81
|
+
task_id,
|
|
82
|
+
extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
|
|
83
|
+
exc_info=True,
|
|
84
|
+
)
|
|
85
|
+
# Cleanup on failure
|
|
86
|
+
await self._cleanup_task(task_id, mission_id=mission_id)
|
|
87
|
+
raise
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""SurrealDB connection management."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import datetime
|
|
4
5
|
import os
|
|
5
6
|
from collections.abc import AsyncGenerator
|
|
@@ -38,6 +39,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
38
39
|
|
|
39
40
|
db: TSurreal
|
|
40
41
|
timeout: datetime.timedelta
|
|
42
|
+
_live_queries: set[UUID] # Track active live queries for cleanup
|
|
41
43
|
|
|
42
44
|
@staticmethod
|
|
43
45
|
def _valid_id(raw_id: str, table_name: str) -> RecordID:
|
|
@@ -82,6 +84,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
82
84
|
self.password = os.getenv("SURREALDB_PASSWORD", "root")
|
|
83
85
|
self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
|
|
84
86
|
self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
|
|
87
|
+
self._live_queries = set() # Initialize live queries tracker
|
|
85
88
|
|
|
86
89
|
async def init_surreal_instance(self) -> None:
|
|
87
90
|
"""Init a SurrealDB connection instance."""
|
|
@@ -92,7 +95,37 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
92
95
|
logger.debug("Successfully connected to SurrealDB")
|
|
93
96
|
|
|
94
97
|
async def close(self) -> None:
|
|
95
|
-
"""Close the SurrealDB connection if it exists.
|
|
98
|
+
"""Close the SurrealDB connection if it exists.
|
|
99
|
+
|
|
100
|
+
This will also kill all active live queries to prevent memory leaks.
|
|
101
|
+
"""
|
|
102
|
+
# Kill all tracked live queries before closing connection
|
|
103
|
+
if self._live_queries:
|
|
104
|
+
logger.debug("Killing %d active live queries before closing", len(self._live_queries))
|
|
105
|
+
live_query_ids = list(self._live_queries)
|
|
106
|
+
|
|
107
|
+
# Kill all queries concurrently, capturing any exceptions
|
|
108
|
+
results = await asyncio.gather(
|
|
109
|
+
*[self.db.kill(live_id) for live_id in live_query_ids], return_exceptions=True
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Process results and track failures
|
|
113
|
+
failed_queries = []
|
|
114
|
+
for live_id, result in zip(live_query_ids, results):
|
|
115
|
+
if isinstance(result, (ConnectionError, TimeoutError, Exception)):
|
|
116
|
+
failed_queries.append((live_id, str(result)))
|
|
117
|
+
else:
|
|
118
|
+
self._live_queries.discard(live_id)
|
|
119
|
+
|
|
120
|
+
# Log aggregated failures once instead of per-query
|
|
121
|
+
if failed_queries:
|
|
122
|
+
logger.warning(
|
|
123
|
+
"Failed to kill %d live queries: %s",
|
|
124
|
+
len(failed_queries),
|
|
125
|
+
failed_queries[:5], # Only log first 5 to avoid log spam
|
|
126
|
+
extra={"total_failed": len(failed_queries)},
|
|
127
|
+
)
|
|
128
|
+
|
|
96
129
|
logger.debug("Closing SurrealDB connection")
|
|
97
130
|
await self.db.close()
|
|
98
131
|
|
|
@@ -208,20 +241,26 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
208
241
|
) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
|
|
209
242
|
"""Create and subscribe to a live SurrealQL query.
|
|
210
243
|
|
|
244
|
+
The live query ID is tracked to ensure proper cleanup on connection close.
|
|
245
|
+
|
|
211
246
|
Args:
|
|
212
247
|
table_name: Name of the table to insert into
|
|
213
248
|
|
|
214
249
|
Returns:
|
|
215
|
-
|
|
250
|
+
tuple[UUID, AsyncGenerator]: Live query ID and subscription generator
|
|
216
251
|
"""
|
|
217
252
|
live_id = await self.db.live(table_name, diff=False)
|
|
253
|
+
self._live_queries.add(live_id) # Track for cleanup
|
|
254
|
+
logger.debug("Started live query %s for table %s (total: %d)", live_id, table_name, len(self._live_queries))
|
|
218
255
|
return live_id, await self.db.subscribe_live(live_id)
|
|
219
256
|
|
|
220
257
|
async def stop_live(self, live_id: UUID) -> None:
|
|
221
258
|
"""Kill a live SurrealQL query.
|
|
222
259
|
|
|
223
260
|
Args:
|
|
224
|
-
live_id:
|
|
261
|
+
live_id: Live query ID to kill
|
|
225
262
|
"""
|
|
226
|
-
logger.debug("
|
|
263
|
+
logger.debug("Killing live query: %s", live_id)
|
|
227
264
|
await self.db.kill(live_id)
|
|
265
|
+
self._live_queries.discard(live_id) # Remove from tracker
|
|
266
|
+
logger.debug("Stopped live query %s (remaining: %d)", live_id, len(self._live_queries))
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""Task executor for running tasks with full lifecycle management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import datetime
|
|
5
|
+
from collections.abc import Coroutine
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
9
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
10
|
+
from digitalkin.logger import logger
|
|
11
|
+
from digitalkin.models.core.task_monitor import (
|
|
12
|
+
CancellationReason,
|
|
13
|
+
SignalMessage,
|
|
14
|
+
SignalType,
|
|
15
|
+
TaskStatus,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TaskExecutor:
|
|
20
|
+
"""Executes tasks with the supervisor pattern (main + heartbeat + signal listener).
|
|
21
|
+
|
|
22
|
+
Pure execution logic - no task registry or orchestration.
|
|
23
|
+
Used by workers to run distributed tasks or by TaskManager for local execution.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
async def execute_task( # noqa: C901, PLR0915
|
|
28
|
+
task_id: str,
|
|
29
|
+
mission_id: str,
|
|
30
|
+
coro: Coroutine[Any, Any, None],
|
|
31
|
+
session: TaskSession,
|
|
32
|
+
channel: SurrealDBConnection,
|
|
33
|
+
) -> asyncio.Task[None]:
|
|
34
|
+
"""Execute a task using the supervisor pattern.
|
|
35
|
+
|
|
36
|
+
Runs three concurrent sub-tasks:
|
|
37
|
+
- Main coroutine (the actual work)
|
|
38
|
+
- Heartbeat generator (sends heartbeats to SurrealDB)
|
|
39
|
+
- Signal listener (watches for stop/pause/resume signals)
|
|
40
|
+
|
|
41
|
+
The first task to complete determines the outcome.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
task_id: Unique identifier for the task
|
|
45
|
+
mission_id: Mission identifier for the task
|
|
46
|
+
coro: The coroutine to execute (module.start(...))
|
|
47
|
+
session: TaskSession for state management
|
|
48
|
+
channel: SurrealDB connection for signals
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
asyncio.Task: The supervisor task managing the lifecycle
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
async def signal_wrapper() -> None:
|
|
55
|
+
"""Create initial signal record and listen for signals."""
|
|
56
|
+
try:
|
|
57
|
+
await channel.create(
|
|
58
|
+
"tasks",
|
|
59
|
+
SignalMessage(
|
|
60
|
+
task_id=task_id,
|
|
61
|
+
mission_id=mission_id,
|
|
62
|
+
status=session.status,
|
|
63
|
+
action=SignalType.START,
|
|
64
|
+
).model_dump(),
|
|
65
|
+
)
|
|
66
|
+
await session.listen_signals()
|
|
67
|
+
except asyncio.CancelledError:
|
|
68
|
+
logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
69
|
+
finally:
|
|
70
|
+
await channel.create(
|
|
71
|
+
"tasks",
|
|
72
|
+
SignalMessage(
|
|
73
|
+
task_id=task_id,
|
|
74
|
+
mission_id=mission_id,
|
|
75
|
+
status=session.status,
|
|
76
|
+
action=SignalType.STOP,
|
|
77
|
+
).model_dump(),
|
|
78
|
+
)
|
|
79
|
+
logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
80
|
+
|
|
81
|
+
async def heartbeat_wrapper() -> None:
|
|
82
|
+
"""Generate heartbeats for task health monitoring."""
|
|
83
|
+
try:
|
|
84
|
+
await session.generate_heartbeats()
|
|
85
|
+
except asyncio.CancelledError:
|
|
86
|
+
logger.debug("Heartbeat cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
87
|
+
finally:
|
|
88
|
+
logger.info("Heartbeat task ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
89
|
+
|
|
90
|
+
async def supervisor() -> None: # noqa: C901, PLR0912, PLR0915
|
|
91
|
+
"""Supervise the three concurrent tasks and handle outcomes.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
RuntimeError: If the heartbeat task stops unexpectedly.
|
|
95
|
+
asyncio.CancelledError: If the supervisor task is cancelled.
|
|
96
|
+
"""
|
|
97
|
+
session.started_at = datetime.datetime.now(datetime.timezone.utc)
|
|
98
|
+
session.status = TaskStatus.RUNNING
|
|
99
|
+
|
|
100
|
+
# Create tasks with proper exception handling
|
|
101
|
+
main_task = None
|
|
102
|
+
hb_task = None
|
|
103
|
+
sig_task = None
|
|
104
|
+
cleanup_reason = CancellationReason.UNKNOWN
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
main_task = asyncio.create_task(coro, name=f"{task_id}_main")
|
|
108
|
+
hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
|
|
109
|
+
sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
|
|
110
|
+
done, pending = await asyncio.wait(
|
|
111
|
+
[main_task, sig_task, hb_task],
|
|
112
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Determine cleanup reason based on which task completed first
|
|
116
|
+
completed = next(iter(done))
|
|
117
|
+
|
|
118
|
+
if completed is main_task:
|
|
119
|
+
# Main task finished - cleanup is due to success
|
|
120
|
+
cleanup_reason = CancellationReason.SUCCESS_CLEANUP
|
|
121
|
+
elif completed is sig_task or (completed is hb_task and sig_task.done()):
|
|
122
|
+
# Signal task finished - external cancellation
|
|
123
|
+
cleanup_reason = CancellationReason.SIGNAL
|
|
124
|
+
elif completed is hb_task:
|
|
125
|
+
# Heartbeat stopped - failure cleanup
|
|
126
|
+
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
127
|
+
|
|
128
|
+
# Cancel pending tasks with proper reason logging
|
|
129
|
+
if pending:
|
|
130
|
+
pending_names = [t.get_name() for t in pending]
|
|
131
|
+
logger.debug(
|
|
132
|
+
"Cancelling pending tasks: %s, reason: %s",
|
|
133
|
+
pending_names,
|
|
134
|
+
cleanup_reason.value,
|
|
135
|
+
extra={
|
|
136
|
+
"mission_id": mission_id,
|
|
137
|
+
"task_id": task_id,
|
|
138
|
+
"pending_tasks": pending_names,
|
|
139
|
+
"cancellation_reason": cleanup_reason.value,
|
|
140
|
+
},
|
|
141
|
+
)
|
|
142
|
+
for t in pending:
|
|
143
|
+
t.cancel()
|
|
144
|
+
|
|
145
|
+
# Propagate exception/result from the finished task
|
|
146
|
+
await completed
|
|
147
|
+
|
|
148
|
+
# Determine final status based on which task completed
|
|
149
|
+
if completed is main_task:
|
|
150
|
+
session.status = TaskStatus.COMPLETED
|
|
151
|
+
logger.info(
|
|
152
|
+
"Main task completed successfully",
|
|
153
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
154
|
+
)
|
|
155
|
+
elif completed is sig_task or (completed is hb_task and sig_task.done()):
|
|
156
|
+
session.status = TaskStatus.CANCELLED
|
|
157
|
+
session.cancellation_reason = CancellationReason.SIGNAL
|
|
158
|
+
logger.info(
|
|
159
|
+
"Task cancelled via external signal",
|
|
160
|
+
extra={
|
|
161
|
+
"mission_id": mission_id,
|
|
162
|
+
"task_id": task_id,
|
|
163
|
+
"cancellation_reason": CancellationReason.SIGNAL.value,
|
|
164
|
+
},
|
|
165
|
+
)
|
|
166
|
+
elif completed is hb_task:
|
|
167
|
+
session.status = TaskStatus.FAILED
|
|
168
|
+
session.cancellation_reason = CancellationReason.HEARTBEAT_FAILURE
|
|
169
|
+
logger.error(
|
|
170
|
+
"Heartbeat stopped unexpectedly for task: '%s'",
|
|
171
|
+
task_id,
|
|
172
|
+
extra={
|
|
173
|
+
"mission_id": mission_id,
|
|
174
|
+
"task_id": task_id,
|
|
175
|
+
"cancellation_reason": CancellationReason.HEARTBEAT_FAILURE.value,
|
|
176
|
+
},
|
|
177
|
+
)
|
|
178
|
+
msg = f"Heartbeat stopped for {task_id}"
|
|
179
|
+
raise RuntimeError(msg) # noqa: TRY301
|
|
180
|
+
|
|
181
|
+
except asyncio.CancelledError:
|
|
182
|
+
session.status = TaskStatus.CANCELLED
|
|
183
|
+
# Only set reason if not already set (preserve original reason)
|
|
184
|
+
logger.info(
|
|
185
|
+
"Task cancelled externally: '%s', reason: %s",
|
|
186
|
+
task_id,
|
|
187
|
+
session.cancellation_reason.value,
|
|
188
|
+
extra={
|
|
189
|
+
"mission_id": mission_id,
|
|
190
|
+
"task_id": task_id,
|
|
191
|
+
"cancellation_reason": session.cancellation_reason.value,
|
|
192
|
+
},
|
|
193
|
+
)
|
|
194
|
+
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
195
|
+
raise
|
|
196
|
+
except Exception:
|
|
197
|
+
session.status = TaskStatus.FAILED
|
|
198
|
+
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
199
|
+
logger.exception(
|
|
200
|
+
"Task failed with exception: '%s'",
|
|
201
|
+
task_id,
|
|
202
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
203
|
+
)
|
|
204
|
+
raise
|
|
205
|
+
finally:
|
|
206
|
+
session.completed_at = datetime.datetime.now(datetime.timezone.utc)
|
|
207
|
+
# Ensure all tasks are cleaned up with proper reason
|
|
208
|
+
tasks_to_cleanup = [t for t in [main_task, hb_task, sig_task] if t is not None and not t.done()]
|
|
209
|
+
if tasks_to_cleanup:
|
|
210
|
+
cleanup_names = [t.get_name() for t in tasks_to_cleanup]
|
|
211
|
+
logger.debug(
|
|
212
|
+
"Final cleanup of %d remaining tasks: %s, reason: %s",
|
|
213
|
+
len(tasks_to_cleanup),
|
|
214
|
+
cleanup_names,
|
|
215
|
+
cleanup_reason.value,
|
|
216
|
+
extra={
|
|
217
|
+
"mission_id": mission_id,
|
|
218
|
+
"task_id": task_id,
|
|
219
|
+
"cleanup_count": len(tasks_to_cleanup),
|
|
220
|
+
"cleanup_tasks": cleanup_names,
|
|
221
|
+
"cancellation_reason": cleanup_reason.value,
|
|
222
|
+
},
|
|
223
|
+
)
|
|
224
|
+
for t in tasks_to_cleanup:
|
|
225
|
+
t.cancel()
|
|
226
|
+
await asyncio.gather(*tasks_to_cleanup, return_exceptions=True)
|
|
227
|
+
|
|
228
|
+
duration = (
|
|
229
|
+
(session.completed_at - session.started_at).total_seconds()
|
|
230
|
+
if session.started_at and session.completed_at
|
|
231
|
+
else None
|
|
232
|
+
)
|
|
233
|
+
logger.info(
|
|
234
|
+
"Task execution completed: '%s', status: %s, reason: %s, duration: %.2fs",
|
|
235
|
+
task_id,
|
|
236
|
+
session.status.value,
|
|
237
|
+
session.cancellation_reason.value if session.status == TaskStatus.CANCELLED else "n/a",
|
|
238
|
+
duration or 0,
|
|
239
|
+
extra={
|
|
240
|
+
"mission_id": mission_id,
|
|
241
|
+
"task_id": task_id,
|
|
242
|
+
"status": session.status.value,
|
|
243
|
+
"cancellation_reason": session.cancellation_reason.value,
|
|
244
|
+
"duration": duration,
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Return the supervisor task to be awaited by caller
|
|
249
|
+
return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
|
|
@@ -6,7 +6,13 @@ from collections.abc import AsyncGenerator
|
|
|
6
6
|
|
|
7
7
|
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
8
8
|
from digitalkin.logger import logger
|
|
9
|
-
from digitalkin.models.core.task_monitor import
|
|
9
|
+
from digitalkin.models.core.task_monitor import (
|
|
10
|
+
CancellationReason,
|
|
11
|
+
HeartbeatMessage,
|
|
12
|
+
SignalMessage,
|
|
13
|
+
SignalType,
|
|
14
|
+
TaskStatus,
|
|
15
|
+
)
|
|
10
16
|
from digitalkin.modules._base_module import BaseModule
|
|
11
17
|
|
|
12
18
|
|
|
@@ -31,6 +37,7 @@ class TaskSession:
|
|
|
31
37
|
completed_at: datetime.datetime | None
|
|
32
38
|
|
|
33
39
|
is_cancelled: asyncio.Event
|
|
40
|
+
cancellation_reason: CancellationReason
|
|
34
41
|
_paused: asyncio.Event
|
|
35
42
|
_heartbeat_interval: datetime.timedelta
|
|
36
43
|
_last_heartbeat: datetime.datetime
|
|
@@ -58,6 +65,7 @@ class TaskSession:
|
|
|
58
65
|
self.module = module
|
|
59
66
|
|
|
60
67
|
self.status = TaskStatus.PENDING
|
|
68
|
+
# Bounded queue to prevent unbounded memory growth (max 1000 items)
|
|
61
69
|
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
|
62
70
|
|
|
63
71
|
self.task_id = task_id
|
|
@@ -71,6 +79,7 @@ class TaskSession:
|
|
|
71
79
|
self.heartbeat_record_id = None
|
|
72
80
|
|
|
73
81
|
self.is_cancelled = asyncio.Event()
|
|
82
|
+
self.cancellation_reason = CancellationReason.UNKNOWN
|
|
74
83
|
self._paused = asyncio.Event()
|
|
75
84
|
self._heartbeat_interval = heartbeat_interval
|
|
76
85
|
|
|
@@ -152,17 +161,26 @@ class TaskSession:
|
|
|
152
161
|
|
|
153
162
|
async def generate_heartbeats(self) -> None:
|
|
154
163
|
"""Periodic heartbeat generator with cancellation support."""
|
|
155
|
-
logger.debug(
|
|
164
|
+
logger.debug(
|
|
165
|
+
"Heartbeat generator started for task: '%s'",
|
|
166
|
+
self.task_id,
|
|
167
|
+
extra={"task_id": self.task_id, "mission_id": self.mission_id},
|
|
168
|
+
)
|
|
156
169
|
while not self.cancelled:
|
|
157
|
-
logger.debug(
|
|
170
|
+
logger.debug(
|
|
171
|
+
"Heartbeat tick for task: '%s', cancelled=%s",
|
|
172
|
+
self.task_id,
|
|
173
|
+
self.cancelled,
|
|
174
|
+
extra={"task_id": self.task_id, "mission_id": self.mission_id},
|
|
175
|
+
)
|
|
158
176
|
success = await self.send_heartbeat()
|
|
159
177
|
if not success:
|
|
160
178
|
logger.error(
|
|
161
179
|
"Heartbeat failed, cancelling task: '%s'",
|
|
162
180
|
self.task_id,
|
|
163
|
-
extra={"task_id": self.task_id},
|
|
181
|
+
extra={"task_id": self.task_id, "mission_id": self.mission_id},
|
|
164
182
|
)
|
|
165
|
-
await self._handle_cancel()
|
|
183
|
+
await self._handle_cancel(CancellationReason.HEARTBEAT_FAILURE)
|
|
166
184
|
break
|
|
167
185
|
await asyncio.sleep(self._heartbeat_interval.total_seconds())
|
|
168
186
|
|
|
@@ -201,7 +219,7 @@ class TaskSession:
|
|
|
201
219
|
continue
|
|
202
220
|
|
|
203
221
|
if signal["action"] == "cancel":
|
|
204
|
-
await self._handle_cancel()
|
|
222
|
+
await self._handle_cancel(CancellationReason.SIGNAL)
|
|
205
223
|
elif signal["action"] == "pause":
|
|
206
224
|
await self._handle_pause()
|
|
207
225
|
elif signal["action"] == "resume":
|
|
@@ -231,26 +249,55 @@ class TaskSession:
|
|
|
231
249
|
extra={"task_id": self.task_id},
|
|
232
250
|
)
|
|
233
251
|
|
|
234
|
-
async def _handle_cancel(self) -> None:
|
|
235
|
-
"""Idempotent cancellation with acknowledgment.
|
|
236
|
-
|
|
252
|
+
async def _handle_cancel(self, reason: CancellationReason = CancellationReason.UNKNOWN) -> None:
|
|
253
|
+
"""Idempotent cancellation with acknowledgment and reason tracking.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
reason: The reason for cancellation (signal, heartbeat failure, cleanup, etc.)
|
|
257
|
+
"""
|
|
237
258
|
if self.is_cancelled.is_set():
|
|
238
259
|
logger.debug(
|
|
239
|
-
"Cancel
|
|
260
|
+
"Cancel ignored - task already cancelled: '%s' (existing reason: %s, new reason: %s)",
|
|
240
261
|
self.task_id,
|
|
241
|
-
|
|
262
|
+
self.cancellation_reason.value,
|
|
263
|
+
reason.value,
|
|
264
|
+
extra={
|
|
265
|
+
"task_id": self.task_id,
|
|
266
|
+
"mission_id": self.mission_id,
|
|
267
|
+
"existing_reason": self.cancellation_reason.value,
|
|
268
|
+
"new_reason": reason.value,
|
|
269
|
+
},
|
|
242
270
|
)
|
|
243
271
|
return
|
|
244
272
|
|
|
245
|
-
|
|
246
|
-
"Cancelling task: '%s'",
|
|
247
|
-
self.task_id,
|
|
248
|
-
extra={"task_id": self.task_id},
|
|
249
|
-
)
|
|
250
|
-
|
|
273
|
+
self.cancellation_reason = reason
|
|
251
274
|
self.status = TaskStatus.CANCELLED
|
|
252
275
|
self.is_cancelled.set()
|
|
253
276
|
|
|
277
|
+
# Log with appropriate level based on reason
|
|
278
|
+
if reason in {CancellationReason.SUCCESS_CLEANUP, CancellationReason.FAILURE_CLEANUP}:
|
|
279
|
+
logger.debug(
|
|
280
|
+
"Task cancelled (cleanup): '%s', reason: %s",
|
|
281
|
+
self.task_id,
|
|
282
|
+
reason.value,
|
|
283
|
+
extra={
|
|
284
|
+
"task_id": self.task_id,
|
|
285
|
+
"mission_id": self.mission_id,
|
|
286
|
+
"cancellation_reason": reason.value,
|
|
287
|
+
},
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
logger.info(
|
|
291
|
+
"Task cancelled: '%s', reason: %s",
|
|
292
|
+
self.task_id,
|
|
293
|
+
reason.value,
|
|
294
|
+
extra={
|
|
295
|
+
"task_id": self.task_id,
|
|
296
|
+
"mission_id": self.mission_id,
|
|
297
|
+
"cancellation_reason": reason.value,
|
|
298
|
+
},
|
|
299
|
+
)
|
|
300
|
+
|
|
254
301
|
# Resume if paused so cancellation can proceed
|
|
255
302
|
if self._paused.is_set():
|
|
256
303
|
self._paused.set()
|
|
@@ -326,3 +373,34 @@ class TaskSession:
|
|
|
326
373
|
self.task_id,
|
|
327
374
|
extra={"task_id": self.task_id},
|
|
328
375
|
)
|
|
376
|
+
|
|
377
|
+
async def cleanup(self) -> None:
|
|
378
|
+
"""Clean up task session resources.
|
|
379
|
+
|
|
380
|
+
This includes:
|
|
381
|
+
- Clearing queue to free memory
|
|
382
|
+
- Stopping module
|
|
383
|
+
- Closing database connection
|
|
384
|
+
- Clearing module reference
|
|
385
|
+
"""
|
|
386
|
+
# Clear queue to free memory
|
|
387
|
+
try:
|
|
388
|
+
while not self.queue.empty():
|
|
389
|
+
self.queue.get_nowait()
|
|
390
|
+
except asyncio.QueueEmpty:
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
# Stop module
|
|
394
|
+
try:
|
|
395
|
+
await self.module.stop()
|
|
396
|
+
except Exception:
|
|
397
|
+
logger.exception(
|
|
398
|
+
"Error stopping module during cleanup",
|
|
399
|
+
extra={"mission_id": self.mission_id, "task_id": self.task_id},
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Close DB connection (kills all live queries)
|
|
403
|
+
await self.db.close()
|
|
404
|
+
|
|
405
|
+
# Clear module reference to allow garbage collection
|
|
406
|
+
self.module = None # type: ignore
|
|
@@ -4,11 +4,11 @@ import uuid
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import grpc
|
|
7
|
-
from digitalkin_proto.
|
|
7
|
+
from digitalkin_proto.agentic_mesh_protocol.module.v1 import (
|
|
8
8
|
module_service_pb2,
|
|
9
9
|
module_service_pb2_grpc,
|
|
10
10
|
)
|
|
11
|
-
from digitalkin_proto.
|
|
11
|
+
from digitalkin_proto.agentic_mesh_protocol.module_registry.v1 import (
|
|
12
12
|
metadata_pb2,
|
|
13
13
|
module_registry_service_pb2_grpc,
|
|
14
14
|
registration_pb2,
|