digitalkin 0.2.23__py3-none-any.whl → 0.3.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/__init__.py +1 -0
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/__init__.py +1 -0
- digitalkin/{modules → core}/job_manager/base_job_manager.py +137 -31
- digitalkin/core/job_manager/single_job_manager.py +354 -0
- digitalkin/{modules → core}/job_manager/taskiq_broker.py +116 -22
- digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
- digitalkin/core/task_manager/__init__.py +1 -0
- digitalkin/core/task_manager/base_task_manager.py +539 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +266 -0
- digitalkin/core/task_manager/task_executor.py +249 -0
- digitalkin/core/task_manager/task_session.py +406 -0
- digitalkin/grpc_servers/__init__.py +1 -19
- digitalkin/grpc_servers/_base_server.py +3 -3
- digitalkin/grpc_servers/module_server.py +27 -43
- digitalkin/grpc_servers/module_servicer.py +51 -36
- digitalkin/grpc_servers/registry_server.py +2 -2
- digitalkin/grpc_servers/registry_servicer.py +4 -4
- digitalkin/grpc_servers/utils/__init__.py +1 -0
- digitalkin/grpc_servers/utils/exceptions.py +0 -8
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +4 -4
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/logger.py +73 -24
- digitalkin/mixins/__init__.py +19 -0
- digitalkin/mixins/base_mixin.py +10 -0
- digitalkin/mixins/callback_mixin.py +24 -0
- digitalkin/mixins/chat_history_mixin.py +110 -0
- digitalkin/mixins/cost_mixin.py +76 -0
- digitalkin/mixins/file_history_mixin.py +93 -0
- digitalkin/mixins/filesystem_mixin.py +46 -0
- digitalkin/mixins/logger_mixin.py +51 -0
- digitalkin/mixins/storage_mixin.py +79 -0
- digitalkin/models/core/__init__.py +1 -0
- digitalkin/{modules/job_manager → models/core}/job_manager_models.py +3 -3
- digitalkin/models/core/task_monitor.py +70 -0
- digitalkin/models/grpc_servers/__init__.py +1 -0
- digitalkin/{grpc_servers/utils → models/grpc_servers}/models.py +5 -5
- digitalkin/models/module/__init__.py +2 -0
- digitalkin/models/module/module.py +9 -1
- digitalkin/models/module/module_context.py +122 -6
- digitalkin/models/module/module_types.py +307 -19
- digitalkin/models/services/__init__.py +9 -0
- digitalkin/models/services/cost.py +1 -0
- digitalkin/models/services/storage.py +39 -5
- digitalkin/modules/_base_module.py +123 -118
- digitalkin/modules/tool_module.py +10 -2
- digitalkin/modules/trigger_handler.py +7 -6
- digitalkin/services/cost/__init__.py +9 -2
- digitalkin/services/cost/grpc_cost.py +9 -42
- digitalkin/services/filesystem/default_filesystem.py +0 -2
- digitalkin/services/filesystem/grpc_filesystem.py +10 -39
- digitalkin/services/setup/default_setup.py +5 -6
- digitalkin/services/setup/grpc_setup.py +52 -15
- digitalkin/services/storage/grpc_storage.py +4 -4
- digitalkin/services/user_profile/__init__.py +1 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- digitalkin/utils/__init__.py +28 -0
- digitalkin/utils/arg_parser.py +1 -1
- digitalkin/utils/development_mode_action.py +2 -2
- digitalkin/utils/dynamic_schema.py +483 -0
- digitalkin/utils/package_discover.py +1 -2
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/METADATA +11 -30
- digitalkin-0.3.1.dev2.dist-info/RECORD +119 -0
- modules/dynamic_setup_module.py +362 -0
- digitalkin/grpc_servers/utils/factory.py +0 -180
- digitalkin/modules/job_manager/single_job_manager.py +0 -294
- digitalkin/modules/job_manager/taskiq_job_manager.py +0 -290
- digitalkin-0.2.23.dist-info/RECORD +0 -89
- /digitalkin/{grpc_servers/utils → models/grpc_servers}/types.py +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/WHEEL +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
"""Base task manager with common lifecycle management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import datetime
|
|
6
|
+
import types
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections.abc import Coroutine
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
12
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
13
|
+
from digitalkin.logger import logger
|
|
14
|
+
from digitalkin.models.core.task_monitor import CancellationReason
|
|
15
|
+
from digitalkin.modules._base_module import BaseModule
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseTaskManager(ABC):
|
|
19
|
+
"""Base task manager with common lifecycle management.
|
|
20
|
+
|
|
21
|
+
Provides shared functionality for task orchestration, monitoring, signaling, and cancellation.
|
|
22
|
+
Subclasses implement specific execution strategies (local or remote).
|
|
23
|
+
|
|
24
|
+
Supports async context manager protocol for automatic resource cleanup:
|
|
25
|
+
async with LocalTaskManager() as manager:
|
|
26
|
+
await manager.create_task(...)
|
|
27
|
+
# Resources automatically cleaned up on exit
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
tasks: dict[str, asyncio.Task]
|
|
31
|
+
tasks_sessions: dict[str, TaskSession]
|
|
32
|
+
default_timeout: float
|
|
33
|
+
max_concurrent_tasks: int
|
|
34
|
+
_shutdown_event: asyncio.Event
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
default_timeout: float = 10.0,
|
|
39
|
+
max_concurrent_tasks: int = 100,
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Initialize task manager properties.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
default_timeout: Default timeout for task operations in seconds
|
|
45
|
+
max_concurrent_tasks: Maximum number of concurrent tasks
|
|
46
|
+
"""
|
|
47
|
+
self.tasks = {}
|
|
48
|
+
self.tasks_sessions = {}
|
|
49
|
+
self.default_timeout = default_timeout
|
|
50
|
+
self.max_concurrent_tasks = max_concurrent_tasks
|
|
51
|
+
self._shutdown_event = asyncio.Event()
|
|
52
|
+
|
|
53
|
+
logger.info(
|
|
54
|
+
"%s initialized with max_concurrent_tasks: %d, default_timeout: %.1f",
|
|
55
|
+
self.__class__.__name__,
|
|
56
|
+
max_concurrent_tasks,
|
|
57
|
+
default_timeout,
|
|
58
|
+
extra={
|
|
59
|
+
"max_concurrent_tasks": max_concurrent_tasks,
|
|
60
|
+
"default_timeout": default_timeout,
|
|
61
|
+
},
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def task_count(self) -> int:
|
|
66
|
+
"""Number of managed tasks."""
|
|
67
|
+
return len(self.tasks_sessions)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def running_tasks(self) -> set[str]:
|
|
71
|
+
"""Get IDs of currently running tasks."""
|
|
72
|
+
return {task_id for task_id, task in self.tasks.items() if not task.done()}
|
|
73
|
+
|
|
74
|
+
async def _cleanup_task(self, task_id: str, mission_id: str) -> None:
|
|
75
|
+
"""Clean up task resources.
|
|
76
|
+
|
|
77
|
+
Delegates cleanup to TaskSession which handles:
|
|
78
|
+
- Clearing queue items to free memory
|
|
79
|
+
- Stopping module (if not already stopped)
|
|
80
|
+
- Closing database connection (which kills live queries)
|
|
81
|
+
|
|
82
|
+
Then removes task from tracking dictionaries.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
task_id: The ID of the task to clean up
|
|
86
|
+
mission_id: The ID of the mission associated with the task
|
|
87
|
+
"""
|
|
88
|
+
session = self.tasks_sessions.get(task_id)
|
|
89
|
+
cancellation_reason = session.cancellation_reason.value if session else "no_session"
|
|
90
|
+
final_status = session.status.value if session else "unknown"
|
|
91
|
+
|
|
92
|
+
logger.debug(
|
|
93
|
+
"Cleaning up resources",
|
|
94
|
+
extra={
|
|
95
|
+
"mission_id": mission_id,
|
|
96
|
+
"task_id": task_id,
|
|
97
|
+
"final_status": final_status,
|
|
98
|
+
"cancellation_reason": cancellation_reason,
|
|
99
|
+
},
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if session:
|
|
103
|
+
await session.cleanup()
|
|
104
|
+
self.tasks_sessions.pop(task_id, None)
|
|
105
|
+
logger.debug(
|
|
106
|
+
"Task session cleanup completed",
|
|
107
|
+
extra={
|
|
108
|
+
"mission_id": mission_id,
|
|
109
|
+
"task_id": task_id,
|
|
110
|
+
"final_status": final_status,
|
|
111
|
+
"cancellation_reason": cancellation_reason,
|
|
112
|
+
},
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.tasks.pop(task_id, None)
|
|
116
|
+
|
|
117
|
+
async def _validate_task_creation(self, task_id: str, mission_id: str, coro: Coroutine[Any, Any, None]) -> None:
|
|
118
|
+
"""Validate task creation preconditions.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
task_id: The ID of the task to create
|
|
122
|
+
mission_id: The ID of the mission associated with the task
|
|
123
|
+
coro: The coroutine to execute
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
ValueError: If task_id already exists
|
|
127
|
+
RuntimeError: If max concurrent tasks reached
|
|
128
|
+
"""
|
|
129
|
+
if task_id in self.tasks_sessions:
|
|
130
|
+
coro.close()
|
|
131
|
+
logger.warning(
|
|
132
|
+
"Task creation failed - task already exists: '%s'",
|
|
133
|
+
task_id,
|
|
134
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
135
|
+
)
|
|
136
|
+
msg = f"Task {task_id} already exists"
|
|
137
|
+
raise ValueError(msg)
|
|
138
|
+
|
|
139
|
+
if len(self.tasks_sessions) >= self.max_concurrent_tasks:
|
|
140
|
+
coro.close()
|
|
141
|
+
logger.error(
|
|
142
|
+
"Task creation failed - max concurrent tasks reached: %d",
|
|
143
|
+
self.max_concurrent_tasks,
|
|
144
|
+
extra={
|
|
145
|
+
"mission_id": mission_id,
|
|
146
|
+
"task_id": task_id,
|
|
147
|
+
"current_count": len(self.tasks_sessions),
|
|
148
|
+
"max_concurrent": self.max_concurrent_tasks,
|
|
149
|
+
},
|
|
150
|
+
)
|
|
151
|
+
msg = f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached"
|
|
152
|
+
raise RuntimeError(msg)
|
|
153
|
+
|
|
154
|
+
async def _create_session(
|
|
155
|
+
self,
|
|
156
|
+
task_id: str,
|
|
157
|
+
mission_id: str,
|
|
158
|
+
module: BaseModule,
|
|
159
|
+
heartbeat_interval: datetime.timedelta,
|
|
160
|
+
connection_timeout: datetime.timedelta,
|
|
161
|
+
) -> tuple[SurrealDBConnection, TaskSession]:
|
|
162
|
+
"""Create SurrealDB connection and task session.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
task_id: The ID of the task
|
|
166
|
+
mission_id: The ID of the mission
|
|
167
|
+
module: The module instance
|
|
168
|
+
heartbeat_interval: Interval between heartbeats
|
|
169
|
+
connection_timeout: Connection timeout for SurrealDB
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Tuple of (channel, session)
|
|
173
|
+
"""
|
|
174
|
+
channel: SurrealDBConnection = SurrealDBConnection("task_manager", connection_timeout)
|
|
175
|
+
await channel.init_surreal_instance()
|
|
176
|
+
session = TaskSession(
|
|
177
|
+
task_id=task_id,
|
|
178
|
+
mission_id=mission_id,
|
|
179
|
+
db=channel,
|
|
180
|
+
module=module,
|
|
181
|
+
heartbeat_interval=heartbeat_interval,
|
|
182
|
+
)
|
|
183
|
+
self.tasks_sessions[task_id] = session
|
|
184
|
+
return channel, session
|
|
185
|
+
|
|
186
|
+
@abstractmethod
|
|
187
|
+
async def create_task(
|
|
188
|
+
self,
|
|
189
|
+
task_id: str,
|
|
190
|
+
mission_id: str,
|
|
191
|
+
module: BaseModule,
|
|
192
|
+
coro: Coroutine[Any, Any, None],
|
|
193
|
+
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
194
|
+
connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
|
|
195
|
+
) -> None:
|
|
196
|
+
"""Create and manage a new task.
|
|
197
|
+
|
|
198
|
+
Subclasses implement specific execution strategies.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
task_id: Unique identifier for the task
|
|
202
|
+
mission_id: Mission identifier
|
|
203
|
+
module: Module instance to execute
|
|
204
|
+
coro: Coroutine to execute
|
|
205
|
+
heartbeat_interval: Interval between heartbeats
|
|
206
|
+
connection_timeout: Connection timeout for SurrealDB
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
ValueError: If task_id duplicated
|
|
210
|
+
RuntimeError: If task overload
|
|
211
|
+
"""
|
|
212
|
+
...
|
|
213
|
+
|
|
214
|
+
async def send_signal(self, task_id: str, mission_id: str, signal_type: str, payload: dict) -> bool:
|
|
215
|
+
"""Send signal to a specific task.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
task_id: The ID of the task
|
|
219
|
+
mission_id: The ID of the mission
|
|
220
|
+
signal_type: Type of signal to send
|
|
221
|
+
payload: Signal payload
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
True if the signal was sent successfully, False otherwise
|
|
225
|
+
"""
|
|
226
|
+
if task_id not in self.tasks_sessions:
|
|
227
|
+
logger.warning(
|
|
228
|
+
"Cannot send signal - task not found: '%s'",
|
|
229
|
+
task_id,
|
|
230
|
+
extra={"mission_id": mission_id, "task_id": task_id, "signal_type": signal_type},
|
|
231
|
+
)
|
|
232
|
+
return False
|
|
233
|
+
|
|
234
|
+
logger.info(
|
|
235
|
+
"Sending signal '%s' to task: '%s'",
|
|
236
|
+
signal_type,
|
|
237
|
+
task_id,
|
|
238
|
+
extra={"mission_id": mission_id, "task_id": task_id, "signal_type": signal_type, "payload": payload},
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Use the task session's db connection to send the signal
|
|
242
|
+
session = self.tasks_sessions[task_id]
|
|
243
|
+
await session.db.update("signals", task_id, {"type": signal_type, "payload": payload})
|
|
244
|
+
return True
|
|
245
|
+
|
|
246
|
+
async def cancel_task(self, task_id: str, mission_id: str, timeout: float | None = None) -> bool:
|
|
247
|
+
"""Cancel a task with graceful shutdown and fallback.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
task_id: The ID of the task to cancel
|
|
251
|
+
mission_id: The ID of the mission
|
|
252
|
+
timeout: Optional timeout for cancellation
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
True if the task was cancelled successfully, False otherwise
|
|
256
|
+
"""
|
|
257
|
+
if task_id not in self.tasks:
|
|
258
|
+
logger.warning(
|
|
259
|
+
"Cannot cancel - task not found: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id}
|
|
260
|
+
)
|
|
261
|
+
# Still cleanup any orphaned session
|
|
262
|
+
await self._cleanup_task(task_id, mission_id)
|
|
263
|
+
return True
|
|
264
|
+
|
|
265
|
+
timeout = timeout or self.default_timeout
|
|
266
|
+
task = self.tasks[task_id]
|
|
267
|
+
|
|
268
|
+
logger.info(
|
|
269
|
+
"Initiating task cancellation: '%s', timeout: %.1fs",
|
|
270
|
+
task_id,
|
|
271
|
+
timeout,
|
|
272
|
+
extra={"mission_id": mission_id, "task_id": task_id, "timeout": timeout},
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
# Wait for graceful shutdown
|
|
277
|
+
await asyncio.wait_for(task, timeout=timeout)
|
|
278
|
+
|
|
279
|
+
logger.info(
|
|
280
|
+
"Task cancelled gracefully: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id}
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
except asyncio.TimeoutError:
|
|
284
|
+
# Set timeout as cancellation reason
|
|
285
|
+
if task_id in self.tasks_sessions:
|
|
286
|
+
session = self.tasks_sessions[task_id]
|
|
287
|
+
if session.cancellation_reason == CancellationReason.UNKNOWN:
|
|
288
|
+
session.cancellation_reason = CancellationReason.TIMEOUT
|
|
289
|
+
|
|
290
|
+
logger.warning(
|
|
291
|
+
"Graceful cancellation timed out for task: '%s', forcing cancellation",
|
|
292
|
+
task_id,
|
|
293
|
+
extra={
|
|
294
|
+
"mission_id": mission_id,
|
|
295
|
+
"task_id": task_id,
|
|
296
|
+
"timeout": timeout,
|
|
297
|
+
"cancellation_reason": CancellationReason.TIMEOUT.value,
|
|
298
|
+
},
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Phase 2: Force cancellation
|
|
302
|
+
task.cancel()
|
|
303
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
304
|
+
await task
|
|
305
|
+
|
|
306
|
+
logger.warning(
|
|
307
|
+
"Task force-cancelled: '%s', reason: %s",
|
|
308
|
+
task_id,
|
|
309
|
+
CancellationReason.TIMEOUT.value,
|
|
310
|
+
extra={
|
|
311
|
+
"mission_id": mission_id,
|
|
312
|
+
"task_id": task_id,
|
|
313
|
+
"cancellation_reason": CancellationReason.TIMEOUT.value,
|
|
314
|
+
},
|
|
315
|
+
)
|
|
316
|
+
return True
|
|
317
|
+
|
|
318
|
+
except Exception as e:
|
|
319
|
+
logger.error(
|
|
320
|
+
"Error during task cancellation: '%s'",
|
|
321
|
+
task_id,
|
|
322
|
+
extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
|
|
323
|
+
exc_info=True,
|
|
324
|
+
)
|
|
325
|
+
return False
|
|
326
|
+
finally:
|
|
327
|
+
await self._cleanup_task(task_id, mission_id)
|
|
328
|
+
return True
|
|
329
|
+
|
|
330
|
+
async def clean_session(self, task_id: str, mission_id: str) -> bool:
|
|
331
|
+
"""Clean up task session without cancelling the task.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
task_id: The ID of the task
|
|
335
|
+
mission_id: The ID of the mission
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
True if the task session was cleaned successfully, False otherwise.
|
|
339
|
+
"""
|
|
340
|
+
if task_id not in self.tasks_sessions:
|
|
341
|
+
logger.warning(
|
|
342
|
+
"Cannot clean session - task not found: '%s'",
|
|
343
|
+
task_id,
|
|
344
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
345
|
+
)
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
await self.cancel_task(mission_id=mission_id, task_id=task_id)
|
|
349
|
+
|
|
350
|
+
logger.info("Cleaning up session for task: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id})
|
|
351
|
+
return True
|
|
352
|
+
|
|
353
|
+
async def pause_task(self, task_id: str, mission_id: str) -> bool:
|
|
354
|
+
"""Pause a running task.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
task_id: The ID of the task
|
|
358
|
+
mission_id: The ID of the mission
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
True if the task was paused successfully, False otherwise
|
|
362
|
+
"""
|
|
363
|
+
return await self.send_signal(task_id=task_id, mission_id=mission_id, signal_type="pause", payload={})
|
|
364
|
+
|
|
365
|
+
async def resume_task(self, task_id: str, mission_id: str) -> bool:
|
|
366
|
+
"""Resume a paused task.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
task_id: The ID of the task
|
|
370
|
+
mission_id: The ID of the mission
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
True if the task was resumed successfully, False otherwise
|
|
374
|
+
"""
|
|
375
|
+
return await self.send_signal(task_id=task_id, mission_id=mission_id, signal_type="resume", payload={})
|
|
376
|
+
|
|
377
|
+
async def get_task_status(self, task_id: str, mission_id: str) -> bool:
|
|
378
|
+
"""Request status from a task.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
task_id: The ID of the task
|
|
382
|
+
mission_id: The ID of the mission
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
True if the status request was sent successfully, False otherwise
|
|
386
|
+
"""
|
|
387
|
+
return await self.send_signal(task_id=task_id, mission_id=mission_id, signal_type="status", payload={})
|
|
388
|
+
|
|
389
|
+
async def cancel_all_tasks(self, mission_id: str, timeout: float | None = None) -> dict[str, bool | BaseException]:
|
|
390
|
+
"""Cancel all running tasks.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
mission_id: The ID of the mission
|
|
394
|
+
timeout: Optional timeout for cancellation
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
Dictionary mapping task_id to cancellation success status
|
|
398
|
+
"""
|
|
399
|
+
timeout = timeout or self.default_timeout
|
|
400
|
+
task_ids = list(self.running_tasks)
|
|
401
|
+
|
|
402
|
+
logger.info(
|
|
403
|
+
"Cancelling all tasks in parallel: %d tasks",
|
|
404
|
+
len(task_ids),
|
|
405
|
+
extra={"mission_id": mission_id, "task_count": len(task_ids), "timeout": timeout},
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Cancel all tasks in parallel to reduce latency
|
|
409
|
+
cancel_coros = [
|
|
410
|
+
self.cancel_task(
|
|
411
|
+
task_id=task_id,
|
|
412
|
+
mission_id=mission_id,
|
|
413
|
+
timeout=timeout,
|
|
414
|
+
)
|
|
415
|
+
for task_id in task_ids
|
|
416
|
+
]
|
|
417
|
+
results_list = await asyncio.gather(*cancel_coros, return_exceptions=True)
|
|
418
|
+
|
|
419
|
+
# Build results dictionary
|
|
420
|
+
results: dict[str, bool | BaseException] = {}
|
|
421
|
+
for task_id, result in zip(task_ids, results_list):
|
|
422
|
+
if isinstance(result, Exception):
|
|
423
|
+
logger.error(
|
|
424
|
+
"Exception cancelling task: '%s', error: %s",
|
|
425
|
+
task_id,
|
|
426
|
+
result,
|
|
427
|
+
extra={
|
|
428
|
+
"mission_id": mission_id,
|
|
429
|
+
"task_id": task_id,
|
|
430
|
+
"error": str(result),
|
|
431
|
+
},
|
|
432
|
+
)
|
|
433
|
+
results[task_id] = False
|
|
434
|
+
else:
|
|
435
|
+
results[task_id] = result
|
|
436
|
+
|
|
437
|
+
return results
|
|
438
|
+
|
|
439
|
+
async def shutdown(self, mission_id: str, timeout: float = 30.0) -> None:
|
|
440
|
+
"""Graceful shutdown of all tasks.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
mission_id: The ID of the mission
|
|
444
|
+
timeout: Timeout for shutdown operations
|
|
445
|
+
"""
|
|
446
|
+
logger.info(
|
|
447
|
+
"TaskManager shutdown initiated, timeout: %.1fs",
|
|
448
|
+
timeout,
|
|
449
|
+
extra={"mission_id": mission_id, "timeout": timeout, "active_tasks": len(self.running_tasks)},
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
self._shutdown_event.set()
|
|
453
|
+
|
|
454
|
+
# Mark all sessions with shutdown reason before cancellation
|
|
455
|
+
for task_id, session in self.tasks_sessions.items():
|
|
456
|
+
if session.cancellation_reason == CancellationReason.UNKNOWN:
|
|
457
|
+
session.cancellation_reason = CancellationReason.SHUTDOWN
|
|
458
|
+
logger.debug(
|
|
459
|
+
"Marking task for shutdown: '%s'",
|
|
460
|
+
task_id,
|
|
461
|
+
extra={
|
|
462
|
+
"mission_id": mission_id,
|
|
463
|
+
"task_id": task_id,
|
|
464
|
+
"cancellation_reason": CancellationReason.SHUTDOWN.value,
|
|
465
|
+
},
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
results = await self.cancel_all_tasks(mission_id, timeout)
|
|
469
|
+
|
|
470
|
+
failed_tasks = [task_id for task_id, success in results.items() if not success]
|
|
471
|
+
if failed_tasks:
|
|
472
|
+
logger.error(
|
|
473
|
+
"Failed to cancel %d tasks during shutdown: %s",
|
|
474
|
+
len(failed_tasks),
|
|
475
|
+
failed_tasks,
|
|
476
|
+
extra={
|
|
477
|
+
"mission_id": mission_id,
|
|
478
|
+
"failed_tasks": failed_tasks,
|
|
479
|
+
"failed_count": len(failed_tasks),
|
|
480
|
+
"cancellation_reason": CancellationReason.SHUTDOWN.value,
|
|
481
|
+
},
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# Clean up any remaining sessions (in case cancellation didn't clean them)
|
|
485
|
+
remaining_sessions = list(self.tasks_sessions.keys())
|
|
486
|
+
if remaining_sessions:
|
|
487
|
+
logger.info(
|
|
488
|
+
"Cleaning up %d remaining task sessions after shutdown",
|
|
489
|
+
len(remaining_sessions),
|
|
490
|
+
extra={
|
|
491
|
+
"mission_id": mission_id,
|
|
492
|
+
"remaining_sessions": remaining_sessions,
|
|
493
|
+
"remaining_count": len(remaining_sessions),
|
|
494
|
+
},
|
|
495
|
+
)
|
|
496
|
+
cleanup_coros = [self._cleanup_task(task_id, mission_id) for task_id in remaining_sessions]
|
|
497
|
+
await asyncio.gather(*cleanup_coros, return_exceptions=True)
|
|
498
|
+
|
|
499
|
+
logger.info(
|
|
500
|
+
"TaskManager shutdown completed, cancelled: %d, failed: %d",
|
|
501
|
+
len(results) - len(failed_tasks),
|
|
502
|
+
len(failed_tasks),
|
|
503
|
+
extra={
|
|
504
|
+
"mission_id": mission_id,
|
|
505
|
+
"cancelled_count": len(results) - len(failed_tasks),
|
|
506
|
+
"failed_count": len(failed_tasks),
|
|
507
|
+
},
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
async def __aenter__(self) -> "BaseTaskManager":
|
|
511
|
+
"""Enter async context manager.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
Self for use in async with statements
|
|
515
|
+
"""
|
|
516
|
+
logger.debug("Entering %s context", self.__class__.__name__)
|
|
517
|
+
return self
|
|
518
|
+
|
|
519
|
+
async def __aexit__(
|
|
520
|
+
self,
|
|
521
|
+
exc_type: type[BaseException] | None,
|
|
522
|
+
exc_val: BaseException | None,
|
|
523
|
+
exc_tb: types.TracebackType | None,
|
|
524
|
+
) -> None:
|
|
525
|
+
"""Exit async context manager and clean up resources.
|
|
526
|
+
|
|
527
|
+
Args:
|
|
528
|
+
exc_type: Exception type if an exception occurred
|
|
529
|
+
exc_val: Exception value if an exception occurred
|
|
530
|
+
exc_tb: Exception traceback if an exception occurred
|
|
531
|
+
"""
|
|
532
|
+
logger.debug(
|
|
533
|
+
"Exiting %s context, exception: %s",
|
|
534
|
+
self.__class__.__name__,
|
|
535
|
+
exc_type,
|
|
536
|
+
extra={"exc_type": exc_type, "exc_val": exc_val},
|
|
537
|
+
)
|
|
538
|
+
# Shutdown with default mission_id for context manager usage
|
|
539
|
+
await self.shutdown(mission_id="context_manager_cleanup")
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Local task manager for single-process execution."""
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
from collections.abc import Coroutine
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from digitalkin.core.task_manager.base_task_manager import BaseTaskManager
|
|
8
|
+
from digitalkin.core.task_manager.task_executor import TaskExecutor
|
|
9
|
+
from digitalkin.logger import logger
|
|
10
|
+
from digitalkin.modules._base_module import BaseModule
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LocalTaskManager(BaseTaskManager):
|
|
14
|
+
"""Task manager for local execution in the same process.
|
|
15
|
+
|
|
16
|
+
Executes tasks locally using TaskExecutor with the supervisor pattern.
|
|
17
|
+
Suitable for single-server deployments and development.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
_executor: TaskExecutor
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
default_timeout: float = 10.0,
|
|
25
|
+
max_concurrent_tasks: int = 100,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize local task manager.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
default_timeout: Default timeout for task operations in seconds
|
|
31
|
+
max_concurrent_tasks: Maximum number of concurrent tasks
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(default_timeout, max_concurrent_tasks)
|
|
34
|
+
self._executor = TaskExecutor()
|
|
35
|
+
|
|
36
|
+
async def create_task(
|
|
37
|
+
self,
|
|
38
|
+
task_id: str,
|
|
39
|
+
mission_id: str,
|
|
40
|
+
module: BaseModule,
|
|
41
|
+
coro: Coroutine[Any, Any, None],
|
|
42
|
+
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
43
|
+
connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Create and execute a task locally using TaskExecutor.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
task_id: Unique identifier for the task
|
|
49
|
+
mission_id: Mission identifier
|
|
50
|
+
module: Module instance to execute
|
|
51
|
+
coro: Coroutine to execute
|
|
52
|
+
heartbeat_interval: Interval between heartbeats
|
|
53
|
+
connection_timeout: Connection timeout for SurrealDB
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If task_id duplicated
|
|
57
|
+
RuntimeError: If task overload
|
|
58
|
+
"""
|
|
59
|
+
# Validation
|
|
60
|
+
await self._validate_task_creation(task_id, mission_id, coro)
|
|
61
|
+
|
|
62
|
+
logger.info(
|
|
63
|
+
"Creating local task: '%s'",
|
|
64
|
+
task_id,
|
|
65
|
+
extra={
|
|
66
|
+
"mission_id": mission_id,
|
|
67
|
+
"task_id": task_id,
|
|
68
|
+
"heartbeat_interval": heartbeat_interval,
|
|
69
|
+
"connection_timeout": connection_timeout,
|
|
70
|
+
},
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
# Create session
|
|
75
|
+
channel, session = await self._create_session(
|
|
76
|
+
task_id, mission_id, module, heartbeat_interval, connection_timeout
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Execute task using TaskExecutor
|
|
80
|
+
supervisor_task = await self._executor.execute_task(
|
|
81
|
+
task_id,
|
|
82
|
+
mission_id,
|
|
83
|
+
coro,
|
|
84
|
+
session,
|
|
85
|
+
channel,
|
|
86
|
+
)
|
|
87
|
+
self.tasks[task_id] = supervisor_task
|
|
88
|
+
|
|
89
|
+
logger.info(
|
|
90
|
+
"Local task created and started: '%s'",
|
|
91
|
+
task_id,
|
|
92
|
+
extra={
|
|
93
|
+
"mission_id": mission_id,
|
|
94
|
+
"task_id": task_id,
|
|
95
|
+
"total_tasks": len(self.tasks),
|
|
96
|
+
},
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.error(
|
|
101
|
+
"Failed to create local task: '%s'",
|
|
102
|
+
task_id,
|
|
103
|
+
extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
|
|
104
|
+
exc_info=True,
|
|
105
|
+
)
|
|
106
|
+
# Cleanup on failure
|
|
107
|
+
await self._cleanup_task(task_id, mission_id=mission_id)
|
|
108
|
+
raise
|