digitalkin 0.3.0rc0__py3-none-any.whl → 0.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/__init__.py +1 -0
- digitalkin/core/job_manager/__init__.py +1 -0
- digitalkin/{modules → core}/job_manager/base_job_manager.py +5 -3
- digitalkin/{modules → core}/job_manager/single_job_manager.py +9 -10
- digitalkin/{modules → core}/job_manager/taskiq_broker.py +2 -3
- digitalkin/{modules → core}/job_manager/taskiq_job_manager.py +5 -6
- digitalkin/core/task_manager/__init__.py +1 -0
- digitalkin/{modules/job_manager → core/task_manager}/surrealdb_repository.py +0 -1
- digitalkin/{modules/job_manager → core/task_manager}/task_manager.py +102 -49
- digitalkin/{modules/job_manager → core/task_manager}/task_session.py +71 -18
- digitalkin/grpc_servers/__init__.py +1 -19
- digitalkin/grpc_servers/_base_server.py +2 -2
- digitalkin/grpc_servers/module_server.py +2 -2
- digitalkin/grpc_servers/module_servicer.py +3 -3
- digitalkin/grpc_servers/registry_server.py +1 -1
- digitalkin/grpc_servers/utils/__init__.py +1 -0
- digitalkin/grpc_servers/utils/exceptions.py +0 -8
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +1 -1
- digitalkin/mixins/chat_history_mixin.py +2 -0
- digitalkin/mixins/file_history_mixin.py +14 -20
- digitalkin/mixins/filesystem_mixin.py +1 -2
- digitalkin/mixins/logger_mixin.py +4 -12
- digitalkin/models/core/__init__.py +1 -0
- digitalkin/{modules/job_manager → models/core}/job_manager_models.py +3 -3
- digitalkin/models/{module → core}/task_monitor.py +7 -5
- digitalkin/models/grpc_servers/__init__.py +1 -0
- digitalkin/{grpc_servers/utils → models/grpc_servers}/models.py +4 -4
- digitalkin/models/module/module_context.py +33 -1
- digitalkin/models/module/module_types.py +5 -1
- digitalkin/models/services/cost.py +1 -0
- digitalkin/modules/_base_module.py +16 -80
- digitalkin/services/cost/grpc_cost.py +1 -1
- digitalkin/services/filesystem/grpc_filesystem.py +1 -1
- digitalkin/services/setup/grpc_setup.py +1 -1
- digitalkin/services/storage/grpc_storage.py +1 -1
- digitalkin/utils/arg_parser.py +1 -1
- digitalkin/utils/development_mode_action.py +2 -2
- digitalkin/utils/package_discover.py +1 -2
- {digitalkin-0.3.0rc0.dist-info → digitalkin-0.3.0rc2.dist-info}/METADATA +5 -25
- {digitalkin-0.3.0rc0.dist-info → digitalkin-0.3.0rc2.dist-info}/RECORD +45 -40
- digitalkin/grpc_servers/utils/factory.py +0 -180
- /digitalkin/{grpc_servers/utils → models/grpc_servers}/types.py +0 -0
- {digitalkin-0.3.0rc0.dist-info → digitalkin-0.3.0rc2.dist-info}/WHEEL +0 -0
- {digitalkin-0.3.0rc0.dist-info → digitalkin-0.3.0rc2.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.3.0rc0.dist-info → digitalkin-0.3.0rc2.dist-info}/top_level.txt +0 -0
digitalkin/__version__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Core of Digitlakin defining the task management and sub-modules."""
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Job Manager logic."""
|
|
@@ -5,15 +5,16 @@ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
|
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
6
|
from typing import Any, Generic
|
|
7
7
|
|
|
8
|
+
from digitalkin.core.task_manager.task_manager import TaskManager
|
|
9
|
+
from digitalkin.models.core.task_monitor import TaskStatus
|
|
8
10
|
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
9
11
|
from digitalkin.models.module.module import ModuleCodeModel
|
|
10
|
-
from digitalkin.models.module.task_monitor import TaskStatus
|
|
11
12
|
from digitalkin.modules._base_module import BaseModule
|
|
12
13
|
from digitalkin.services.services_config import ServicesConfig
|
|
13
14
|
from digitalkin.services.services_models import ServicesMode
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT, OutputModelT]):
|
|
17
|
+
class BaseJobManager(abc.ABC, TaskManager, Generic[InputModelT, SetupModelT, OutputModelT]):
|
|
17
18
|
"""Abstract base class for managing background module jobs."""
|
|
18
19
|
|
|
19
20
|
async def start(self) -> None:
|
|
@@ -25,7 +26,8 @@ class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT, OutputModelT]):
|
|
|
25
26
|
|
|
26
27
|
@staticmethod
|
|
27
28
|
async def job_specific_callback(
|
|
28
|
-
callback: Callable[[str, OutputModelT | ModuleCodeModel], Coroutine[Any, Any, None]],
|
|
29
|
+
callback: Callable[[str, OutputModelT | ModuleCodeModel], Coroutine[Any, Any, None]],
|
|
30
|
+
job_id: str,
|
|
29
31
|
) -> Callable[[OutputModelT | ModuleCodeModel], Coroutine[Any, Any, None]]:
|
|
30
32
|
"""Generate a job-specific callback function.
|
|
31
33
|
|
|
@@ -9,19 +9,18 @@ from typing import Any, Generic
|
|
|
9
9
|
|
|
10
10
|
import grpc
|
|
11
11
|
|
|
12
|
+
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
13
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
14
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
12
15
|
from digitalkin.logger import logger
|
|
16
|
+
from digitalkin.models.core.task_monitor import TaskStatus
|
|
13
17
|
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
14
18
|
from digitalkin.models.module.module import ModuleCodeModel
|
|
15
|
-
from digitalkin.models.module.task_monitor import TaskStatus
|
|
16
19
|
from digitalkin.modules._base_module import BaseModule
|
|
17
|
-
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
18
|
-
from digitalkin.modules.job_manager.surrealdb_repository import SurrealDBConnection
|
|
19
|
-
from digitalkin.modules.job_manager.task_manager import TaskManager
|
|
20
|
-
from digitalkin.modules.job_manager.task_session import TaskSession
|
|
21
20
|
from digitalkin.services.services_models import ServicesMode
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
class SingleJobManager(BaseJobManager,
|
|
23
|
+
class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupModelT]):
|
|
25
24
|
"""Manages a single instance of a module job.
|
|
26
25
|
|
|
27
26
|
This class ensures that only one instance of a module job is active at a time.
|
|
@@ -31,7 +30,7 @@ class SingleJobManager(BaseJobManager, TaskManager, Generic[InputModelT, OutputM
|
|
|
31
30
|
|
|
32
31
|
async def start(self) -> None:
|
|
33
32
|
"""Start manager."""
|
|
34
|
-
self.channel = SurrealDBConnection("task_manager", datetime.timedelta(seconds=5))
|
|
33
|
+
self.channel: SurrealDBConnection = SurrealDBConnection("task_manager", datetime.timedelta(seconds=5))
|
|
35
34
|
await self.channel.init_surreal_instance()
|
|
36
35
|
|
|
37
36
|
def __init__(
|
|
@@ -87,7 +86,6 @@ class SingleJobManager(BaseJobManager, TaskManager, Generic[InputModelT, OutputM
|
|
|
87
86
|
|
|
88
87
|
Args:
|
|
89
88
|
config_setup_data: The input data required to start the job.
|
|
90
|
-
setup_data: The setup configuration for the module.
|
|
91
89
|
mission_id: The mission ID associated with the job.
|
|
92
90
|
setup_id: The setup ID associated with the module.
|
|
93
91
|
setup_version_id: The setup ID.
|
|
@@ -101,7 +99,7 @@ class SingleJobManager(BaseJobManager, TaskManager, Generic[InputModelT, OutputM
|
|
|
101
99
|
job_id = str(uuid.uuid4())
|
|
102
100
|
# TODO: Ensure the job_id is unique.
|
|
103
101
|
module = self.module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
|
|
104
|
-
self.tasks_sessions[job_id] = TaskSession(job_id, self.channel, module)
|
|
102
|
+
self.tasks_sessions[job_id] = TaskSession(job_id, mission_id, self.channel, module)
|
|
105
103
|
|
|
106
104
|
try:
|
|
107
105
|
await module.start_config_setup(
|
|
@@ -224,6 +222,7 @@ class SingleJobManager(BaseJobManager, TaskManager, Generic[InputModelT, OutputM
|
|
|
224
222
|
|
|
225
223
|
await self.create_task(
|
|
226
224
|
job_id,
|
|
225
|
+
mission_id,
|
|
227
226
|
module,
|
|
228
227
|
module.start(input_data, setup_data, callback, done_callback=None),
|
|
229
228
|
)
|
|
@@ -254,7 +253,7 @@ class SingleJobManager(BaseJobManager, TaskManager, Generic[InputModelT, OutputM
|
|
|
254
253
|
await session.module.stop()
|
|
255
254
|
|
|
256
255
|
if job_id in self.tasks:
|
|
257
|
-
await self.cancel_task(job_id)
|
|
256
|
+
await self.cancel_task(job_id, session.mission_id)
|
|
258
257
|
logger.debug(f"session {job_id} ({session.module.name}) stopped successfully")
|
|
259
258
|
except Exception as e:
|
|
260
259
|
logger.error(f"Error while stopping module {job_id}: {e}")
|
|
@@ -14,11 +14,11 @@ from taskiq.compat import model_validate
|
|
|
14
14
|
from taskiq.message import BrokerMessage
|
|
15
15
|
from taskiq_aio_pika import AioPikaBroker
|
|
16
16
|
|
|
17
|
+
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
17
18
|
from digitalkin.logger import logger
|
|
19
|
+
from digitalkin.models.core.job_manager_models import StreamCodeModel
|
|
18
20
|
from digitalkin.models.module.module_types import OutputModelT
|
|
19
21
|
from digitalkin.modules._base_module import BaseModule
|
|
20
|
-
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
21
|
-
from digitalkin.modules.job_manager.job_manager_models import StreamCodeModel
|
|
22
22
|
from digitalkin.services.services_config import ServicesConfig
|
|
23
23
|
from digitalkin.services.services_models import ServicesMode
|
|
24
24
|
|
|
@@ -194,7 +194,6 @@ async def run_config_module(
|
|
|
194
194
|
module_class: type[BaseModule],
|
|
195
195
|
services_mode: ServicesMode,
|
|
196
196
|
config_setup_data: dict,
|
|
197
|
-
setup_data: dict,
|
|
198
197
|
context: Allow TaskIQ context access
|
|
199
198
|
"""
|
|
200
199
|
logger.warning("%s", services_mode)
|
|
@@ -17,12 +17,12 @@ from typing import TYPE_CHECKING, Any, Generic
|
|
|
17
17
|
|
|
18
18
|
from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, OffsetType
|
|
19
19
|
|
|
20
|
+
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
21
|
+
from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
|
|
20
22
|
from digitalkin.logger import logger
|
|
23
|
+
from digitalkin.models.core.task_monitor import TaskStatus
|
|
21
24
|
from digitalkin.models.module import InputModelT, SetupModelT
|
|
22
|
-
from digitalkin.models.module.task_monitor import TaskStatus
|
|
23
25
|
from digitalkin.modules._base_module import BaseModule
|
|
24
|
-
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
25
|
-
from digitalkin.modules.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
|
|
26
26
|
from digitalkin.services.services_models import ServicesMode
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
@@ -146,7 +146,6 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
146
146
|
|
|
147
147
|
Args:
|
|
148
148
|
config_setup_data: The input data required to start the job.
|
|
149
|
-
setup_data: The setup configuration for the module.
|
|
150
149
|
mission_id: The mission ID associated with the job.
|
|
151
150
|
setup_id: The setup ID associated with the module.
|
|
152
151
|
setup_version_id: The setup ID.
|
|
@@ -158,7 +157,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
158
157
|
TypeError: If the function is called with bad data type.
|
|
159
158
|
ValueError: If the module fails to start.
|
|
160
159
|
"""
|
|
161
|
-
task = TASKIQ_BROKER.find_task("digitalkin.
|
|
160
|
+
task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_config_module")
|
|
162
161
|
|
|
163
162
|
if task is None:
|
|
164
163
|
msg = "Task not found"
|
|
@@ -242,7 +241,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
242
241
|
Raises:
|
|
243
242
|
ValueError: If the task is not found.
|
|
244
243
|
"""
|
|
245
|
-
task = TASKIQ_BROKER.find_task("digitalkin.
|
|
244
|
+
task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_start_module")
|
|
246
245
|
|
|
247
246
|
if task is None:
|
|
248
247
|
msg = "Task not found"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Base task manager logic."""
|
|
@@ -6,14 +6,18 @@ import datetime
|
|
|
6
6
|
from collections.abc import Coroutine
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
10
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
9
11
|
from digitalkin.logger import logger
|
|
10
|
-
from digitalkin.models.
|
|
12
|
+
from digitalkin.models.core.task_monitor import SignalMessage, SignalType, TaskStatus
|
|
11
13
|
from digitalkin.modules._base_module import BaseModule
|
|
12
|
-
from digitalkin.modules.job_manager.task_session import SurrealDBConnection, TaskSession
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class TaskManager:
|
|
16
|
-
"""Task manager with comprehensive lifecycle management.
|
|
17
|
+
"""Task manager with comprehensive lifecycle management.
|
|
18
|
+
|
|
19
|
+
Handle the tasks creation, execution, monitoring, signaling, and cancellation.
|
|
20
|
+
"""
|
|
17
21
|
|
|
18
22
|
tasks: dict[str, asyncio.Task]
|
|
19
23
|
tasks_sessions: dict[str, TaskSession]
|
|
@@ -22,8 +26,8 @@ class TaskManager:
|
|
|
22
26
|
max_concurrent_tasks: int
|
|
23
27
|
_shutdown_event: asyncio.Event
|
|
24
28
|
|
|
25
|
-
def __init__(self, default_timeout: float = 10.0, max_concurrent_tasks: int =
|
|
26
|
-
"""."""
|
|
29
|
+
def __init__(self, default_timeout: float = 10.0, max_concurrent_tasks: int = 1000) -> None:
|
|
30
|
+
"""Defining task manager properties."""
|
|
27
31
|
self.tasks = {}
|
|
28
32
|
self.tasks_sessions = {}
|
|
29
33
|
self.default_timeout = default_timeout
|
|
@@ -34,29 +38,42 @@ class TaskManager:
|
|
|
34
38
|
"TaskManager initialized with max_concurrent_tasks: %d, default_timeout: %.1f",
|
|
35
39
|
max_concurrent_tasks,
|
|
36
40
|
default_timeout,
|
|
37
|
-
extra={
|
|
41
|
+
extra={
|
|
42
|
+
"max_concurrent_tasks": max_concurrent_tasks,
|
|
43
|
+
"default_timeout": default_timeout,
|
|
44
|
+
},
|
|
38
45
|
)
|
|
39
46
|
|
|
40
47
|
@property
|
|
41
48
|
def task_count(self) -> int:
|
|
42
|
-
"""."""
|
|
49
|
+
"""Number of managed tasks."""
|
|
43
50
|
return len(self.tasks_sessions)
|
|
44
51
|
|
|
45
52
|
@property
|
|
46
53
|
def running_tasks(self) -> set[str]:
|
|
47
|
-
"""."""
|
|
54
|
+
"""Get IDs of currently running tasks."""
|
|
48
55
|
return {task_id for task_id, task in self.tasks.items() if not task.done()}
|
|
49
56
|
|
|
50
|
-
async def _cleanup_task(self, task_id: str) -> None:
|
|
51
|
-
"""Clean up task resources.
|
|
52
|
-
|
|
57
|
+
async def _cleanup_task(self, task_id: str, mission_id: str) -> None:
|
|
58
|
+
"""Clean up task resources.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
task_id (str): The ID of the task to clean up.
|
|
62
|
+
mission_id (str): The ID of the mission associated with the task.
|
|
63
|
+
"""
|
|
64
|
+
logger.debug(
|
|
65
|
+
"Cleaning up resources for task: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id}
|
|
66
|
+
)
|
|
53
67
|
if task_id in self.tasks_sessions:
|
|
54
68
|
await self.tasks_sessions[task_id].db.close()
|
|
55
69
|
# Remove from collections
|
|
70
|
+
self.tasks.pop(task_id, None)
|
|
71
|
+
self.tasks_sessions.pop(task_id, None)
|
|
56
72
|
|
|
57
73
|
async def _task_wrapper( # noqa: C901, PLR0915
|
|
58
74
|
self,
|
|
59
75
|
task_id: str,
|
|
76
|
+
mission_id: str,
|
|
60
77
|
coro: Coroutine[Any, Any, None],
|
|
61
78
|
session: TaskSession,
|
|
62
79
|
) -> asyncio.Task[None]:
|
|
@@ -75,31 +92,33 @@ class TaskManager:
|
|
|
75
92
|
"tasks",
|
|
76
93
|
SignalMessage(
|
|
77
94
|
task_id=task_id,
|
|
95
|
+
mission_id=mission_id,
|
|
78
96
|
status=session.status,
|
|
79
97
|
action=SignalType.START,
|
|
80
98
|
).model_dump(),
|
|
81
99
|
)
|
|
82
100
|
await session.listen_signals()
|
|
83
101
|
except asyncio.CancelledError:
|
|
84
|
-
logger.debug("Signal listener cancelled", extra={"task_id": task_id})
|
|
102
|
+
logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
85
103
|
finally:
|
|
86
104
|
await self.channel.create(
|
|
87
105
|
"tasks",
|
|
88
106
|
SignalMessage(
|
|
89
107
|
task_id=task_id,
|
|
108
|
+
mission_id=mission_id,
|
|
90
109
|
status=session.status,
|
|
91
110
|
action=SignalType.STOP,
|
|
92
111
|
).model_dump(),
|
|
93
112
|
)
|
|
94
|
-
logger.info("Signal listener ended", extra={"task_id": task_id})
|
|
113
|
+
logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
95
114
|
|
|
96
115
|
async def heartbeat_wrapper() -> None:
|
|
97
116
|
try:
|
|
98
117
|
await session.generate_heartbeats()
|
|
99
118
|
except asyncio.CancelledError:
|
|
100
|
-
logger.debug("Signal listener cancelled", extra={"task_id": task_id})
|
|
119
|
+
logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
101
120
|
finally:
|
|
102
|
-
logger.info("Heartbeat task ended", extra={"task_id": task_id})
|
|
121
|
+
logger.info("Heartbeat task ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
103
122
|
|
|
104
123
|
async def supervisor() -> None:
|
|
105
124
|
session.started_at = datetime.datetime.now(datetime.timezone.utc)
|
|
@@ -153,6 +172,7 @@ class TaskManager:
|
|
|
153
172
|
async def create_task(
|
|
154
173
|
self,
|
|
155
174
|
task_id: str,
|
|
175
|
+
mission_id: str,
|
|
156
176
|
module: BaseModule,
|
|
157
177
|
coro: Coroutine[Any, Any, None],
|
|
158
178
|
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
@@ -167,7 +187,11 @@ class TaskManager:
|
|
|
167
187
|
if task_id in self.tasks:
|
|
168
188
|
# close Coroutine during runtime
|
|
169
189
|
coro.close()
|
|
170
|
-
logger.warning(
|
|
190
|
+
logger.warning(
|
|
191
|
+
"Task creation failed - task already exists: '%s'",
|
|
192
|
+
task_id,
|
|
193
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
194
|
+
)
|
|
171
195
|
msg = f"Task {task_id} already exists"
|
|
172
196
|
raise ValueError(msg)
|
|
173
197
|
|
|
@@ -177,6 +201,7 @@ class TaskManager:
|
|
|
177
201
|
"Task creation failed - max concurrent tasks reached: %d",
|
|
178
202
|
self.max_concurrent_tasks,
|
|
179
203
|
extra={
|
|
204
|
+
"mission_id": mission_id,
|
|
180
205
|
"task_id": task_id,
|
|
181
206
|
"current_count": len(self.tasks),
|
|
182
207
|
"max_concurrent": self.max_concurrent_tasks,
|
|
@@ -189,6 +214,7 @@ class TaskManager:
|
|
|
189
214
|
"Creating new task: '%s'",
|
|
190
215
|
task_id,
|
|
191
216
|
extra={
|
|
217
|
+
"mission_id": mission_id,
|
|
192
218
|
"task_id": task_id,
|
|
193
219
|
"heartbeat_interval": heartbeat_interval,
|
|
194
220
|
"connection_timeout": connection_timeout,
|
|
@@ -199,17 +225,26 @@ class TaskManager:
|
|
|
199
225
|
# Initialize components
|
|
200
226
|
channel: SurrealDBConnection = SurrealDBConnection("task_manager", connection_timeout)
|
|
201
227
|
await channel.init_surreal_instance()
|
|
202
|
-
session = TaskSession(task_id, channel, module, heartbeat_interval)
|
|
228
|
+
session = TaskSession(task_id, mission_id, channel, module, heartbeat_interval)
|
|
203
229
|
|
|
204
230
|
self.tasks_sessions[task_id] = session
|
|
205
231
|
|
|
206
232
|
# Create wrapper task
|
|
207
|
-
self.tasks[task_id] = asyncio.create_task(
|
|
233
|
+
self.tasks[task_id] = asyncio.create_task(
|
|
234
|
+
self._task_wrapper(
|
|
235
|
+
task_id,
|
|
236
|
+
mission_id,
|
|
237
|
+
coro,
|
|
238
|
+
session,
|
|
239
|
+
),
|
|
240
|
+
name=task_id,
|
|
241
|
+
)
|
|
208
242
|
|
|
209
243
|
logger.info(
|
|
210
244
|
"Task created successfully: '%s'",
|
|
211
245
|
task_id,
|
|
212
246
|
extra={
|
|
247
|
+
"mission_id": mission_id,
|
|
213
248
|
"task_id": task_id,
|
|
214
249
|
"total_tasks": len(self.tasks),
|
|
215
250
|
},
|
|
@@ -217,13 +252,16 @@ class TaskManager:
|
|
|
217
252
|
|
|
218
253
|
except Exception as e:
|
|
219
254
|
logger.error(
|
|
220
|
-
"Failed to create task: '%s'",
|
|
255
|
+
"Failed to create task: '%s'",
|
|
256
|
+
task_id,
|
|
257
|
+
extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
|
|
258
|
+
exc_info=True,
|
|
221
259
|
)
|
|
222
260
|
# Cleanup on failure
|
|
223
|
-
await self._cleanup_task(task_id)
|
|
261
|
+
await self._cleanup_task(task_id, mission_id=mission_id)
|
|
224
262
|
raise
|
|
225
263
|
|
|
226
|
-
async def send_signal(self, task_id: str, signal_type: str, payload: dict) -> bool:
|
|
264
|
+
async def send_signal(self, task_id: str, mission_id: str, signal_type: str, payload: dict) -> bool:
|
|
227
265
|
"""Send signal to a specific task.
|
|
228
266
|
|
|
229
267
|
Returns:
|
|
@@ -233,7 +271,7 @@ class TaskManager:
|
|
|
233
271
|
logger.warning(
|
|
234
272
|
"Cannot send signal - task not found: '%s'",
|
|
235
273
|
task_id,
|
|
236
|
-
extra={"task_id": task_id, "signal_type": signal_type},
|
|
274
|
+
extra={"mission_id": mission_id, "task_id": task_id, "signal_type": signal_type},
|
|
237
275
|
)
|
|
238
276
|
return False
|
|
239
277
|
|
|
@@ -241,20 +279,22 @@ class TaskManager:
|
|
|
241
279
|
"Sending signal '%s' to task: '%s'",
|
|
242
280
|
signal_type,
|
|
243
281
|
task_id,
|
|
244
|
-
extra={"task_id": task_id, "signal_type": signal_type, "payload": payload},
|
|
282
|
+
extra={"mission_id": mission_id, "task_id": task_id, "signal_type": signal_type, "payload": payload},
|
|
245
283
|
)
|
|
246
284
|
|
|
247
285
|
await self.channel.update("tasks", signal_type, payload)
|
|
248
286
|
return True
|
|
249
287
|
|
|
250
|
-
async def cancel_task(self, task_id: str, timeout: float | None = None) -> bool:
|
|
288
|
+
async def cancel_task(self, task_id: str, mission_id: str, timeout: float | None = None) -> bool:
|
|
251
289
|
"""Cancel a task with graceful shutdown and fallback.
|
|
252
290
|
|
|
253
291
|
Returns:
|
|
254
292
|
bool: True if the task was cancelled successfully, False otherwise.
|
|
255
293
|
"""
|
|
256
294
|
if task_id not in self.tasks:
|
|
257
|
-
logger.warning(
|
|
295
|
+
logger.warning(
|
|
296
|
+
"Cannot cancel - task not found: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id}
|
|
297
|
+
)
|
|
258
298
|
return True
|
|
259
299
|
|
|
260
300
|
timeout = timeout or self.default_timeout
|
|
@@ -264,23 +304,25 @@ class TaskManager:
|
|
|
264
304
|
"Initiating task cancellation: '%s', timeout: %.1fs",
|
|
265
305
|
task_id,
|
|
266
306
|
timeout,
|
|
267
|
-
extra={"task_id": task_id, "timeout": timeout},
|
|
307
|
+
extra={"mission_id": mission_id, "task_id": task_id, "timeout": timeout},
|
|
268
308
|
)
|
|
269
309
|
|
|
270
310
|
try:
|
|
271
311
|
# Phase 1: Cooperative cancellation
|
|
272
|
-
# await self.send_signal(task_id, "cancel") # noqa: ERA001
|
|
312
|
+
# await self.send_signal(task_id, mission_id, "cancel") # noqa: ERA001
|
|
273
313
|
|
|
274
314
|
# Wait for graceful shutdown
|
|
275
315
|
await asyncio.wait_for(task, timeout=timeout)
|
|
276
316
|
|
|
277
|
-
logger.info(
|
|
317
|
+
logger.info(
|
|
318
|
+
"Task cancelled gracefully: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id}
|
|
319
|
+
)
|
|
278
320
|
|
|
279
321
|
except asyncio.TimeoutError:
|
|
280
322
|
logger.warning(
|
|
281
323
|
"Graceful cancellation timed out for task: '%s', forcing cancellation",
|
|
282
324
|
task_id,
|
|
283
|
-
extra={"task_id": task_id, "timeout": timeout},
|
|
325
|
+
extra={"mission_id": mission_id, "task_id": task_id, "timeout": timeout},
|
|
284
326
|
)
|
|
285
327
|
|
|
286
328
|
# Phase 2: Force cancellation
|
|
@@ -288,61 +330,66 @@ class TaskManager:
|
|
|
288
330
|
with contextlib.suppress(asyncio.CancelledError):
|
|
289
331
|
await task
|
|
290
332
|
|
|
291
|
-
logger.warning("Task force-cancelled: '%s'", task_id, extra={"task_id": task_id})
|
|
333
|
+
logger.warning("Task force-cancelled: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id})
|
|
292
334
|
return True
|
|
293
335
|
|
|
294
336
|
except Exception as e:
|
|
295
337
|
logger.error(
|
|
296
338
|
"Error during task cancellation: '%s'",
|
|
297
339
|
task_id,
|
|
298
|
-
extra={"task_id": task_id, "error": str(e)},
|
|
340
|
+
extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
|
|
299
341
|
exc_info=True,
|
|
300
342
|
)
|
|
301
343
|
return False
|
|
302
344
|
return True
|
|
303
345
|
|
|
304
|
-
async def clean_session(self, task_id: str) -> bool:
|
|
346
|
+
async def clean_session(self, task_id: str, mission_id: str) -> bool:
|
|
305
347
|
"""Clean up task session without cancelling the task.
|
|
306
348
|
|
|
307
349
|
Returns:
|
|
308
350
|
bool: True if the task was cleaned successfully, False otherwise.
|
|
309
351
|
"""
|
|
310
352
|
if task_id not in self.tasks_sessions:
|
|
311
|
-
logger.warning(
|
|
353
|
+
logger.warning(
|
|
354
|
+
"Cannot clean session - task not found: '%s'",
|
|
355
|
+
task_id,
|
|
356
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
357
|
+
)
|
|
312
358
|
return False
|
|
313
359
|
|
|
314
360
|
await self.tasks_sessions[task_id].module.stop()
|
|
315
|
-
await self.cancel_task(task_id)
|
|
361
|
+
await self.cancel_task(task_id, mission_id)
|
|
316
362
|
|
|
317
|
-
logger.info("Cleaning up session for task: '%s'", task_id, extra={"task_id": task_id})
|
|
363
|
+
logger.info("Cleaning up session for task: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id})
|
|
318
364
|
self.tasks_sessions.pop(task_id, None)
|
|
365
|
+
self.tasks.pop(task_id, None)
|
|
319
366
|
return True
|
|
320
367
|
|
|
321
|
-
async def pause_task(self, task_id: str) -> bool:
|
|
368
|
+
async def pause_task(self, task_id: str, mission_id: str) -> bool:
|
|
322
369
|
"""Pause a running task.
|
|
323
370
|
|
|
324
371
|
Returns:
|
|
325
372
|
bool: True if the task was paused successfully, False otherwise.
|
|
326
373
|
"""
|
|
327
|
-
return await self.send_signal(task_id, "pause", {})
|
|
374
|
+
return await self.send_signal(task_id, mission_id, "pause", {})
|
|
328
375
|
|
|
329
|
-
async def resume_task(self, task_id: str) -> bool:
|
|
376
|
+
async def resume_task(self, task_id: str, mission_id: str) -> bool:
|
|
330
377
|
"""Resume a paused task.
|
|
331
378
|
|
|
332
379
|
Returns:
|
|
333
380
|
bool: True if the task was paused successfully, False otherwise.
|
|
334
381
|
"""
|
|
335
|
-
return await self.send_signal(task_id, "resume", {})
|
|
382
|
+
return await self.send_signal(task_id, mission_id, "resume", {})
|
|
336
383
|
|
|
337
|
-
async def get_task_status(self, task_id: str) -> bool:
|
|
384
|
+
async def get_task_status(self, task_id: str, mission_id: str) -> bool:
|
|
338
385
|
"""Request status from a task.
|
|
339
386
|
|
|
340
387
|
Returns:
|
|
341
388
|
bool: True if the task was paused successfully, False otherwise.
|
|
342
389
|
"""
|
|
343
|
-
return await self.send_signal(task_id, "status", {})
|
|
390
|
+
return await self.send_signal(task_id, mission_id, "status", {})
|
|
344
391
|
|
|
345
|
-
async def cancel_all_tasks(self, timeout: float | None = None) -> dict[str, bool]:
|
|
392
|
+
async def cancel_all_tasks(self, mission_id: str, timeout: float | None = None) -> dict[str, bool]:
|
|
346
393
|
"""Cancel all running tasks.
|
|
347
394
|
|
|
348
395
|
Returns:
|
|
@@ -352,25 +399,27 @@ class TaskManager:
|
|
|
352
399
|
task_ids = list(self.running_tasks)
|
|
353
400
|
|
|
354
401
|
logger.info(
|
|
355
|
-
"Cancelling all tasks: %d tasks",
|
|
402
|
+
"Cancelling all tasks: %d tasks",
|
|
403
|
+
len(task_ids),
|
|
404
|
+
extra={"mission_id": mission_id, "task_count": len(task_ids), "timeout": timeout},
|
|
356
405
|
)
|
|
357
406
|
|
|
358
407
|
results = {}
|
|
359
408
|
for task_id in task_ids:
|
|
360
|
-
results[task_id] = await self.cancel_task(task_id, timeout)
|
|
409
|
+
results[task_id] = await self.cancel_task(task_id, mission_id, timeout)
|
|
361
410
|
|
|
362
411
|
return results
|
|
363
412
|
|
|
364
|
-
async def shutdown(self, timeout: float = 30.0) -> None:
|
|
413
|
+
async def shutdown(self, mission_id: str, timeout: float = 30.0) -> None:
|
|
365
414
|
"""Graceful shutdown of all tasks."""
|
|
366
415
|
logger.info(
|
|
367
416
|
"TaskManager shutdown initiated, timeout: %.1fs",
|
|
368
417
|
timeout,
|
|
369
|
-
extra={"timeout": timeout, "active_tasks": len(self.running_tasks)},
|
|
418
|
+
extra={"mission_id": mission_id, "timeout": timeout, "active_tasks": len(self.running_tasks)},
|
|
370
419
|
)
|
|
371
420
|
|
|
372
421
|
self._shutdown_event.set()
|
|
373
|
-
results = await self.cancel_all_tasks(timeout)
|
|
422
|
+
results = await self.cancel_all_tasks(mission_id, timeout)
|
|
374
423
|
|
|
375
424
|
failed_tasks = [task_id for task_id, success in results.items() if not success]
|
|
376
425
|
if failed_tasks:
|
|
@@ -378,12 +427,16 @@ class TaskManager:
|
|
|
378
427
|
"Failed to cancel %d tasks during shutdown: %s",
|
|
379
428
|
len(failed_tasks),
|
|
380
429
|
failed_tasks,
|
|
381
|
-
extra={"failed_tasks": failed_tasks, "failed_count": len(failed_tasks)},
|
|
430
|
+
extra={"mission_id": mission_id, "failed_tasks": failed_tasks, "failed_count": len(failed_tasks)},
|
|
382
431
|
)
|
|
383
432
|
|
|
384
433
|
logger.info(
|
|
385
434
|
"TaskManager shutdown completed, cancelled: %d, failed: %d",
|
|
386
435
|
len(results) - len(failed_tasks),
|
|
387
436
|
len(failed_tasks),
|
|
388
|
-
extra={
|
|
437
|
+
extra={
|
|
438
|
+
"mission_id": mission_id,
|
|
439
|
+
"cancelled_count": len(results) - len(failed_tasks),
|
|
440
|
+
"failed_count": len(failed_tasks),
|
|
441
|
+
},
|
|
389
442
|
)
|