digitalkin 0.3.0rc1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/core/common/__init__.py +9 -0
  3. digitalkin/core/common/factories.py +156 -0
  4. digitalkin/core/job_manager/base_job_manager.py +128 -28
  5. digitalkin/core/job_manager/single_job_manager.py +80 -25
  6. digitalkin/core/job_manager/taskiq_broker.py +114 -19
  7. digitalkin/core/job_manager/taskiq_job_manager.py +291 -39
  8. digitalkin/core/task_manager/base_task_manager.py +539 -0
  9. digitalkin/core/task_manager/local_task_manager.py +108 -0
  10. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  11. digitalkin/core/task_manager/surrealdb_repository.py +43 -4
  12. digitalkin/core/task_manager/task_executor.py +249 -0
  13. digitalkin/core/task_manager/task_session.py +107 -19
  14. digitalkin/grpc_servers/module_server.py +2 -2
  15. digitalkin/grpc_servers/module_servicer.py +21 -12
  16. digitalkin/grpc_servers/registry_server.py +1 -1
  17. digitalkin/grpc_servers/registry_servicer.py +4 -4
  18. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  19. digitalkin/models/core/task_monitor.py +17 -0
  20. digitalkin/models/grpc_servers/models.py +4 -4
  21. digitalkin/models/module/module_context.py +5 -0
  22. digitalkin/models/module/module_types.py +304 -16
  23. digitalkin/modules/_base_module.py +66 -28
  24. digitalkin/services/cost/grpc_cost.py +8 -41
  25. digitalkin/services/filesystem/grpc_filesystem.py +9 -38
  26. digitalkin/services/services_config.py +11 -0
  27. digitalkin/services/services_models.py +3 -1
  28. digitalkin/services/setup/default_setup.py +5 -6
  29. digitalkin/services/setup/grpc_setup.py +51 -14
  30. digitalkin/services/storage/grpc_storage.py +2 -2
  31. digitalkin/services/user_profile/__init__.py +12 -0
  32. digitalkin/services/user_profile/default_user_profile.py +55 -0
  33. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  34. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  35. digitalkin/utils/__init__.py +28 -0
  36. digitalkin/utils/dynamic_schema.py +483 -0
  37. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/METADATA +9 -29
  38. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/RECORD +42 -30
  39. modules/dynamic_setup_module.py +362 -0
  40. digitalkin/core/task_manager/task_manager.py +0 -439
  41. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/WHEEL +0 -0
  42. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/licenses/LICENSE +0 -0
  43. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  """Taskiq broker & RSTREAM producer for the job manager."""
2
2
 
3
3
  import asyncio
4
+ import datetime
4
5
  import json
5
6
  import logging
6
7
  import os
7
8
  import pickle # noqa: S403
9
+ from typing import Any
8
10
 
9
11
  from rstream import Producer
10
12
  from rstream.exceptions import PreconditionFailed
@@ -14,7 +16,10 @@ from taskiq.compat import model_validate
14
16
  from taskiq.message import BrokerMessage
15
17
  from taskiq_aio_pika import AioPikaBroker
16
18
 
19
+ from digitalkin.core.common import ConnectionFactory, ModuleFactory
17
20
  from digitalkin.core.job_manager.base_job_manager import BaseJobManager
21
+ from digitalkin.core.task_manager.task_executor import TaskExecutor
22
+ from digitalkin.core.task_manager.task_session import TaskSession
18
23
  from digitalkin.logger import logger
19
24
  from digitalkin.models.core.job_manager_models import StreamCodeModel
20
25
  from digitalkin.models.module.module_types import OutputModelT
@@ -118,6 +123,24 @@ RSTREAM_PRODUCER = define_producer()
118
123
  TASKIQ_BROKER = define_broker()
119
124
 
120
125
 
