digitalkin 0.2.23__py3-none-any.whl → 0.3.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/__init__.py +1 -0
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/__init__.py +1 -0
- digitalkin/{modules → core}/job_manager/base_job_manager.py +137 -31
- digitalkin/core/job_manager/single_job_manager.py +354 -0
- digitalkin/{modules → core}/job_manager/taskiq_broker.py +116 -22
- digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
- digitalkin/core/task_manager/__init__.py +1 -0
- digitalkin/core/task_manager/base_task_manager.py +539 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +266 -0
- digitalkin/core/task_manager/task_executor.py +249 -0
- digitalkin/core/task_manager/task_session.py +406 -0
- digitalkin/grpc_servers/__init__.py +1 -19
- digitalkin/grpc_servers/_base_server.py +3 -3
- digitalkin/grpc_servers/module_server.py +27 -43
- digitalkin/grpc_servers/module_servicer.py +51 -36
- digitalkin/grpc_servers/registry_server.py +2 -2
- digitalkin/grpc_servers/registry_servicer.py +4 -4
- digitalkin/grpc_servers/utils/__init__.py +1 -0
- digitalkin/grpc_servers/utils/exceptions.py +0 -8
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +4 -4
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/logger.py +73 -24
- digitalkin/mixins/__init__.py +19 -0
- digitalkin/mixins/base_mixin.py +10 -0
- digitalkin/mixins/callback_mixin.py +24 -0
- digitalkin/mixins/chat_history_mixin.py +110 -0
- digitalkin/mixins/cost_mixin.py +76 -0
- digitalkin/mixins/file_history_mixin.py +93 -0
- digitalkin/mixins/filesystem_mixin.py +46 -0
- digitalkin/mixins/logger_mixin.py +51 -0
- digitalkin/mixins/storage_mixin.py +79 -0
- digitalkin/models/core/__init__.py +1 -0
- digitalkin/{modules/job_manager → models/core}/job_manager_models.py +3 -3
- digitalkin/models/core/task_monitor.py +70 -0
- digitalkin/models/grpc_servers/__init__.py +1 -0
- digitalkin/{grpc_servers/utils → models/grpc_servers}/models.py +5 -5
- digitalkin/models/module/__init__.py +2 -0
- digitalkin/models/module/module.py +9 -1
- digitalkin/models/module/module_context.py +122 -6
- digitalkin/models/module/module_types.py +307 -19
- digitalkin/models/services/__init__.py +9 -0
- digitalkin/models/services/cost.py +1 -0
- digitalkin/models/services/storage.py +39 -5
- digitalkin/modules/_base_module.py +123 -118
- digitalkin/modules/tool_module.py +10 -2
- digitalkin/modules/trigger_handler.py +7 -6
- digitalkin/services/cost/__init__.py +9 -2
- digitalkin/services/cost/grpc_cost.py +9 -42
- digitalkin/services/filesystem/default_filesystem.py +0 -2
- digitalkin/services/filesystem/grpc_filesystem.py +10 -39
- digitalkin/services/setup/default_setup.py +5 -6
- digitalkin/services/setup/grpc_setup.py +52 -15
- digitalkin/services/storage/grpc_storage.py +4 -4
- digitalkin/services/user_profile/__init__.py +1 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- digitalkin/utils/__init__.py +28 -0
- digitalkin/utils/arg_parser.py +1 -1
- digitalkin/utils/development_mode_action.py +2 -2
- digitalkin/utils/dynamic_schema.py +483 -0
- digitalkin/utils/package_discover.py +1 -2
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/METADATA +11 -30
- digitalkin-0.3.1.dev2.dist-info/RECORD +119 -0
- modules/dynamic_setup_module.py +362 -0
- digitalkin/grpc_servers/utils/factory.py +0 -180
- digitalkin/modules/job_manager/single_job_manager.py +0 -294
- digitalkin/modules/job_manager/taskiq_job_manager.py +0 -290
- digitalkin-0.2.23.dist-info/RECORD +0 -89
- /digitalkin/{grpc_servers/utils → models/grpc_servers}/types.py +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/WHEEL +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
"""Taskiq job manager module."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import taskiq # noqa: F401
|
|
5
|
+
|
|
6
|
+
except ImportError:
|
|
7
|
+
msg = "Install digitalkin[taskiq] to use this functionality\n$ uv pip install digitalkin[taskiq]."
|
|
8
|
+
raise ImportError(msg)
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import contextlib
|
|
12
|
+
import datetime
|
|
13
|
+
import json
|
|
14
|
+
import os
|
|
15
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
16
|
+
from contextlib import asynccontextmanager
|
|
17
|
+
from typing import TYPE_CHECKING, Any
|
|
18
|
+
|
|
19
|
+
from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, OffsetType
|
|
20
|
+
|
|
21
|
+
from digitalkin.core.common import ConnectionFactory, QueueFactory
|
|
22
|
+
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
23
|
+
from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER, cleanup_global_resources
|
|
24
|
+
from digitalkin.core.task_manager.remote_task_manager import RemoteTaskManager
|
|
25
|
+
from digitalkin.logger import logger
|
|
26
|
+
from digitalkin.models.core.task_monitor import TaskStatus
|
|
27
|
+
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
28
|
+
from digitalkin.modules._base_module import BaseModule
|
|
29
|
+
from digitalkin.services.services_models import ServicesMode
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from taskiq.task import AsyncTaskiqTask
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
36
|
+
"""Taskiq job manager for running modules in Taskiq tasks."""
|
|
37
|
+
|
|
38
|
+
services_mode: ServicesMode
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _define_consumer() -> Consumer:
|
|
42
|
+
"""Get from the env the connection parameter to RabbitMQ.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Consumer
|
|
46
|
+
"""
|
|
47
|
+
host: str = os.environ.get("RABBITMQ_RSTREAM_HOST", "localhost")
|
|
48
|
+
port: str = os.environ.get("RABBITMQ_RSTREAM_PORT", "5552")
|
|
49
|
+
username: str = os.environ.get("RABBITMQ_RSTREAM_USERNAME", "guest")
|
|
50
|
+
password: str = os.environ.get("RABBITMQ_RSTREAM_PASSWORD", "guest")
|
|
51
|
+
|
|
52
|
+
logger.info("Connection to RabbitMQ: %s:%s.", host, port)
|
|
53
|
+
return Consumer(host=host, port=int(port), username=username, password=password)
|
|
54
|
+
|
|
55
|
+
async def _on_message(self, message: bytes, message_context: MessageContext) -> None: # noqa: ARG002
|
|
56
|
+
"""Internal callback: parse JSON and route to the correct job queue."""
|
|
57
|
+
try:
|
|
58
|
+
data = json.loads(message.decode("utf-8"))
|
|
59
|
+
except json.JSONDecodeError:
|
|
60
|
+
return
|
|
61
|
+
job_id = data.get("job_id")
|
|
62
|
+
if not job_id:
|
|
63
|
+
return
|
|
64
|
+
queue = self.job_queues.get(job_id)
|
|
65
|
+
if queue:
|
|
66
|
+
await queue.put(data.get("output_data"))
|
|
67
|
+
|
|
68
|
+
async def start(self) -> None:
|
|
69
|
+
"""Start the TaskiqJobManager and initialize SurrealDB connection."""
|
|
70
|
+
await self._start()
|
|
71
|
+
self.channel = await ConnectionFactory.create_surreal_connection(
|
|
72
|
+
database="taskiq_job_manager", timeout=datetime.timedelta(seconds=5)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
async def _start(self) -> None:
|
|
76
|
+
await TASKIQ_BROKER.startup()
|
|
77
|
+
|
|
78
|
+
self.stream_consumer = self._define_consumer()
|
|
79
|
+
|
|
80
|
+
await self.stream_consumer.create_stream(
|
|
81
|
+
STREAM,
|
|
82
|
+
exists_ok=True,
|
|
83
|
+
arguments={"max-length-bytes": STREAM_RETENTION},
|
|
84
|
+
)
|
|
85
|
+
await self.stream_consumer.start()
|
|
86
|
+
|
|
87
|
+
start_spec = ConsumerOffsetSpecification(OffsetType.LAST)
|
|
88
|
+
# on_message use bytes instead of AMQPMessage
|
|
89
|
+
await self.stream_consumer.subscribe(
|
|
90
|
+
stream=STREAM,
|
|
91
|
+
subscriber_name=f"""subscriber_{os.environ.get("SERVER_NAME", "module_servicer")}""",
|
|
92
|
+
callback=self._on_message, # type: ignore
|
|
93
|
+
offset_specification=start_spec,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Wrap the consumer task with error handling
|
|
97
|
+
async def run_consumer_with_error_handling() -> None:
|
|
98
|
+
try:
|
|
99
|
+
await self.stream_consumer.run()
|
|
100
|
+
except asyncio.CancelledError:
|
|
101
|
+
logger.debug("Stream consumer task cancelled")
|
|
102
|
+
raise
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error("Stream consumer task failed: %s", e, exc_info=True, extra={"error": str(e)})
|
|
105
|
+
# Re-raise to ensure the error is not silently ignored
|
|
106
|
+
raise
|
|
107
|
+
|
|
108
|
+
self.stream_consumer_task = asyncio.create_task(
|
|
109
|
+
run_consumer_with_error_handling(),
|
|
110
|
+
name="stream_consumer_task",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
async def _stop(self) -> None:
|
|
114
|
+
"""Stop the TaskiqJobManager and clean up all resources."""
|
|
115
|
+
# Close SurrealDB connection
|
|
116
|
+
if hasattr(self, "channel"):
|
|
117
|
+
try:
|
|
118
|
+
await self.channel.close()
|
|
119
|
+
logger.info("TaskiqJobManager: SurrealDB connection closed")
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.warning("Failed to close SurrealDB connection: %s", e)
|
|
122
|
+
|
|
123
|
+
# Signal the consumer to stop
|
|
124
|
+
await self.stream_consumer.close()
|
|
125
|
+
# Cancel the background task
|
|
126
|
+
self.stream_consumer_task.cancel()
|
|
127
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
128
|
+
await self.stream_consumer_task
|
|
129
|
+
|
|
130
|
+
# Clean up job queues
|
|
131
|
+
self.job_queues.clear()
|
|
132
|
+
logger.info("TaskiqJobManager: Cleared %d job queues", len(self.job_queues))
|
|
133
|
+
|
|
134
|
+
# Call global cleanup for producer and broker
|
|
135
|
+
await cleanup_global_resources()
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
module_class: type[BaseModule],
|
|
140
|
+
services_mode: ServicesMode,
|
|
141
|
+
default_timeout: float = 10.0,
|
|
142
|
+
max_concurrent_tasks: int = 100,
|
|
143
|
+
stream_timeout: float = 30.0,
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Initialize the Taskiq job manager.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
module_class: The class of the module to be managed
|
|
149
|
+
services_mode: The mode of operation for the services
|
|
150
|
+
default_timeout: Default timeout for task operations
|
|
151
|
+
max_concurrent_tasks: Maximum number of concurrent tasks
|
|
152
|
+
stream_timeout: Timeout for stream consumer operations (default: 15.0s for distributed systems)
|
|
153
|
+
"""
|
|
154
|
+
# Create remote task manager for distributed execution
|
|
155
|
+
task_manager = RemoteTaskManager(default_timeout, max_concurrent_tasks)
|
|
156
|
+
|
|
157
|
+
# Initialize base job manager with task manager
|
|
158
|
+
super().__init__(module_class, services_mode, task_manager)
|
|
159
|
+
|
|
160
|
+
logger.warning("TaskiqJobManager initialized with app: %s", TASKIQ_BROKER)
|
|
161
|
+
self.job_queues: dict[str, asyncio.Queue] = {}
|
|
162
|
+
self.max_queue_size = 1000
|
|
163
|
+
self.stream_timeout = stream_timeout
|
|
164
|
+
|
|
165
|
+
async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
|
|
166
|
+
"""Generate a stream consumer for a module's output data.
|
|
167
|
+
|
|
168
|
+
This method creates an asynchronous generator that streams output data
|
|
169
|
+
from a specific module job. If the module does not exist, it generates
|
|
170
|
+
an error message.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
job_id: The unique identifier of the job.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
SetupModelT: the SetupModelT object fully processed.
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
asyncio.TimeoutError: If waiting for the setup response times out.
|
|
180
|
+
"""
|
|
181
|
+
queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
|
|
182
|
+
self.job_queues[job_id] = queue
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
# Add timeout to prevent indefinite blocking
|
|
186
|
+
item = await asyncio.wait_for(queue.get(), timeout=30.0)
|
|
187
|
+
except asyncio.TimeoutError:
|
|
188
|
+
logger.error("Timeout waiting for config setup response for job %s", job_id)
|
|
189
|
+
raise
|
|
190
|
+
else:
|
|
191
|
+
queue.task_done()
|
|
192
|
+
return item
|
|
193
|
+
finally:
|
|
194
|
+
logger.info(f"generate_config_setup_module_response: {job_id=}: {self.job_queues[job_id].empty()}")
|
|
195
|
+
self.job_queues.pop(job_id, None)
|
|
196
|
+
|
|
197
|
+
async def create_config_setup_instance_job(
|
|
198
|
+
self,
|
|
199
|
+
config_setup_data: SetupModelT,
|
|
200
|
+
mission_id: str,
|
|
201
|
+
setup_id: str,
|
|
202
|
+
setup_version_id: str,
|
|
203
|
+
) -> str:
|
|
204
|
+
"""Create and start a new module setup configuration job.
|
|
205
|
+
|
|
206
|
+
This method initializes a new module job, assigns it a unique job ID,
|
|
207
|
+
and starts the config setup it in the background.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
config_setup_data: The input data required to start the job.
|
|
211
|
+
mission_id: The mission ID associated with the job.
|
|
212
|
+
setup_id: The setup ID associated with the module.
|
|
213
|
+
setup_version_id: The setup ID.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
str: The unique identifier (job ID) of the created job.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
TypeError: If the function is called with bad data type.
|
|
220
|
+
ValueError: If the module fails to start.
|
|
221
|
+
"""
|
|
222
|
+
task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_config_module")
|
|
223
|
+
|
|
224
|
+
if task is None:
|
|
225
|
+
msg = "Task not found"
|
|
226
|
+
raise ValueError(msg)
|
|
227
|
+
|
|
228
|
+
if config_setup_data is None:
|
|
229
|
+
msg = "config_setup_data must be a valid model with model_dump method"
|
|
230
|
+
raise TypeError(msg)
|
|
231
|
+
|
|
232
|
+
# Submit task to Taskiq
|
|
233
|
+
running_task: AsyncTaskiqTask[Any] = await task.kiq(
|
|
234
|
+
mission_id,
|
|
235
|
+
setup_id,
|
|
236
|
+
setup_version_id,
|
|
237
|
+
self.module_class,
|
|
238
|
+
self.services_mode,
|
|
239
|
+
config_setup_data.model_dump(), # type: ignore
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
job_id = running_task.task_id
|
|
243
|
+
|
|
244
|
+
# Create module instance for metadata
|
|
245
|
+
module = self.module_class(
|
|
246
|
+
job_id,
|
|
247
|
+
mission_id=mission_id,
|
|
248
|
+
setup_id=setup_id,
|
|
249
|
+
setup_version_id=setup_version_id,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Register task in TaskManager (remote mode)
|
|
253
|
+
async def _dummy_coro() -> None:
|
|
254
|
+
"""Dummy coroutine - actual execution happens in worker."""
|
|
255
|
+
|
|
256
|
+
await self.create_task(
|
|
257
|
+
job_id,
|
|
258
|
+
mission_id,
|
|
259
|
+
module,
|
|
260
|
+
_dummy_coro(),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
logger.info("Registered config task: %s, waiting for initial result", job_id)
|
|
264
|
+
result = await running_task.wait_result(timeout=10)
|
|
265
|
+
logger.info("Job %s with data %s", job_id, result)
|
|
266
|
+
return job_id
|
|
267
|
+
|
|
268
|
+
@asynccontextmanager # type: ignore
|
|
269
|
+
async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
|
|
270
|
+
"""Generate a stream consumer for the RStream stream.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
job_id: The job ID to filter messages.
|
|
274
|
+
|
|
275
|
+
Yields:
|
|
276
|
+
messages: The stream messages from the associated module.
|
|
277
|
+
"""
|
|
278
|
+
queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
|
|
279
|
+
self.job_queues[job_id] = queue
|
|
280
|
+
|
|
281
|
+
async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
|
|
282
|
+
"""Generate the stream with batch-drain optimization.
|
|
283
|
+
|
|
284
|
+
This implementation uses a micro-batching pattern optimized for distributed
|
|
285
|
+
message streams from RabbitMQ:
|
|
286
|
+
1. Block waiting for the first item (with timeout for termination checks)
|
|
287
|
+
2. Drain all immediately available items without blocking (micro-batch)
|
|
288
|
+
3. Yield control back to event loop
|
|
289
|
+
|
|
290
|
+
This pattern provides:
|
|
291
|
+
- Better throughput for bursty message streams
|
|
292
|
+
- Reduced gRPC streaming overhead
|
|
293
|
+
- Lower latency when multiple messages arrive simultaneously
|
|
294
|
+
|
|
295
|
+
Yields:
|
|
296
|
+
dict: generated object from the module
|
|
297
|
+
"""
|
|
298
|
+
while True:
|
|
299
|
+
try:
|
|
300
|
+
# Block for first item with timeout to allow termination checks
|
|
301
|
+
item = await asyncio.wait_for(queue.get(), timeout=self.stream_timeout)
|
|
302
|
+
queue.task_done()
|
|
303
|
+
yield item
|
|
304
|
+
|
|
305
|
+
# Drain all immediately available items (micro-batch optimization)
|
|
306
|
+
# This reduces latency when messages arrive in bursts from RabbitMQ
|
|
307
|
+
batch_count = 0
|
|
308
|
+
max_batch_size = 100 # Safety limit to prevent memory spikes
|
|
309
|
+
while batch_count < max_batch_size:
|
|
310
|
+
try:
|
|
311
|
+
item = queue.get_nowait()
|
|
312
|
+
queue.task_done()
|
|
313
|
+
yield item
|
|
314
|
+
batch_count += 1
|
|
315
|
+
except asyncio.QueueEmpty: # noqa: PERF203
|
|
316
|
+
# No more items immediately available, break to next blocking wait
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
except asyncio.TimeoutError:
|
|
320
|
+
logger.warning("Stream consumer timeout for job %s, checking if job is still active", job_id)
|
|
321
|
+
|
|
322
|
+
# Check if job is registered
|
|
323
|
+
if job_id not in self.tasks_sessions:
|
|
324
|
+
logger.info("Job %s no longer registered, ending stream", job_id)
|
|
325
|
+
break
|
|
326
|
+
|
|
327
|
+
# Check job status to detect cancelled/failed jobs
|
|
328
|
+
status = await self.get_module_status(job_id)
|
|
329
|
+
|
|
330
|
+
if status in {TaskStatus.CANCELLED, TaskStatus.FAILED}:
|
|
331
|
+
logger.info("Job %s has terminal status %s, draining queue and ending stream", job_id, status)
|
|
332
|
+
|
|
333
|
+
# Drain remaining queue items before stopping
|
|
334
|
+
while not queue.empty():
|
|
335
|
+
try:
|
|
336
|
+
item = queue.get_nowait()
|
|
337
|
+
queue.task_done()
|
|
338
|
+
yield item
|
|
339
|
+
except asyncio.QueueEmpty: # noqa: PERF203
|
|
340
|
+
break
|
|
341
|
+
|
|
342
|
+
break
|
|
343
|
+
|
|
344
|
+
# Continue waiting for active/completed jobs
|
|
345
|
+
continue
|
|
346
|
+
|
|
347
|
+
try:
|
|
348
|
+
yield _stream()
|
|
349
|
+
finally:
|
|
350
|
+
self.job_queues.pop(job_id, None)
|
|
351
|
+
|
|
352
|
+
async def create_module_instance_job(
|
|
353
|
+
self,
|
|
354
|
+
input_data: InputModelT,
|
|
355
|
+
setup_data: SetupModelT,
|
|
356
|
+
mission_id: str,
|
|
357
|
+
setup_id: str,
|
|
358
|
+
setup_version_id: str,
|
|
359
|
+
) -> str:
|
|
360
|
+
"""Launches the module_task in Taskiq, returns the Taskiq task id as job_id.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
input_data: Input data for the module
|
|
364
|
+
setup_data: Setup data for the module
|
|
365
|
+
mission_id: Mission ID for the module
|
|
366
|
+
setup_id: The setup ID associated with the module.
|
|
367
|
+
setup_version_id: The setup ID associated with the module.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
job_id: The Taskiq task id.
|
|
371
|
+
|
|
372
|
+
Raises:
|
|
373
|
+
ValueError: If the task is not found.
|
|
374
|
+
"""
|
|
375
|
+
task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_start_module")
|
|
376
|
+
|
|
377
|
+
if task is None:
|
|
378
|
+
msg = "Task not found"
|
|
379
|
+
raise ValueError(msg)
|
|
380
|
+
|
|
381
|
+
# Submit task to Taskiq
|
|
382
|
+
running_task: AsyncTaskiqTask[Any] = await task.kiq(
|
|
383
|
+
mission_id,
|
|
384
|
+
setup_id,
|
|
385
|
+
setup_version_id,
|
|
386
|
+
self.module_class,
|
|
387
|
+
self.services_mode,
|
|
388
|
+
input_data.model_dump(),
|
|
389
|
+
setup_data.model_dump(),
|
|
390
|
+
)
|
|
391
|
+
job_id = running_task.task_id
|
|
392
|
+
|
|
393
|
+
# Create module instance for metadata
|
|
394
|
+
module = self.module_class(
|
|
395
|
+
job_id,
|
|
396
|
+
mission_id=mission_id,
|
|
397
|
+
setup_id=setup_id,
|
|
398
|
+
setup_version_id=setup_version_id,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Register task in TaskManager (remote mode)
|
|
402
|
+
# Dummy coroutine will be closed by TaskManager since execution_mode="remote"
|
|
403
|
+
async def _dummy_coro() -> None:
|
|
404
|
+
"""Dummy coroutine - actual execution happens in worker."""
|
|
405
|
+
|
|
406
|
+
await self.create_task(
|
|
407
|
+
job_id,
|
|
408
|
+
mission_id,
|
|
409
|
+
module,
|
|
410
|
+
_dummy_coro(), # Will be closed immediately by TaskManager in remote mode
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
logger.info("Registered remote task: %s, waiting for initial result", job_id)
|
|
414
|
+
result = await running_task.wait_result(timeout=10)
|
|
415
|
+
logger.debug("Job %s with data %s", job_id, result)
|
|
416
|
+
return job_id
|
|
417
|
+
|
|
418
|
+
async def get_module_status(self, job_id: str) -> TaskStatus:
|
|
419
|
+
"""Query a module status from SurrealDB.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
job_id: The unique identifier of the job.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
TaskStatus: The status of the module task.
|
|
426
|
+
"""
|
|
427
|
+
if job_id not in self.tasks_sessions:
|
|
428
|
+
logger.warning("Job %s not found in registry", job_id)
|
|
429
|
+
return TaskStatus.FAILED
|
|
430
|
+
|
|
431
|
+
# Safety check: if channel not initialized (start() wasn't called), return FAILED
|
|
432
|
+
if not hasattr(self, "channel") or self.channel is None:
|
|
433
|
+
logger.warning("Job %s status check failed - channel not initialized", job_id)
|
|
434
|
+
return TaskStatus.FAILED
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
# Query the tasks table for the task status
|
|
438
|
+
task_record = await self.channel.select_by_task_id("tasks", job_id)
|
|
439
|
+
if task_record and "status" in task_record:
|
|
440
|
+
status_str = task_record["status"]
|
|
441
|
+
return TaskStatus(status_str) if isinstance(status_str, str) else status_str
|
|
442
|
+
# If no record found in tasks, check heartbeats to see if task exists
|
|
443
|
+
heartbeat_record = await self.channel.select_by_task_id("heartbeats", job_id)
|
|
444
|
+
if heartbeat_record:
|
|
445
|
+
return TaskStatus.RUNNING
|
|
446
|
+
# No task or heartbeat record found - task may still be initializing
|
|
447
|
+
logger.debug("No task or heartbeat record found for job %s - task may still be initializing", job_id)
|
|
448
|
+
except Exception:
|
|
449
|
+
logger.exception("Error getting status for job %s", job_id)
|
|
450
|
+
return TaskStatus.FAILED
|
|
451
|
+
else:
|
|
452
|
+
return TaskStatus.FAILED
|
|
453
|
+
|
|
454
|
+
async def wait_for_completion(self, job_id: str) -> None:
|
|
455
|
+
"""Wait for a task to complete by polling its status from SurrealDB.
|
|
456
|
+
|
|
457
|
+
This method polls the task status until it reaches a terminal state.
|
|
458
|
+
Uses a 0.5 second polling interval to balance responsiveness and resource usage.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
job_id: The unique identifier of the job to wait for.
|
|
462
|
+
|
|
463
|
+
Raises:
|
|
464
|
+
KeyError: If the job_id is not found in tasks_sessions.
|
|
465
|
+
"""
|
|
466
|
+
if job_id not in self.tasks_sessions:
|
|
467
|
+
msg = f"Job {job_id} not found"
|
|
468
|
+
raise KeyError(msg)
|
|
469
|
+
|
|
470
|
+
# Poll task status until terminal state
|
|
471
|
+
terminal_states = {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED}
|
|
472
|
+
while True:
|
|
473
|
+
status = await self.get_module_status(job_id)
|
|
474
|
+
if status in terminal_states:
|
|
475
|
+
logger.debug("Job %s reached terminal state: %s", job_id, status)
|
|
476
|
+
break
|
|
477
|
+
await asyncio.sleep(0.5) # Poll interval
|
|
478
|
+
|
|
479
|
+
async def stop_module(self, job_id: str) -> bool:
|
|
480
|
+
"""Stop a running module using TaskManager.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
job_id: The Taskiq task id to stop.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
bool: True if the signal was successfully sent, False otherwise.
|
|
487
|
+
"""
|
|
488
|
+
if job_id not in self.tasks_sessions:
|
|
489
|
+
logger.warning("Job %s not found in registry", job_id)
|
|
490
|
+
return False
|
|
491
|
+
|
|
492
|
+
try:
|
|
493
|
+
session = self.tasks_sessions[job_id]
|
|
494
|
+
# Use TaskManager's cancel_task method which handles signal sending
|
|
495
|
+
await self.cancel_task(job_id, session.mission_id)
|
|
496
|
+
logger.info("Cancel signal sent for job %s via TaskManager", job_id)
|
|
497
|
+
|
|
498
|
+
# Clean up queue after cancellation
|
|
499
|
+
self.job_queues.pop(job_id, None)
|
|
500
|
+
logger.debug("Cleaned up queue for job %s", job_id)
|
|
501
|
+
except Exception:
|
|
502
|
+
logger.exception("Error stopping job %s", job_id)
|
|
503
|
+
return False
|
|
504
|
+
return True
|
|
505
|
+
|
|
506
|
+
async def stop_all_modules(self) -> None:
|
|
507
|
+
"""Stop all running modules tracked in the registry."""
|
|
508
|
+
stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
|
|
509
|
+
if stop_tasks:
|
|
510
|
+
results = await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
511
|
+
logger.info("Stopped %d modules, results: %s", len(results), results)
|
|
512
|
+
|
|
513
|
+
async def list_modules(self) -> dict[str, dict[str, Any]]:
|
|
514
|
+
"""List all modules tracked in the registry with their statuses.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
dict[str, dict[str, Any]]: A dictionary containing information about all tracked modules.
|
|
518
|
+
"""
|
|
519
|
+
modules_info: dict[str, dict[str, Any]] = {}
|
|
520
|
+
|
|
521
|
+
for job_id in self.tasks_sessions:
|
|
522
|
+
try:
|
|
523
|
+
status = await self.get_module_status(job_id)
|
|
524
|
+
task_record = await self.channel.select_by_task_id("tasks", job_id)
|
|
525
|
+
|
|
526
|
+
modules_info[job_id] = {
|
|
527
|
+
"name": self.module_class.__name__,
|
|
528
|
+
"status": status,
|
|
529
|
+
"class": self.module_class.__name__,
|
|
530
|
+
"mission_id": task_record.get("mission_id") if task_record else "unknown",
|
|
531
|
+
}
|
|
532
|
+
except Exception: # noqa: PERF203
|
|
533
|
+
logger.exception("Error getting info for job %s", job_id)
|
|
534
|
+
modules_info[job_id] = {
|
|
535
|
+
"name": self.module_class.__name__,
|
|
536
|
+
"status": TaskStatus.FAILED,
|
|
537
|
+
"class": self.module_class.__name__,
|
|
538
|
+
"error": "Failed to retrieve status",
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
return modules_info
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Base task manager logic."""
|