digitalkin 0.3.0rc2__py3-none-any.whl → 0.3.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/core/common/__init__.py +9 -0
  3. digitalkin/core/common/factories.py +156 -0
  4. digitalkin/core/job_manager/base_job_manager.py +128 -28
  5. digitalkin/core/job_manager/single_job_manager.py +80 -25
  6. digitalkin/core/job_manager/taskiq_broker.py +114 -19
  7. digitalkin/core/job_manager/taskiq_job_manager.py +292 -39
  8. digitalkin/core/task_manager/base_task_manager.py +464 -0
  9. digitalkin/core/task_manager/local_task_manager.py +108 -0
  10. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  11. digitalkin/core/task_manager/surrealdb_repository.py +43 -4
  12. digitalkin/core/task_manager/task_executor.py +173 -0
  13. digitalkin/core/task_manager/task_session.py +34 -12
  14. digitalkin/grpc_servers/module_server.py +2 -2
  15. digitalkin/grpc_servers/module_servicer.py +4 -3
  16. digitalkin/grpc_servers/registry_server.py +1 -1
  17. digitalkin/grpc_servers/registry_servicer.py +4 -4
  18. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  19. digitalkin/models/grpc_servers/models.py +4 -4
  20. digitalkin/services/cost/grpc_cost.py +8 -41
  21. digitalkin/services/filesystem/grpc_filesystem.py +9 -38
  22. digitalkin/services/setup/default_setup.py +5 -6
  23. digitalkin/services/setup/grpc_setup.py +51 -14
  24. digitalkin/services/storage/grpc_storage.py +2 -2
  25. digitalkin/services/user_profile/__init__.py +1 -0
  26. digitalkin/services/user_profile/default_user_profile.py +55 -0
  27. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  28. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  29. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/METADATA +7 -7
  30. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/RECORD +33 -23
  31. digitalkin/core/task_manager/task_manager.py +0 -442
  32. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/WHEEL +0 -0
  33. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/licenses/LICENSE +0 -0
  34. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  """Taskiq broker & RSTREAM producer for the job manager."""
2
2
 
3
3
  import asyncio
4
+ import datetime
4
5
  import json
5
6
  import logging
6
7
  import os
7
8
  import pickle # noqa: S403
9
+ from typing import Any
8
10
 
9
11
  from rstream import Producer
10
12
  from rstream.exceptions import PreconditionFailed
@@ -14,7 +16,10 @@ from taskiq.compat import model_validate
14
16
  from taskiq.message import BrokerMessage
15
17
  from taskiq_aio_pika import AioPikaBroker
16
18
 
19
+ from digitalkin.core.common import ConnectionFactory, ModuleFactory
17
20
  from digitalkin.core.job_manager.base_job_manager import BaseJobManager
21
+ from digitalkin.core.task_manager.task_executor import TaskExecutor
22
+ from digitalkin.core.task_manager.task_session import TaskSession
18
23
  from digitalkin.logger import logger
19
24
  from digitalkin.models.core.job_manager_models import StreamCodeModel
20
25
  from digitalkin.models.module.module_types import OutputModelT
@@ -118,6 +123,24 @@ RSTREAM_PRODUCER = define_producer()
118
123
  TASKIQ_BROKER = define_broker()
119
124
 
120
125
 
126
+ async def cleanup_global_resources() -> None:
127
+ """Clean up global resources (producer and broker connections).
128
+
129
+ This should be called during shutdown to prevent connection leaks.
130
+ """
131
+ try:
132
+ await RSTREAM_PRODUCER.close()
133
+ logger.info("RStream producer closed successfully")
134
+ except Exception as e:
135
+ logger.warning("Failed to close RStream producer: %s", e)
136
+
137
+ try:
138
+ await TASKIQ_BROKER.shutdown()
139
+ logger.info("Taskiq broker shut down successfully")
140
+ except Exception as e:
141
+ logger.warning("Failed to shutdown Taskiq broker: %s", e)
142
+
143
+
121
144
  async def send_message_to_stream(job_id: str, output_data: OutputModelT) -> None: # type: ignore