126
+ async def cleanup_global_resources() -> None:
127
+ """Clean up global resources (producer and broker connections).
128
+
129
+ This should be called during shutdown to prevent connection leaks.
130
+ """
131
+ try:
132
+ await RSTREAM_PRODUCER.close()
133
+ logger.info("RStream producer closed successfully")
134
+ except Exception as e:
135
+ logger.warning("Failed to close RStream producer: %s", e)
136
+
137
+ try:
138
+ await TASKIQ_BROKER.shutdown()
139
+ logger.info("Taskiq broker shut down successfully")
140
+ except Exception as e:
141
+ logger.warning("Failed to shutdown Taskiq broker: %s", e)
142
+
143
+
121
144
  async def send_message_to_stream(job_id: str, output_data: OutputModelT) -> None: # type: ignore
122
145
  """Callback define to add a message frame to the Rstream.
123
146
 
@@ -152,27 +175,70 @@ async def run_start_module(
152
175
  setup_data: dict,
153
176
  context: Allow TaskIQ context access
154
177
  """
155
- logger.warning("%s", services_mode)
178
+ logger.info("Starting module with services_mode: %s", services_mode)
156
179
  services_config = ServicesConfig(
157
180
  services_config_strategies=module_class.services_config_strategies,
158
181
  services_config_params=module_class.services_config_params,
159
182
  mode=services_mode,
160
183
  )
161
184
  setattr(module_class, "services_config", services_config)
162
- logger.warning("%s | %s", services_config, module_class.services_config)
185
+ logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
186
+ module_class.discover()
163
187
 
164
188
  job_id = context.message.task_id
165
189
  callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
166
- module = module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
167
-
168
- await module.start(
169
- input_data,
170
- setup_data,
171
- callback,
172
- # ensure that the callback is called when the task is done + allow asyncio to run
173
- # TODO: should define a BaseModel for stream code / error
174
- done_callback=lambda _: asyncio.create_task(callback(StreamCodeModel(code="__END_OF_STREAM__"))),
175
- )
190
+ module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
191
+
192
+ channel = None
193
+ try:
194
+ # Create TaskExecutor and supporting components for worker execution
195
+ executor = TaskExecutor()
196
+ # SurrealDB env vars are expected to be set in env.
197
+ channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
198
+ session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
199
+
200
+ # Execute the task using TaskExecutor
201
+ # Create a proper done callback that handles errors
202
+ async def send_end_of_stream(_: Any) -> None: # noqa: ANN401
203
+ try:
204
+ await callback(StreamCodeModel(code="__END_OF_STREAM__"))
205
+ except Exception as e:
206
+ logger.error("Error sending end of stream: %s", e, exc_info=True)
207
+
208
+ # Reconstruct Pydantic models from dicts for type safety
209
+ try:
210
+ input_model = module_class.create_input_model(input_data)
211
+ setup_model = await module_class.create_setup_model(setup_data)
212
+ except Exception as e:
213
+ logger.error("Failed to reconstruct models for job %s: %s", job_id, e, exc_info=True)
214
+ raise
215
+
216
+ supervisor_task = await executor.execute_task(
217
+ task_id=job_id,
218
+ mission_id=mission_id,
219
+ coro=module.start(
220
+ input_model,
221
+ setup_model,
222
+ callback,
223
+ done_callback=lambda result: asyncio.ensure_future(send_end_of_stream(result)),
224
+ ),
225
+ session=session,
226
+ channel=channel,
227
+ )
228
+
229
+ # Wait for the supervisor task to complete
230
+ await supervisor_task
231
+ logger.info("Module task %s completed", job_id)
232
+ except Exception:
233
+ logger.exception("Error running module %s", job_id)
234
+ raise
235
+ finally:
236
+ # Cleanup channel
237
+ if channel is not None:
238
+ try:
239
+ await channel.close()
240
+ except Exception:
241
+ logger.exception("Error closing channel for job %s", job_id)
176
242
 
177
243
 
178
244
  @TASKIQ_BROKER.task
@@ -196,20 +262,49 @@ async def run_config_module(
196
262
  config_setup_data: dict,
197
263
  context: Allow TaskIQ context access
198
264
  """
199
- logger.warning("%s", services_mode)
265
+ logger.info("Starting config module with services_mode: %s", services_mode)
200
266
  services_config = ServicesConfig(
201
267
  services_config_strategies=module_class.services_config_strategies,
202
268
  services_config_params=module_class.services_config_params,
203
269
  mode=services_mode,
204
270
  )
205
271
  setattr(module_class, "services_config", services_config)
206
- logger.warning("%s | %s", services_config, module_class.services_config)
272
+ logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
207
273
 
208
274
  job_id = context.message.task_id
209
275
  callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
210
- module = module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
276
+ module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
211
277
 
212
- await module.start_config_setup(
213
- module_class.create_config_setup_model(config_setup_data),
214
- callback,
215
- )
278
+ # Override environment variables temporarily to use manager's SurrealDB
279
+ channel = None
280
+ try:
281
+ # Create TaskExecutor and supporting components for worker execution
282
+ executor = TaskExecutor()
283
+ # SurrealDB env vars are expected to be set in env.
284
+ channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
285
+ session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
286
+
287
+ # Create and run the config setup task with TaskExecutor
288
+ setup_model = module_class.create_config_setup_model(config_setup_data)
289
+
290
+ supervisor_task = await executor.execute_task(
291
+ task_id=job_id,
292
+ mission_id=mission_id,
293
+ coro=module.start_config_setup(setup_model, callback),
294
+ session=session,
295
+ channel=channel,
296
+ )
297
+
298
+ # Wait for the supervisor task to complete
299
+ await supervisor_task
300
+ logger.info("Config module task %s completed", job_id)
301
+ except Exception:
302
+ logger.exception("Error running config module %s", job_id)
303
+ raise
304
+ finally:
305
+ # Cleanup channel
306
+ if channel is not None:
307
+ try:
308
+ await channel.close()
309
+ except Exception:
310
+ logger.exception("Error closing channel for job %s", job_id)
@@ -9,19 +9,22 @@ except ImportError:
9
9
 
10
10
  import asyncio
11
11
  import contextlib
12
+ import datetime
12
13
  import json
13
14
  import os
14
15
  from collections.abc import AsyncGenerator, AsyncIterator
15
16
  from contextlib import asynccontextmanager
16
- from typing import TYPE_CHECKING, Any, Generic
17
+ from typing import TYPE_CHECKING, Any
17
18
 
18
19
  from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, OffsetType
19
20
 
21
+ from digitalkin.core.common import ConnectionFactory, QueueFactory
20
22
  from digitalkin.core.job_manager.base_job_manager import BaseJobManager
21
- from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
23
+ from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER, cleanup_global_resources
24
+ from digitalkin.core.task_manager.remote_task_manager import RemoteTaskManager
22
25
  from digitalkin.logger import logger
23
26
  from digitalkin.models.core.task_monitor import TaskStatus
24
- from digitalkin.models.module import InputModelT, SetupModelT
27
+ from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
25
28
  from digitalkin.modules._base_module import BaseModule
26
29
  from digitalkin.services.services_models import ServicesMode
27
30
 
@@ -29,7 +32,7 @@ if TYPE_CHECKING:
29
32
  from taskiq.task import AsyncTaskiqTask
30
33
 
31
34
 
32
- class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
35
+ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
33
36
  """Taskiq job manager for running modules in Taskiq tasks."""
34
37
 
35
38
  services_mode: ServicesMode
@@ -62,6 +65,13 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
62
65
  if queue:
63
66
  await queue.put(data.get("output_data"))
64
67
 
68
+ async def start(self) -> None:
69
+ """Start the TaskiqJobManager and initialize SurrealDB connection."""
70
+ await self._start()
71
+ self.channel = await ConnectionFactory.create_surreal_connection(
72
+ database="taskiq_job_manager", timeout=datetime.timedelta(seconds=5)
73
+ )
74
+
65
75
  async def _start(self) -> None:
66
76
  await TASKIQ_BROKER.startup()
67
77
 
@@ -82,12 +92,34 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
82
92
  callback=self._on_message, # type: ignore
83
93
  offset_specification=start_spec,
84
94
  )
95
+
96
+ # Wrap the consumer task with error handling
97
+ async def run_consumer_with_error_handling() -> None:
98
+ try:
99
+ await self.stream_consumer.run()
100
+ except asyncio.CancelledError:
101
+ logger.debug("Stream consumer task cancelled")
102
+ raise
103
+ except Exception as e:
104
+ logger.error("Stream consumer task failed: %s", e, exc_info=True, extra={"error": str(e)})
105
+ # Re-raise to ensure the error is not silently ignored
106
+ raise
107
+
85
108
  self.stream_consumer_task = asyncio.create_task(
86
- self.stream_consumer.run(),
109
+ run_consumer_with_error_handling(),
87
110
  name="stream_consumer_task",
88
111
  )
89
112
 
90
113
  async def _stop(self) -> None:
114
+ """Stop the TaskiqJobManager and clean up all resources."""
115
+ # Close SurrealDB connection
116
+ if hasattr(self, "channel"):
117
+ try:
118
+ await self.channel.close()
119
+ logger.info("TaskiqJobManager: SurrealDB connection closed")
120
+ except Exception as e:
121
+ logger.warning("Failed to close SurrealDB connection: %s", e)
122
+
91
123
  # Signal the consumer to stop
92
124
  await self.stream_consumer.close()
93
125
  # Cancel the background task
@@ -95,18 +127,40 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
95
127
  with contextlib.suppress(asyncio.CancelledError):
96
128
  await self.stream_consumer_task
97
129
 
130
+ # Clean up job queues
131
+ self.job_queues.clear()
132
+ logger.info("TaskiqJobManager: Cleared %d job queues", len(self.job_queues))
133
+
134
+ # Call global cleanup for producer and broker
135
+ await cleanup_global_resources()
136
+
98
137
  def __init__(
99
138
  self,
100
139
  module_class: type[BaseModule],
101
140
  services_mode: ServicesMode,
141
+ default_timeout: float = 10.0,
142
+ max_concurrent_tasks: int = 100,
143
+ stream_timeout: float = 30.0,
102
144
  ) -> None:
103
- """Initialize the Taskiq job manager."""
104
- super().__init__(module_class, services_mode)
145
+ """Initialize the Taskiq job manager.
146
+
147
+ Args:
148
+ module_class: The class of the module to be managed
149
+ services_mode: The mode of operation for the services
150
+ default_timeout: Default timeout for task operations
151
+ max_concurrent_tasks: Maximum number of concurrent tasks
152
+ stream_timeout: Timeout for stream consumer operations (default: 15.0s for distributed systems)
153
+ """
154
+ # Create remote task manager for distributed execution
155
+ task_manager = RemoteTaskManager(default_timeout, max_concurrent_tasks)
156
+
157
+ # Initialize base job manager with task manager
158
+ super().__init__(module_class, services_mode, task_manager)
105
159
 
106
160
  logger.warning("TaskiqJobManager initialized with app: %s", TASKIQ_BROKER)
107
- self.services_mode = services_mode
108
161
  self.job_queues: dict[str, asyncio.Queue] = {}
109
162
  self.max_queue_size = 1000
163
+ self.stream_timeout = stream_timeout
110
164
 
111
165
  async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
112
166
  """Generate a stream consumer for a module's output data.
@@ -120,12 +174,20 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
120
174
 
121
175
  Returns:
122
176
  SetupModelT: the SetupModelT object fully processed.
177
+
178
+ Raises:
179
+ asyncio.TimeoutError: If waiting for the setup response times out.
123
180
  """
124
- queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_queue_size)
181
+ queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
125
182
  self.job_queues[job_id] = queue
126
183
 
127
184
  try:
128
- item = await queue.get()
185
+ # Add timeout to prevent indefinite blocking
186
+ item = await asyncio.wait_for(queue.get(), timeout=30.0)
187
+ except asyncio.TimeoutError:
188
+ logger.error("Timeout waiting for config setup response for job %s", job_id)
189
+ raise
190
+ else:
129
191
  queue.task_done()
130
192
  return item
131
193
  finally:
@@ -157,7 +219,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
157
219
  TypeError: If the function is called with bad data type.
158
220
  ValueError: If the module fails to start.
159
221
  """
160
- task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_config_module")
222
+ task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_config_module")
161
223
 
162
224
  if task is None:
163
225
  msg = "Task not found"
@@ -167,6 +229,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
167
229
  msg = "config_setup_data must be a valid model with model_dump method"
168
230
  raise TypeError(msg)
169
231
 
232
+ # Submit task to Taskiq
170
233
  running_task: AsyncTaskiqTask[Any] = await task.kiq(
171
234
  mission_id,
172
235
  setup_id,
@@ -177,6 +240,27 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
177
240
  )
178
241
 
179
242
  job_id = running_task.task_id
243
+
244
+ # Create module instance for metadata
245
+ module = self.module_class(
246
+ job_id,
247
+ mission_id=mission_id,
248
+ setup_id=setup_id,
249
+ setup_version_id=setup_version_id,
250
+ )
251
+
252
+ # Register task in TaskManager (remote mode)
253
+ async def _dummy_coro() -> None:
254
+ """Dummy coroutine - actual execution happens in worker."""
255
+
256
+ await self.create_task(
257
+ job_id,
258
+ mission_id,
259
+ module,
260
+ _dummy_coro(),
261
+ )
262
+
263
+ logger.info("Registered config task: %s, waiting for initial result", job_id)
180
264
  result = await running_task.wait_result(timeout=10)
181
265
  logger.info("Job %s with data %s", job_id, result)
182
266
  return job_id
@@ -191,28 +275,75 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
191
275
  Yields:
192
276
  messages: The stream messages from the associated module.
193
277
  """
194
- queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_queue_size)
278
+ queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
195
279
  self.job_queues[job_id] = queue
196
280
 
197
281
  async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
198
- """Generate the stream allowing flowless communication.
282
+ """Generate the stream with batch-drain optimization.
283
+
284
+ This implementation uses a micro-batching pattern optimized for distributed
285
+ message streams from RabbitMQ:
286
+ 1. Block waiting for the first item (with timeout for termination checks)
287
+ 2. Drain all immediately available items without blocking (micro-batch)
288
+ 3. Yield control back to event loop
289
+
290
+ This pattern provides:
291
+ - Better throughput for bursty message streams
292
+ - Reduced gRPC streaming overhead
293
+ - Lower latency when multiple messages arrive simultaneously
199
294
 
200
295
  Yields:
201
296
  dict: generated object from the module
202
297
  """
203
298
  while True:
204
- item = await queue.get()
205
- queue.task_done()
206
- yield item
207
-
208
- while True:
209
- try:
210
- item = queue.get_nowait()
211
- except asyncio.QueueEmpty:
212
- break
299
+ try:
300
+ # Block for first item with timeout to allow termination checks
301
+ item = await asyncio.wait_for(queue.get(), timeout=self.stream_timeout)
213
302
  queue.task_done()
214
303
  yield item
215
304
 
305
+ # Drain all immediately available items (micro-batch optimization)
306
+ # This reduces latency when messages arrive in bursts from RabbitMQ
307
+ batch_count = 0
308
+ max_batch_size = 100 # Safety limit to prevent memory spikes
309
+ while batch_count < max_batch_size:
310
+ try:
311
+ item = queue.get_nowait()
312
+ queue.task_done()
313
+ yield item
314
+ batch_count += 1
315
+ except asyncio.QueueEmpty: # noqa: PERF203
316
+ # No more items immediately available, break to next blocking wait
317
+ break
318
+
319
+ except asyncio.TimeoutError:
320
+ logger.warning("Stream consumer timeout for job %s, checking if job is still active", job_id)
321
+
322
+ # Check if job is registered
323
+ if job_id not in self.tasks_sessions:
324
+ logger.info("Job %s no longer registered, ending stream", job_id)
325
+ break
326
+
327
+ # Check job status to detect cancelled/failed jobs
328
+ status = await self.get_module_status(job_id)
329
+
330
+ if status in {TaskStatus.CANCELLED, TaskStatus.FAILED}:
331
+ logger.info("Job %s has terminal status %s, draining queue and ending stream", job_id, status)
332
+
333
+ # Drain remaining queue items before stopping
334
+ while not queue.empty():
335
+ try:
336
+ item = queue.get_nowait()
337
+ queue.task_done()
338
+ yield item
339
+ except asyncio.QueueEmpty: # noqa: PERF203
340
+ break
341
+
342
+ break
343
+
344
+ # Continue waiting for active/completed jobs
345
+ continue
346
+
216
347
  try:
217
348
  yield _stream()
218
349
  finally:
@@ -241,12 +372,13 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
241
372
  Raises:
242
373
  ValueError: If the task is not found.
243
374
  """
244
- task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_start_module")
375
+ task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_start_module")
245
376
 
246
377
  if task is None:
247
378
  msg = "Task not found"
248
379
  raise ValueError(msg)
249
380
 
381
+ # Submit task to Taskiq
250
382
  running_task: AsyncTaskiqTask[Any] = await task.kiq(
251
383
  mission_id,
252
384
  setup_id,
@@ -257,33 +389,153 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
257
389
  setup_data.model_dump(),
258
390
  )
259
391
  job_id = running_task.task_id
392
+
393
+ # Create module instance for metadata
394
+ module = self.module_class(
395
+ job_id,
396
+ mission_id=mission_id,
397
+ setup_id=setup_id,
398
+ setup_version_id=setup_version_id,
399
+ )
400
+
401
+ # Register task in TaskManager (remote mode)
402
+ # Dummy coroutine will be closed by TaskManager since execution_mode="remote"
403
+ async def _dummy_coro() -> None:
404
+ """Dummy coroutine - actual execution happens in worker."""
405
+
406
+ await self.create_task(
407
+ job_id,
408
+ mission_id,
409
+ module,
410
+ _dummy_coro(), # Will be closed immediately by TaskManager in remote mode
411
+ )
412
+
413
+ logger.info("Registered remote task: %s, waiting for initial result", job_id)
260
414
  result = await running_task.wait_result(timeout=10)
261
415
  logger.debug("Job %s with data %s", job_id, result)
262
416
  return job_id
263
417
 
418
+ async def get_module_status(self, job_id: str) -> TaskStatus:
419
+ """Query a module status from SurrealDB.
420
+
421
+ Args:
422
+ job_id: The unique identifier of the job.
423
+
424
+ Returns:
425
+ TaskStatus: The status of the module task.
426
+ """
427
+ if job_id not in self.tasks_sessions:
428
+ logger.warning("Job %s not found in registry", job_id)
429
+ return TaskStatus.FAILED
430
+
431
+ # Safety check: if channel not initialized (start() wasn't called), return FAILED
432
+ if not hasattr(self, "channel") or self.channel is None:
433
+ logger.warning("Job %s status check failed - channel not initialized", job_id)
434
+ return TaskStatus.FAILED
435
+
436
+ try:
437
+ # Query the tasks table for the task status
438
+ task_record = await self.channel.select_by_task_id("tasks", job_id)
439
+ if task_record and "status" in task_record:
440
+ status_str = task_record["status"]
441
+ return TaskStatus(status_str) if isinstance(status_str, str) else status_str
442
+ # If no record found in tasks, check heartbeats to see if task exists
443
+ heartbeat_record = await self.channel.select_by_task_id("heartbeats", job_id)
444
+ if heartbeat_record:
445
+ return TaskStatus.RUNNING
446
+ # No task or heartbeat record found - task may still be initializing
447
+ logger.debug("No task or heartbeat record found for job %s - task may still be initializing", job_id)
448
+ except Exception:
449
+ logger.exception("Error getting status for job %s", job_id)
450
+ return TaskStatus.FAILED
451
+ else:
452
+ return TaskStatus.FAILED
453
+
454
+ async def wait_for_completion(self, job_id: str) -> None:
455
+ """Wait for a task to complete by polling its status from SurrealDB.
456
+
457
+ This method polls the task status until it reaches a terminal state.
458
+ Uses a 0.5 second polling interval to balance responsiveness and resource usage.
459
+
460
+ Args:
461
+ job_id: The unique identifier of the job to wait for.
462
+
463
+ Raises:
464
+ KeyError: If the job_id is not found in tasks_sessions.
465
+ """
466
+ if job_id not in self.tasks_sessions:
467
+ msg = f"Job {job_id} not found"
468
+ raise KeyError(msg)
469
+
470
+ # Poll task status until terminal state
471
+ terminal_states = {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED}
472
+ while True:
473
+ status = await self.get_module_status(job_id)
474
+ if status in terminal_states:
475
+ logger.debug("Job %s reached terminal state: %s", job_id, status)
476
+ break
477
+ await asyncio.sleep(0.5) # Poll interval
478
+
264
479
  async def stop_module(self, job_id: str) -> bool:
265
- """Revoke (terminate) the Taskiq task with id.
480
+ """Stop a running module using TaskManager.
266
481
 
267
482
  Args:
268
483
  job_id: The Taskiq task id to stop.
269
484
 
270
- Raises:
271
- bool: True if the task was successfully revoked, False otherwise.
485
+ Returns:
486
+ bool: True if the signal was successfully sent, False otherwise.
272
487
  """
273
- msg = "stop_module not implemented in TaskiqJobManager"
274
- raise NotImplementedError(msg)
488
+ if job_id not in self.tasks_sessions:
489
+ logger.warning("Job %s not found in registry", job_id)
490
+ return False
275
491
 
276
- async def stop_all_modules(self) -> None:
277
- """Stop all running modules."""
278
- msg = "stop_all_modules not implemented in TaskiqJobManager"
279
- raise NotImplementedError(msg)
492
+ try:
493
+ session = self.tasks_sessions[job_id]
494
+ # Use TaskManager's cancel_task method which handles signal sending
495
+ await self.cancel_task(job_id, session.mission_id)
496
+ logger.info("Cancel signal sent for job %s via TaskManager", job_id)
280
497
 
281
- async def get_module_status(self, job_id: str) -> TaskStatus:
282
- """Query a module status."""
283
- msg = "get_module_status not implemented in TaskiqJobManager"
284
- raise NotImplementedError(msg)
498
+ # Clean up queue after cancellation
499
+ self.job_queues.pop(job_id, None)
500
+ logger.debug("Cleaned up queue for job %s", job_id)
501
+ except Exception:
502
+ logger.exception("Error stopping job %s", job_id)
503
+ return False
504
+ return True
505
+
506
+ async def stop_all_modules(self) -> None:
507
+ """Stop all running modules tracked in the registry."""
508
+ stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
509
+ if stop_tasks:
510
+ results = await asyncio.gather(*stop_tasks, return_exceptions=True)
511
+ logger.info("Stopped %d modules, results: %s", len(results), results)
285
512
 
286
513
  async def list_modules(self) -> dict[str, dict[str, Any]]:
287
- """List all modules."""
288
- msg = "list_modules not implemented in TaskiqJobManager"
289
- raise NotImplementedError(msg)
514
+ """List all modules tracked in the registry with their statuses.
515
+
516
+ Returns:
517
+ dict[str, dict[str, Any]]: A dictionary containing information about all tracked modules.
518
+ """
519
+ modules_info: dict[str, dict[str, Any]] = {}
520
+
521
+ for job_id in self.tasks_sessions:
522
+ try:
523
+ status = await self.get_module_status(job_id)
524
+ task_record = await self.channel.select_by_task_id("tasks", job_id)
525
+
526
+ modules_info[job_id] = {
527
+ "name": self.module_class.__name__,
528
+ "status": status,
529
+ "class": self.module_class.__name__,
530
+ "mission_id": task_record.get("mission_id") if task_record else "unknown",
531
+ }
532
+ except Exception: # noqa: PERF203
533
+ logger.exception("Error getting info for job %s", job_id)
534
+ modules_info[job_id] = {
535
+ "name": self.module_class.__name__,
536
+ "status": TaskStatus.FAILED,
537
+ "class": self.module_class.__name__,
538
+ "error": "Failed to retrieve status",
539
+ }
540
+
541
+ return modules_info