122
145
  """Callback define to add a message frame to the Rstream.
123
146
 
@@ -152,27 +175,70 @@ async def run_start_module(
152
175
  setup_data: dict,
153
176
  context: Allow TaskIQ context access
154
177
  """
155
- logger.warning("%s", services_mode)
178
+ logger.info("Starting module with services_mode: %s", services_mode)
156
179
  services_config = ServicesConfig(
157
180
  services_config_strategies=module_class.services_config_strategies,
158
181
  services_config_params=module_class.services_config_params,
159
182
  mode=services_mode,
160
183
  )
161
184
  setattr(module_class, "services_config", services_config)
162
- logger.warning("%s | %s", services_config, module_class.services_config)
185
+ logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
186
+ module_class.discover()
163
187
 
164
188
  job_id = context.message.task_id
165
189
  callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
166
- module = module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
167
-
168
- await module.start(
169
- input_data,
170
- setup_data,
171
- callback,
172
- # ensure that the callback is called when the task is done + allow asyncio to run
173
- # TODO: should define a BaseModel for stream code / error
174
- done_callback=lambda _: asyncio.create_task(callback(StreamCodeModel(code="__END_OF_STREAM__"))),
175
- )
190
+ module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
191
+
192
+ channel = None
193
+ try:
194
+ # Create TaskExecutor and supporting components for worker execution
195
+ executor = TaskExecutor()
196
+ # SurrealDB env vars are expected to be set in env.
197
+ channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
198
+ session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
199
+
200
+ # Execute the task using TaskExecutor
201
+ # Create a proper done callback that handles errors
202
+ async def send_end_of_stream(_: Any) -> None: # noqa: ANN401
203
+ try:
204
+ await callback(StreamCodeModel(code="__END_OF_STREAM__"))
205
+ except Exception as e:
206
+ logger.error("Error sending end of stream: %s", e, exc_info=True)
207
+
208
+ # Reconstruct Pydantic models from dicts for type safety
209
+ try:
210
+ input_model = module_class.create_input_model(input_data)
211
+ setup_model = module_class.create_setup_model(setup_data)
212
+ except Exception as e:
213
+ logger.error("Failed to reconstruct models for job %s: %s", job_id, e, exc_info=True)
214
+ raise
215
+
216
+ supervisor_task = await executor.execute_task(
217
+ task_id=job_id,
218
+ mission_id=mission_id,
219
+ coro=module.start(
220
+ input_model,
221
+ setup_model,
222
+ callback,
223
+ done_callback=lambda result: asyncio.ensure_future(send_end_of_stream(result)),
224
+ ),
225
+ session=session,
226
+ channel=channel,
227
+ )
228
+
229
+ # Wait for the supervisor task to complete
230
+ await supervisor_task
231
+ logger.info("Module task %s completed", job_id)
232
+ except Exception:
233
+ logger.exception("Error running module %s", job_id)
234
+ raise
235
+ finally:
236
+ # Cleanup channel
237
+ if channel is not None:
238
+ try:
239
+ await channel.close()
240
+ except Exception:
241
+ logger.exception("Error closing channel for job %s", job_id)
176
242
 
177
243
 
178
244
  @TASKIQ_BROKER.task
@@ -196,20 +262,49 @@ async def run_config_module(
196
262
  config_setup_data: dict,
197
263
  context: Allow TaskIQ context access
198
264
  """
199
- logger.warning("%s", services_mode)
265
+ logger.info("Starting config module with services_mode: %s", services_mode)
200
266
  services_config = ServicesConfig(
201
267
  services_config_strategies=module_class.services_config_strategies,
202
268
  services_config_params=module_class.services_config_params,
203
269
  mode=services_mode,
204
270
  )
205
271
  setattr(module_class, "services_config", services_config)
206
- logger.warning("%s | %s", services_config, module_class.services_config)
272
+ logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
207
273
 
208
274
  job_id = context.message.task_id
209
275
  callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
210
- module = module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
276
+ module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
211
277
 
212
- await module.start_config_setup(
213
- module_class.create_config_setup_model(config_setup_data),
214
- callback,
215
- )
278
+ # Override environment variables temporarily to use manager's SurrealDB
279
+ channel = None
280
+ try:
281
+ # Create TaskExecutor and supporting components for worker execution
282
+ executor = TaskExecutor()
283
+ # SurrealDB env vars are expected to be set in env.
284
+ channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
285
+ session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
286
+
287
+ # Create and run the config setup task with TaskExecutor
288
+ setup_model = module_class.create_config_setup_model(config_setup_data)
289
+
290
+ supervisor_task = await executor.execute_task(
291
+ task_id=job_id,
292
+ mission_id=mission_id,
293
+ coro=module.start_config_setup(setup_model, callback),
294
+ session=session,
295
+ channel=channel,
296
+ )
297
+
298
+ # Wait for the supervisor task to complete
299
+ await supervisor_task
300
+ logger.info("Config module task %s completed", job_id)
301
+ except Exception:
302
+ logger.exception("Error running config module %s", job_id)
303
+ raise
304
+ finally:
305
+ # Cleanup channel
306
+ if channel is not None:
307
+ try:
308
+ await channel.close()
309
+ except Exception:
310
+ logger.exception("Error closing channel for job %s", job_id)
@@ -9,19 +9,22 @@ except ImportError:
9
9
 
10
10
  import asyncio
11
11
  import contextlib
12
+ import datetime
12
13
  import json
13
14
  import os
14
15
  from collections.abc import AsyncGenerator, AsyncIterator
15
16
  from contextlib import asynccontextmanager
16
- from typing import TYPE_CHECKING, Any, Generic
17
+ from typing import TYPE_CHECKING, Any
17
18
 
18
19
  from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, OffsetType
19
20
 
21
+ from digitalkin.core.common import ConnectionFactory, QueueFactory
20
22
  from digitalkin.core.job_manager.base_job_manager import BaseJobManager
21
- from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
23
+ from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER, cleanup_global_resources
24
+ from digitalkin.core.task_manager.remote_task_manager import RemoteTaskManager
22
25
  from digitalkin.logger import logger
23
26
  from digitalkin.models.core.task_monitor import TaskStatus
24
- from digitalkin.models.module import InputModelT, SetupModelT
27
+ from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
25
28
  from digitalkin.modules._base_module import BaseModule
26
29
  from digitalkin.services.services_models import ServicesMode
27
30
 
@@ -29,7 +32,7 @@ if TYPE_CHECKING:
29
32
  from taskiq.task import AsyncTaskiqTask
30
33
 
31
34
 
32
- class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
35
+ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
33
36
  """Taskiq job manager for running modules in Taskiq tasks."""
34
37
 
35
38
  services_mode: ServicesMode
@@ -62,6 +65,13 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
62
65
  if queue:
63
66
  await queue.put(data.get("output_data"))
64
67
 
68
+ async def start(self) -> None:
69
+ """Start the TaskiqJobManager and initialize SurrealDB connection."""
70
+ await self._start()
71
+ self.channel = await ConnectionFactory.create_surreal_connection(
72
+ database="taskiq_job_manager", timeout=datetime.timedelta(seconds=5)
73
+ )
74
+
65
75
  async def _start(self) -> None:
66
76
  await TASKIQ_BROKER.startup()
67
77
 
@@ -82,12 +92,34 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
82
92
  callback=self._on_message, # type: ignore
83
93
  offset_specification=start_spec,
84
94
  )
95
+
96
+ # Wrap the consumer task with error handling
97
+ async def run_consumer_with_error_handling() -> None:
98
+ try:
99
+ await self.stream_consumer.run()
100
+ except asyncio.CancelledError:
101
+ logger.debug("Stream consumer task cancelled")
102
+ raise
103
+ except Exception as e:
104
+ logger.error("Stream consumer task failed: %s", e, exc_info=True, extra={"error": str(e)})
105
+ # Re-raise to ensure the error is not silently ignored
106
+ raise
107
+
85
108
  self.stream_consumer_task = asyncio.create_task(
86
- self.stream_consumer.run(),
109
+ run_consumer_with_error_handling(),
87
110
  name="stream_consumer_task",
88
111
  )
89
112
 
90
113
  async def _stop(self) -> None:
114
+ """Stop the TaskiqJobManager and clean up all resources."""
115
+ # Close SurrealDB connection
116
+ if hasattr(self, "channel"):
117
+ try:
118
+ await self.channel.close()
119
+ logger.info("TaskiqJobManager: SurrealDB connection closed")
120
+ except Exception as e:
121
+ logger.warning("Failed to close SurrealDB connection: %s", e)
122
+
91
123
  # Signal the consumer to stop
92
124
  await self.stream_consumer.close()
93
125
  # Cancel the background task
@@ -95,18 +127,40 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
95
127
  with contextlib.suppress(asyncio.CancelledError):
96
128
  await self.stream_consumer_task
97
129
 
130
+ # Clean up job queues
131
+ self.job_queues.clear()
132
+ logger.info("TaskiqJobManager: Cleared %d job queues", len(self.job_queues))
133
+
134
+ # Call global cleanup for producer and broker
135
+ await cleanup_global_resources()
136
+
98
137
  def __init__(
99
138
  self,
100
139
  module_class: type[BaseModule],
101
140
  services_mode: ServicesMode,
141
+ default_timeout: float = 10.0,
142
+ max_concurrent_tasks: int = 100,
143
+ stream_timeout: float = 15.0,
102
144
  ) -> None:
103
- """Initialize the Taskiq job manager."""
104
- super().__init__(module_class, services_mode)
145
+ """Initialize the Taskiq job manager.
146
+
147
+ Args:
148
+ module_class: The class of the module to be managed
149
+ services_mode: The mode of operation for the services
150
+ default_timeout: Default timeout for task operations
151
+ max_concurrent_tasks: Maximum number of concurrent tasks
152
+ stream_timeout: Timeout for stream consumer operations (default: 15.0s for distributed systems)
153
+ """
154
+ # Create remote task manager for distributed execution
155
+ task_manager = RemoteTaskManager(default_timeout, max_concurrent_tasks)
156
+
157
+ # Initialize base job manager with task manager
158
+ super().__init__(module_class, services_mode, task_manager)
105
159
 
106
160
  logger.warning("TaskiqJobManager initialized with app: %s", TASKIQ_BROKER)
107
- self.services_mode = services_mode
108
161
  self.job_queues: dict[str, asyncio.Queue] = {}
109
162
  self.max_queue_size = 1000
163
+ self.stream_timeout = stream_timeout
110
164
 
111
165
  async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
112
166
  """Generate a stream consumer for a module's output data.
@@ -120,12 +174,20 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
120
174
 
121
175
  Returns:
122
176
  SetupModelT: the SetupModelT object fully processed.
177
+
178
+ Raises:
179
+ asyncio.TimeoutError: If waiting for the setup response times out.
123
180
  """
124
- queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_queue_size)
181
+ queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
125
182
  self.job_queues[job_id] = queue
126
183
 
127
184
  try:
128
- item = await queue.get()
185
+ # Add timeout to prevent indefinite blocking
186
+ item = await asyncio.wait_for(queue.get(), timeout=30.0)
187
+ except asyncio.TimeoutError:
188
+ logger.error("Timeout waiting for config setup response for job %s", job_id)
189
+ raise
190
+ else:
129
191
  queue.task_done()
130
192
  return item
131
193
  finally:
@@ -157,7 +219,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
157
219
  TypeError: If the function is called with bad data type.
158
220
  ValueError: If the module fails to start.
159
221
  """
160
- task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_config_module")
222
+ task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_config_module")
161
223
 
162
224
  if task is None:
163
225
  msg = "Task not found"
@@ -167,6 +229,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
167
229
  msg = "config_setup_data must be a valid model with model_dump method"
168
230
  raise TypeError(msg)
169
231
 
232
+ # Submit task to Taskiq
170
233
  running_task: AsyncTaskiqTask[Any] = await task.kiq(
171
234
  mission_id,
172
235
  setup_id,
@@ -177,6 +240,27 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
177
240
  )
178
241
 
179
242
  job_id = running_task.task_id
243
+
244
+ # Create module instance for metadata
245
+ module = self.module_class(
246
+ job_id,
247
+ mission_id=mission_id,
248
+ setup_id=setup_id,
249
+ setup_version_id=setup_version_id,
250
+ )
251
+
252
+ # Register task in TaskManager (remote mode)
253
+ async def _dummy_coro() -> None:
254
+ """Dummy coroutine - actual execution happens in worker."""
255
+
256
+ await self.create_task(
257
+ job_id,
258
+ mission_id,
259
+ module,
260
+ _dummy_coro(),
261
+ )
262
+
263
+ logger.info("Registered config task: %s, waiting for initial result", job_id)
180
264
  result = await running_task.wait_result(timeout=10)
181
265
  logger.info("Job %s with data %s", job_id, result)
182
266
  return job_id
@@ -191,28 +275,76 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
191
275
  Yields:
192
276
  messages: The stream messages from the associated module.
193
277
  """
194
- queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_queue_size)
278
+ queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
195
279
  self.job_queues[job_id] = queue
196
280
 
197
281
  async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
198
- """Generate the stream allowing flowless communication.
282
+ """Generate the stream with batch-drain optimization.
283
+
284
+ This implementation uses a micro-batching pattern optimized for distributed
285
+ message streams from RabbitMQ:
286
+ 1. Block waiting for the first item (with timeout for termination checks)
287
+ 2. Drain all immediately available items without blocking (micro-batch)
288
+ 3. Yield control back to event loop
289
+
290
+ This pattern provides:
291
+ - Better throughput for bursty message streams
292
+ - Reduced gRPC streaming overhead
293
+ - Lower latency when multiple messages arrive simultaneously
199
294
 
200
295
  Yields:
201
296
  dict: generated object from the module
202
297
  """
203
298
  while True:
204
- item = await queue.get()
205
- queue.task_done()
206
- yield item
207
-
208
- while True:
209
- try:
210
- item = queue.get_nowait()
211
- except asyncio.QueueEmpty:
212
- break
299
+ try:
300
+ # Block for first item with timeout to allow termination checks
301
+ # Configurable timeout (default 15s) to account for distributed system latencies
302
+ item = await asyncio.wait_for(queue.get(), timeout=self.stream_timeout)
213
303
  queue.task_done()
214
304
  yield item
215
305
 
306
+ # Drain all immediately available items (micro-batch optimization)
307
+ # This reduces latency when messages arrive in bursts from RabbitMQ
308
+ batch_count = 0
309
+ max_batch_size = 100 # Safety limit to prevent memory spikes
310
+ while batch_count < max_batch_size:
311
+ try:
312
+ item = queue.get_nowait()
313
+ queue.task_done()
314
+ yield item
315
+ batch_count += 1
316
+ except asyncio.QueueEmpty: # noqa: PERF203
317
+ # No more items immediately available, break to next blocking wait
318
+ break
319
+
320
+ except asyncio.TimeoutError:
321
+ logger.warning("Stream consumer timeout for job %s, checking if job is still active", job_id)
322
+
323
+ # Check if job is registered
324
+ if job_id not in self.tasks_sessions:
325
+ logger.info("Job %s no longer registered, ending stream", job_id)
326
+ break
327
+
328
+ # Check job status to detect cancelled/failed jobs
329
+ status = await self.get_module_status(job_id)
330
+
331
+ if status in {TaskStatus.CANCELLED, TaskStatus.FAILED}:
332
+ logger.info("Job %s has terminal status %s, draining queue and ending stream", job_id, status)
333
+
334
+ # Drain remaining queue items before stopping
335
+ while not queue.empty():
336
+ try:
337
+ item = queue.get_nowait()
338
+ queue.task_done()
339
+ yield item
340
+ except asyncio.QueueEmpty: # noqa: PERF203
341
+ break
342
+
343
+ break
344
+
345
+ # Continue waiting for active/completed jobs
346
+ continue
347
+
216
348
  try:
217
349
  yield _stream()
218
350
  finally:
@@ -241,12 +373,13 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
241
373
  Raises:
242
374
  ValueError: If the task is not found.
243
375
  """
244
- task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_start_module")
376
+ task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_start_module")
245
377
 
246
378
  if task is None:
247
379
  msg = "Task not found"
248
380
  raise ValueError(msg)
249
381
 
382
+ # Submit task to Taskiq
250
383
  running_task: AsyncTaskiqTask[Any] = await task.kiq(
251
384
  mission_id,
252
385
  setup_id,
@@ -257,33 +390,153 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
257
390
  setup_data.model_dump(),
258
391
  )
259
392
  job_id = running_task.task_id
393
+
394
+ # Create module instance for metadata
395
+ module = self.module_class(
396
+ job_id,
397
+ mission_id=mission_id,
398
+ setup_id=setup_id,
399
+ setup_version_id=setup_version_id,
400
+ )
401
+
402
+ # Register task in TaskManager (remote mode)
403
+ # Dummy coroutine will be closed by TaskManager since execution_mode="remote"
404
+ async def _dummy_coro() -> None:
405
+ """Dummy coroutine - actual execution happens in worker."""
406
+
407
+ await self.create_task(
408
+ job_id,
409
+ mission_id,
410
+ module,
411
+ _dummy_coro(), # Will be closed immediately by TaskManager in remote mode
412
+ )
413
+
414
+ logger.info("Registered remote task: %s, waiting for initial result", job_id)
260
415
  result = await running_task.wait_result(timeout=10)
261
416
  logger.debug("Job %s with data %s", job_id, result)
262
417
  return job_id
263
418
 
419
+ async def get_module_status(self, job_id: str) -> TaskStatus:
420
+ """Query a module status from SurrealDB.
421
+
422
+ Args:
423
+ job_id: The unique identifier of the job.
424
+
425
+ Returns:
426
+ TaskStatus: The status of the module task.
427
+ """
428
+ if job_id not in self.tasks_sessions:
429
+ logger.warning("Job %s not found in registry", job_id)
430
+ return TaskStatus.FAILED
431
+
432
+ # Safety check: if channel not initialized (start() wasn't called), return FAILED
433
+ if not hasattr(self, "channel") or self.channel is None:
434
+ logger.warning("Job %s status check failed - channel not initialized", job_id)
435
+ return TaskStatus.FAILED
436
+
437
+ try:
438
+ # Query the tasks table for the task status
439
+ task_record = await self.channel.select_by_task_id("tasks", job_id)
440
+ if task_record and "status" in task_record:
441
+ status_str = task_record["status"]
442
+ return TaskStatus(status_str) if isinstance(status_str, str) else status_str
443
+ # If no record found in tasks, check heartbeats to see if task exists
444
+ heartbeat_record = await self.channel.select_by_task_id("heartbeats", job_id)
445
+ if heartbeat_record:
446
+ return TaskStatus.RUNNING
447
+ # No task or heartbeat record found - task may still be initializing
448
+ logger.debug("No task or heartbeat record found for job %s - task may still be initializing", job_id)
449
+ except Exception:
450
+ logger.exception("Error getting status for job %s", job_id)
451
+ return TaskStatus.FAILED
452
+ else:
453
+ return TaskStatus.FAILED
454
+
455
+ async def wait_for_completion(self, job_id: str) -> None:
456
+ """Wait for a task to complete by polling its status from SurrealDB.
457
+
458
+ This method polls the task status until it reaches a terminal state.
459
+ Uses a 0.5 second polling interval to balance responsiveness and resource usage.
460
+
461
+ Args:
462
+ job_id: The unique identifier of the job to wait for.
463
+
464
+ Raises:
465
+ KeyError: If the job_id is not found in tasks_sessions.
466
+ """
467
+ if job_id not in self.tasks_sessions:
468
+ msg = f"Job {job_id} not found"
469
+ raise KeyError(msg)
470
+
471
+ # Poll task status until terminal state
472
+ terminal_states = {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED}
473
+ while True:
474
+ status = await self.get_module_status(job_id)
475
+ if status in terminal_states:
476
+ logger.debug("Job %s reached terminal state: %s", job_id, status)
477
+ break
478
+ await asyncio.sleep(0.5) # Poll interval
479
+
264
480
  async def stop_module(self, job_id: str) -> bool:
265
- """Revoke (terminate) the Taskiq task with id.
481
+ """Stop a running module using TaskManager.
266
482
 
267
483
  Args:
268
484
  job_id: The Taskiq task id to stop.
269
485
 
270
- Raises:
271
- bool: True if the task was successfully revoked, False otherwise.
486
+ Returns:
487
+ bool: True if the signal was successfully sent, False otherwise.
272
488
  """
273
- msg = "stop_module not implemented in TaskiqJobManager"
274
- raise NotImplementedError(msg)
489
+ if job_id not in self.tasks_sessions:
490
+ logger.warning("Job %s not found in registry", job_id)
491
+ return False
275
492
 
276
- async def stop_all_modules(self) -> None:
277
- """Stop all running modules."""
278
- msg = "stop_all_modules not implemented in TaskiqJobManager"
279
- raise NotImplementedError(msg)
493
+ try:
494
+ session = self.tasks_sessions[job_id]
495
+ # Use TaskManager's cancel_task method which handles signal sending
496
+ await self.cancel_task(job_id, session.mission_id)
497
+ logger.info("Cancel signal sent for job %s via TaskManager", job_id)
280
498
 
281
- async def get_module_status(self, job_id: str) -> TaskStatus:
282
- """Query a module status."""
283
- msg = "get_module_status not implemented in TaskiqJobManager"
284
- raise NotImplementedError(msg)
499
+ # Clean up queue after cancellation
500
+ self.job_queues.pop(job_id, None)
501
+ logger.debug("Cleaned up queue for job %s", job_id)
502
+ except Exception:
503
+ logger.exception("Error stopping job %s", job_id)
504
+ return False
505
+ return True
506
+
507
+ async def stop_all_modules(self) -> None:
508
+ """Stop all running modules tracked in the registry."""
509
+ stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
510
+ if stop_tasks:
511
+ results = await asyncio.gather(*stop_tasks, return_exceptions=True)
512
+ logger.info("Stopped %d modules, results: %s", len(results), results)
285
513
 
286
514
  async def list_modules(self) -> dict[str, dict[str, Any]]:
287
- """List all modules."""
288
- msg = "list_modules not implemented in TaskiqJobManager"
289
- raise NotImplementedError(msg)
515
+ """List all modules tracked in the registry with their statuses.
516
+
517
+ Returns:
518
+ dict[str, dict[str, Any]]: A dictionary containing information about all tracked modules.
519
+ """
520
+ modules_info: dict[str, dict[str, Any]] = {}
521
+
522
+ for job_id in self.tasks_sessions:
523
+ try:
524
+ status = await self.get_module_status(job_id)
525
+ task_record = await self.channel.select_by_task_id("tasks", job_id)
526
+
527
+ modules_info[job_id] = {
528
+ "name": self.module_class.__name__,
529
+ "status": status,
530
+ "class": self.module_class.__name__,
531
+ "mission_id": task_record.get("mission_id") if task_record else "unknown",
532
+ }
533
+ except Exception: # noqa: PERF203
534
+ logger.exception("Error getting info for job %s", job_id)
535
+ modules_info[job_id] = {
536
+ "name": self.module_class.__name__,
537
+ "status": TaskStatus.FAILED,
538
+ "class": self.module_class.__name__,
539
+ "error": "Failed to retrieve status",
540
+ }
541
+
542
+ return modules_info