digitalkin 0.3.2.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. base_server/__init__.py +1 -0
  2. base_server/mock/__init__.py +5 -0
  3. base_server/mock/mock_pb2.py +39 -0
  4. base_server/mock/mock_pb2_grpc.py +102 -0
  5. base_server/server_async_insecure.py +125 -0
  6. base_server/server_async_secure.py +143 -0
  7. base_server/server_sync_insecure.py +103 -0
  8. base_server/server_sync_secure.py +122 -0
  9. digitalkin/__init__.py +8 -0
  10. digitalkin/__version__.py +8 -0
  11. digitalkin/core/__init__.py +1 -0
  12. digitalkin/core/common/__init__.py +9 -0
  13. digitalkin/core/common/factories.py +156 -0
  14. digitalkin/core/job_manager/__init__.py +1 -0
  15. digitalkin/core/job_manager/base_job_manager.py +288 -0
  16. digitalkin/core/job_manager/single_job_manager.py +354 -0
  17. digitalkin/core/job_manager/taskiq_broker.py +311 -0
  18. digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
  19. digitalkin/core/task_manager/__init__.py +1 -0
  20. digitalkin/core/task_manager/base_task_manager.py +539 -0
  21. digitalkin/core/task_manager/local_task_manager.py +108 -0
  22. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  23. digitalkin/core/task_manager/surrealdb_repository.py +266 -0
  24. digitalkin/core/task_manager/task_executor.py +249 -0
  25. digitalkin/core/task_manager/task_session.py +406 -0
  26. digitalkin/grpc_servers/__init__.py +1 -0
  27. digitalkin/grpc_servers/_base_server.py +486 -0
  28. digitalkin/grpc_servers/module_server.py +208 -0
  29. digitalkin/grpc_servers/module_servicer.py +516 -0
  30. digitalkin/grpc_servers/utils/__init__.py +1 -0
  31. digitalkin/grpc_servers/utils/exceptions.py +29 -0
  32. digitalkin/grpc_servers/utils/grpc_client_wrapper.py +88 -0
  33. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  34. digitalkin/grpc_servers/utils/utility_schema_extender.py +97 -0
  35. digitalkin/logger.py +157 -0
  36. digitalkin/mixins/__init__.py +19 -0
  37. digitalkin/mixins/base_mixin.py +10 -0
  38. digitalkin/mixins/callback_mixin.py +24 -0
  39. digitalkin/mixins/chat_history_mixin.py +110 -0
  40. digitalkin/mixins/cost_mixin.py +76 -0
  41. digitalkin/mixins/file_history_mixin.py +93 -0
  42. digitalkin/mixins/filesystem_mixin.py +46 -0
  43. digitalkin/mixins/logger_mixin.py +51 -0
  44. digitalkin/mixins/storage_mixin.py +79 -0
  45. digitalkin/models/__init__.py +8 -0
  46. digitalkin/models/core/__init__.py +1 -0
  47. digitalkin/models/core/job_manager_models.py +36 -0
  48. digitalkin/models/core/task_monitor.py +70 -0
  49. digitalkin/models/grpc_servers/__init__.py +1 -0
  50. digitalkin/models/grpc_servers/models.py +275 -0
  51. digitalkin/models/grpc_servers/types.py +24 -0
  52. digitalkin/models/module/__init__.py +25 -0
  53. digitalkin/models/module/module.py +40 -0
  54. digitalkin/models/module/module_context.py +149 -0
  55. digitalkin/models/module/module_types.py +393 -0
  56. digitalkin/models/module/utility.py +146 -0
  57. digitalkin/models/services/__init__.py +10 -0
  58. digitalkin/models/services/cost.py +54 -0
  59. digitalkin/models/services/registry.py +42 -0
  60. digitalkin/models/services/storage.py +44 -0
  61. digitalkin/modules/__init__.py +11 -0
  62. digitalkin/modules/_base_module.py +517 -0
  63. digitalkin/modules/archetype_module.py +23 -0
  64. digitalkin/modules/tool_module.py +23 -0
  65. digitalkin/modules/trigger_handler.py +48 -0
  66. digitalkin/modules/triggers/__init__.py +12 -0
  67. digitalkin/modules/triggers/healthcheck_ping_trigger.py +45 -0
  68. digitalkin/modules/triggers/healthcheck_services_trigger.py +63 -0
  69. digitalkin/modules/triggers/healthcheck_status_trigger.py +52 -0
  70. digitalkin/py.typed +0 -0
  71. digitalkin/services/__init__.py +30 -0
  72. digitalkin/services/agent/__init__.py +6 -0
  73. digitalkin/services/agent/agent_strategy.py +19 -0
  74. digitalkin/services/agent/default_agent.py +13 -0
  75. digitalkin/services/base_strategy.py +22 -0
  76. digitalkin/services/communication/__init__.py +7 -0
  77. digitalkin/services/communication/communication_strategy.py +76 -0
  78. digitalkin/services/communication/default_communication.py +101 -0
  79. digitalkin/services/communication/grpc_communication.py +223 -0
  80. digitalkin/services/cost/__init__.py +14 -0
  81. digitalkin/services/cost/cost_strategy.py +100 -0
  82. digitalkin/services/cost/default_cost.py +114 -0
  83. digitalkin/services/cost/grpc_cost.py +138 -0
  84. digitalkin/services/filesystem/__init__.py +7 -0
  85. digitalkin/services/filesystem/default_filesystem.py +417 -0
  86. digitalkin/services/filesystem/filesystem_strategy.py +252 -0
  87. digitalkin/services/filesystem/grpc_filesystem.py +317 -0
  88. digitalkin/services/identity/__init__.py +6 -0
  89. digitalkin/services/identity/default_identity.py +15 -0
  90. digitalkin/services/identity/identity_strategy.py +14 -0
  91. digitalkin/services/registry/__init__.py +27 -0
  92. digitalkin/services/registry/default_registry.py +141 -0
  93. digitalkin/services/registry/exceptions.py +47 -0
  94. digitalkin/services/registry/grpc_registry.py +306 -0
  95. digitalkin/services/registry/registry_models.py +43 -0
  96. digitalkin/services/registry/registry_strategy.py +98 -0
  97. digitalkin/services/services_config.py +200 -0
  98. digitalkin/services/services_models.py +65 -0
  99. digitalkin/services/setup/__init__.py +1 -0
  100. digitalkin/services/setup/default_setup.py +219 -0
  101. digitalkin/services/setup/grpc_setup.py +343 -0
  102. digitalkin/services/setup/setup_strategy.py +145 -0
  103. digitalkin/services/snapshot/__init__.py +6 -0
  104. digitalkin/services/snapshot/default_snapshot.py +39 -0
  105. digitalkin/services/snapshot/snapshot_strategy.py +30 -0
  106. digitalkin/services/storage/__init__.py +7 -0
  107. digitalkin/services/storage/default_storage.py +228 -0
  108. digitalkin/services/storage/grpc_storage.py +214 -0
  109. digitalkin/services/storage/storage_strategy.py +273 -0
  110. digitalkin/services/user_profile/__init__.py +12 -0
  111. digitalkin/services/user_profile/default_user_profile.py +55 -0
  112. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  113. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  114. digitalkin/utils/__init__.py +29 -0
  115. digitalkin/utils/arg_parser.py +92 -0
  116. digitalkin/utils/development_mode_action.py +51 -0
  117. digitalkin/utils/dynamic_schema.py +483 -0
  118. digitalkin/utils/llm_ready_schema.py +75 -0
  119. digitalkin/utils/package_discover.py +357 -0
  120. digitalkin-0.3.2.dev2.dist-info/METADATA +602 -0
  121. digitalkin-0.3.2.dev2.dist-info/RECORD +131 -0
  122. digitalkin-0.3.2.dev2.dist-info/WHEEL +5 -0
  123. digitalkin-0.3.2.dev2.dist-info/licenses/LICENSE +430 -0
  124. digitalkin-0.3.2.dev2.dist-info/top_level.txt +4 -0
  125. modules/__init__.py +0 -0
  126. modules/cpu_intensive_module.py +280 -0
  127. modules/dynamic_setup_module.py +338 -0
  128. modules/minimal_llm_module.py +347 -0
  129. modules/text_transform_module.py +203 -0
  130. services/filesystem_module.py +200 -0
  131. services/storage_module.py +206 -0
@@ -0,0 +1,541 @@
1
+ """Taskiq job manager module."""
2
+
3
+ try:
4
+ import taskiq # noqa: F401
5
+
6
+ except ImportError:
7
+ msg = "Install digitalkin[taskiq] to use this functionality\n$ uv pip install digitalkin[taskiq]."
8
+ raise ImportError(msg)
9
+
10
+ import asyncio
11
+ import contextlib
12
+ import datetime
13
+ import json
14
+ import os
15
+ from collections.abc import AsyncGenerator, AsyncIterator
16
+ from contextlib import asynccontextmanager
17
+ from typing import TYPE_CHECKING, Any
18
+
19
+ from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, OffsetType
20
+
21
+ from digitalkin.core.common import ConnectionFactory, QueueFactory
22
+ from digitalkin.core.job_manager.base_job_manager import BaseJobManager
23
+ from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER, cleanup_global_resources
24
+ from digitalkin.core.task_manager.remote_task_manager import RemoteTaskManager
25
+ from digitalkin.logger import logger
26
+ from digitalkin.models.core.task_monitor import TaskStatus
27
+ from digitalkin.models.module.module_types import InputModelT, OutputModelT, SetupModelT
28
+ from digitalkin.modules._base_module import BaseModule
29
+ from digitalkin.services.services_models import ServicesMode
30
+
31
+ if TYPE_CHECKING:
32
+ from taskiq.task import AsyncTaskiqTask
33
+
34
+
35
+ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
36
+ """Taskiq job manager for running modules in Taskiq tasks."""
37
+
38
+ services_mode: ServicesMode
39
+
40
+ @staticmethod
41
+ def _define_consumer() -> Consumer:
42
+ """Get from the env the connection parameter to RabbitMQ.
43
+
44
+ Returns:
45
+ Consumer
46
+ """
47
+ host: str = os.environ.get("RABBITMQ_RSTREAM_HOST", "localhost")
48
+ port: str = os.environ.get("RABBITMQ_RSTREAM_PORT", "5552")
49
+ username: str = os.environ.get("RABBITMQ_RSTREAM_USERNAME", "guest")
50
+ password: str = os.environ.get("RABBITMQ_RSTREAM_PASSWORD", "guest")
51
+
52
+ logger.info("Connection to RabbitMQ: %s:%s.", host, port)
53
+ return Consumer(host=host, port=int(port), username=username, password=password)
54
+
55
+ async def _on_message(self, message: bytes, message_context: MessageContext) -> None: # noqa: ARG002
56
+ """Internal callback: parse JSON and route to the correct job queue."""
57
+ try:
58
+ data = json.loads(message.decode("utf-8"))
59
+ except json.JSONDecodeError:
60
+ return
61
+ job_id = data.get("job_id")
62
+ if not job_id:
63
+ return
64
+ queue = self.job_queues.get(job_id)
65
+ if queue:
66
+ await queue.put(data.get("output_data"))
67
+
68
+ async def start(self) -> None:
69
+ """Start the TaskiqJobManager and initialize SurrealDB connection."""
70
+ await self._start()
71
+ self.channel = await ConnectionFactory.create_surreal_connection(
72
+ database="taskiq_job_manager", timeout=datetime.timedelta(seconds=5)
73
+ )
74
+
75
+ async def _start(self) -> None:
76
+ await TASKIQ_BROKER.startup()
77
+
78
+ self.stream_consumer = self._define_consumer()
79
+
80
+ await self.stream_consumer.create_stream(
81
+ STREAM,
82
+ exists_ok=True,
83
+ arguments={"max-length-bytes": STREAM_RETENTION},
84
+ )
85
+ await self.stream_consumer.start()
86
+
87
+ start_spec = ConsumerOffsetSpecification(OffsetType.LAST)
88
+ # on_message use bytes instead of AMQPMessage
89
+ await self.stream_consumer.subscribe(
90
+ stream=STREAM,
91
+ subscriber_name=f"""subscriber_{os.environ.get("SERVER_NAME", "module_servicer")}""",
92
+ callback=self._on_message, # type: ignore
93
+ offset_specification=start_spec,
94
+ )
95
+
96
+ # Wrap the consumer task with error handling
97
+ async def run_consumer_with_error_handling() -> None:
98
+ try:
99
+ await self.stream_consumer.run()
100
+ except asyncio.CancelledError:
101
+ logger.debug("Stream consumer task cancelled")
102
+ raise
103
+ except Exception as e:
104
+ logger.error("Stream consumer task failed: %s", e, exc_info=True, extra={"error": str(e)})
105
+ # Re-raise to ensure the error is not silently ignored
106
+ raise
107
+
108
+ self.stream_consumer_task = asyncio.create_task(
109
+ run_consumer_with_error_handling(),
110
+ name="stream_consumer_task",
111
+ )
112
+
113
+ async def _stop(self) -> None:
114
+ """Stop the TaskiqJobManager and clean up all resources."""
115
+ # Close SurrealDB connection
116
+ if hasattr(self, "channel"):
117
+ try:
118
+ await self.channel.close()
119
+ logger.info("TaskiqJobManager: SurrealDB connection closed")
120
+ except Exception as e:
121
+ logger.warning("Failed to close SurrealDB connection: %s", e)
122
+
123
+ # Signal the consumer to stop
124
+ await self.stream_consumer.close()
125
+ # Cancel the background task
126
+ self.stream_consumer_task.cancel()
127
+ with contextlib.suppress(asyncio.CancelledError):
128
+ await self.stream_consumer_task
129
+
130
+ # Clean up job queues
131
+ self.job_queues.clear()
132
+ logger.info("TaskiqJobManager: Cleared %d job queues", len(self.job_queues))
133
+
134
+ # Call global cleanup for producer and broker
135
+ await cleanup_global_resources()
136
+
137
+ def __init__(
138
+ self,
139
+ module_class: type[BaseModule],
140
+ services_mode: ServicesMode,
141
+ default_timeout: float = 10.0,
142
+ max_concurrent_tasks: int = 100,
143
+ stream_timeout: float = 30.0,
144
+ ) -> None:
145
+ """Initialize the Taskiq job manager.
146
+
147
+ Args:
148
+ module_class: The class of the module to be managed
149
+ services_mode: The mode of operation for the services
150
+ default_timeout: Default timeout for task operations
151
+ max_concurrent_tasks: Maximum number of concurrent tasks
152
+ stream_timeout: Timeout for stream consumer operations (default: 15.0s for distributed systems)
153
+ """
154
+ # Create remote task manager for distributed execution
155
+ task_manager = RemoteTaskManager(default_timeout, max_concurrent_tasks)
156
+
157
+ # Initialize base job manager with task manager
158
+ super().__init__(module_class, services_mode, task_manager)
159
+
160
+ logger.warning("TaskiqJobManager initialized with app: %s", TASKIQ_BROKER)
161
+ self.job_queues: dict[str, asyncio.Queue] = {}
162
+ self.max_queue_size = 1000
163
+ self.stream_timeout = stream_timeout
164
+
165
+ async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
166
+ """Generate a stream consumer for a module's output data.
167
+
168
+ This method creates an asynchronous generator that streams output data
169
+ from a specific module job. If the module does not exist, it generates
170
+ an error message.
171
+
172
+ Args:
173
+ job_id: The unique identifier of the job.
174
+
175
+ Returns:
176
+ SetupModelT: the SetupModelT object fully processed.
177
+
178
+ Raises:
179
+ asyncio.TimeoutError: If waiting for the setup response times out.
180
+ """
181
+ queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
182
+ self.job_queues[job_id] = queue
183
+
184
+ try:
185
+ # Add timeout to prevent indefinite blocking
186
+ item = await asyncio.wait_for(queue.get(), timeout=30.0)
187
+ except asyncio.TimeoutError:
188
+ logger.error("Timeout waiting for config setup response for job %s", job_id)
189
+ raise
190
+ else:
191
+ queue.task_done()
192
+ return item
193
+ finally:
194
+ logger.info(f"generate_config_setup_module_response: {job_id=}: {self.job_queues[job_id].empty()}")
195
+ self.job_queues.pop(job_id, None)
196
+
197
+ async def create_config_setup_instance_job(
198
+ self,
199
+ config_setup_data: SetupModelT,
200
+ mission_id: str,
201
+ setup_id: str,
202
+ setup_version_id: str,
203
+ ) -> str:
204
+ """Create and start a new module setup configuration job.
205
+
206
+ This method initializes a new module job, assigns it a unique job ID,
207
+ and starts the config setup it in the background.
208
+
209
+ Args:
210
+ config_setup_data: The input data required to start the job.
211
+ mission_id: The mission ID associated with the job.
212
+ setup_id: The setup ID associated with the module.
213
+ setup_version_id: The setup ID.
214
+
215
+ Returns:
216
+ str: The unique identifier (job ID) of the created job.
217
+
218
+ Raises:
219
+ TypeError: If the function is called with bad data type.
220
+ ValueError: If the module fails to start.
221
+ """
222
+ task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_config_module")
223
+
224
+ if task is None:
225
+ msg = "Task not found"
226
+ raise ValueError(msg)
227
+
228
+ if config_setup_data is None:
229
+ msg = "config_setup_data must be a valid model with model_dump method"
230
+ raise TypeError(msg)
231
+
232
+ # Submit task to Taskiq
233
+ running_task: AsyncTaskiqTask[Any] = await task.kiq(
234
+ mission_id,
235
+ setup_id,
236
+ setup_version_id,
237
+ self.module_class,
238
+ self.services_mode,
239
+ config_setup_data.model_dump(), # type: ignore
240
+ )
241
+
242
+ job_id = running_task.task_id
243
+
244
+ # Create module instance for metadata
245
+ module = self.module_class(
246
+ job_id,
247
+ mission_id=mission_id,
248
+ setup_id=setup_id,
249
+ setup_version_id=setup_version_id,
250
+ )
251
+
252
+ # Register task in TaskManager (remote mode)
253
+ async def _dummy_coro() -> None:
254
+ """Dummy coroutine - actual execution happens in worker."""
255
+
256
+ await self.create_task(
257
+ job_id,
258
+ mission_id,
259
+ module,
260
+ _dummy_coro(),
261
+ )
262
+
263
+ logger.info("Registered config task: %s, waiting for initial result", job_id)
264
+ result = await running_task.wait_result(timeout=10)
265
+ logger.info("Job %s with data %s", job_id, result)
266
+ return job_id
267
+
268
+ @asynccontextmanager # type: ignore
269
+ async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
270
+ """Generate a stream consumer for the RStream stream.
271
+
272
+ Args:
273
+ job_id: The job ID to filter messages.
274
+
275
+ Yields:
276
+ messages: The stream messages from the associated module.
277
+ """
278
+ queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
279
+ self.job_queues[job_id] = queue
280
+
281
+ async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
282
+ """Generate the stream with batch-drain optimization.
283
+
284
+ This implementation uses a micro-batching pattern optimized for distributed
285
+ message streams from RabbitMQ:
286
+ 1. Block waiting for the first item (with timeout for termination checks)
287
+ 2. Drain all immediately available items without blocking (micro-batch)
288
+ 3. Yield control back to event loop
289
+
290
+ This pattern provides:
291
+ - Better throughput for bursty message streams
292
+ - Reduced gRPC streaming overhead
293
+ - Lower latency when multiple messages arrive simultaneously
294
+
295
+ Yields:
296
+ dict: generated object from the module
297
+ """
298
+ while True:
299
+ try:
300
+ # Block for first item with timeout to allow termination checks
301
+ item = await asyncio.wait_for(queue.get(), timeout=self.stream_timeout)
302
+ queue.task_done()
303
+ yield item
304
+
305
+ # Drain all immediately available items (micro-batch optimization)
306
+ # This reduces latency when messages arrive in bursts from RabbitMQ
307
+ batch_count = 0
308
+ max_batch_size = 100 # Safety limit to prevent memory spikes
309
+ while batch_count < max_batch_size:
310
+ try:
311
+ item = queue.get_nowait()
312
+ queue.task_done()
313
+ yield item
314
+ batch_count += 1
315
+ except asyncio.QueueEmpty: # noqa: PERF203
316
+ # No more items immediately available, break to next blocking wait
317
+ break
318
+
319
+ except asyncio.TimeoutError:
320
+ logger.warning("Stream consumer timeout for job %s, checking if job is still active", job_id)
321
+
322
+ # Check if job is registered
323
+ if job_id not in self.tasks_sessions:
324
+ logger.info("Job %s no longer registered, ending stream", job_id)
325
+ break
326
+
327
+ # Check job status to detect cancelled/failed jobs
328
+ status = await self.get_module_status(job_id)
329
+
330
+ if status in {TaskStatus.CANCELLED, TaskStatus.FAILED}:
331
+ logger.info("Job %s has terminal status %s, draining queue and ending stream", job_id, status)
332
+
333
+ # Drain remaining queue items before stopping
334
+ while not queue.empty():
335
+ try:
336
+ item = queue.get_nowait()
337
+ queue.task_done()
338
+ yield item
339
+ except asyncio.QueueEmpty: # noqa: PERF203
340
+ break
341
+
342
+ break
343
+
344
+ # Continue waiting for active/completed jobs
345
+ continue
346
+
347
+ try:
348
+ yield _stream()
349
+ finally:
350
+ self.job_queues.pop(job_id, None)
351
+
352
+ async def create_module_instance_job(
353
+ self,
354
+ input_data: InputModelT,
355
+ setup_data: SetupModelT,
356
+ mission_id: str,
357
+ setup_id: str,
358
+ setup_version_id: str,
359
+ ) -> str:
360
+ """Launches the module_task in Taskiq, returns the Taskiq task id as job_id.
361
+
362
+ Args:
363
+ input_data: Input data for the module
364
+ setup_data: Setup data for the module
365
+ mission_id: Mission ID for the module
366
+ setup_id: The setup ID associated with the module.
367
+ setup_version_id: The setup ID associated with the module.
368
+
369
+ Returns:
370
+ job_id: The Taskiq task id.
371
+
372
+ Raises:
373
+ ValueError: If the task is not found.
374
+ """
375
+ task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_start_module")
376
+
377
+ if task is None:
378
+ msg = "Task not found"
379
+ raise ValueError(msg)
380
+
381
+ # Submit task to Taskiq
382
+ running_task: AsyncTaskiqTask[Any] = await task.kiq(
383
+ mission_id,
384
+ setup_id,
385
+ setup_version_id,
386
+ self.module_class,
387
+ self.services_mode,
388
+ input_data.model_dump(),
389
+ setup_data.model_dump(),
390
+ )
391
+ job_id = running_task.task_id
392
+
393
+ # Create module instance for metadata
394
+ module = self.module_class(
395
+ job_id,
396
+ mission_id=mission_id,
397
+ setup_id=setup_id,
398
+ setup_version_id=setup_version_id,
399
+ )
400
+
401
+ # Register task in TaskManager (remote mode)
402
+ # Dummy coroutine will be closed by TaskManager since execution_mode="remote"
403
+ async def _dummy_coro() -> None:
404
+ """Dummy coroutine - actual execution happens in worker."""
405
+
406
+ await self.create_task(
407
+ job_id,
408
+ mission_id,
409
+ module,
410
+ _dummy_coro(), # Will be closed immediately by TaskManager in remote mode
411
+ )
412
+
413
+ logger.info("Registered remote task: %s, waiting for initial result", job_id)
414
+ result = await running_task.wait_result(timeout=10)
415
+ logger.debug("Job %s with data %s", job_id, result)
416
+ return job_id
417
+
418
+ async def get_module_status(self, job_id: str) -> TaskStatus:
419
+ """Query a module status from SurrealDB.
420
+
421
+ Args:
422
+ job_id: The unique identifier of the job.
423
+
424
+ Returns:
425
+ TaskStatus: The status of the module task.
426
+ """
427
+ if job_id not in self.tasks_sessions:
428
+ logger.warning("Job %s not found in registry", job_id)
429
+ return TaskStatus.FAILED
430
+
431
+ # Safety check: if channel not initialized (start() wasn't called), return FAILED
432
+ if not hasattr(self, "channel") or self.channel is None:
433
+ logger.warning("Job %s status check failed - channel not initialized", job_id)
434
+ return TaskStatus.FAILED
435
+
436
+ try:
437
+ # Query the tasks table for the task status
438
+ task_record = await self.channel.select_by_task_id("tasks", job_id)
439
+ if task_record and "status" in task_record:
440
+ status_str = task_record["status"]
441
+ return TaskStatus(status_str) if isinstance(status_str, str) else status_str
442
+ # If no record found in tasks, check heartbeats to see if task exists
443
+ heartbeat_record = await self.channel.select_by_task_id("heartbeats", job_id)
444
+ if heartbeat_record:
445
+ return TaskStatus.RUNNING
446
+ # No task or heartbeat record found - task may still be initializing
447
+ logger.debug("No task or heartbeat record found for job %s - task may still be initializing", job_id)
448
+ except Exception:
449
+ logger.exception("Error getting status for job %s", job_id)
450
+ return TaskStatus.FAILED
451
+ else:
452
+ return TaskStatus.FAILED
453
+
454
+ async def wait_for_completion(self, job_id: str) -> None:
455
+ """Wait for a task to complete by polling its status from SurrealDB.
456
+
457
+ This method polls the task status until it reaches a terminal state.
458
+ Uses a 0.5 second polling interval to balance responsiveness and resource usage.
459
+
460
+ Args:
461
+ job_id: The unique identifier of the job to wait for.
462
+
463
+ Raises:
464
+ KeyError: If the job_id is not found in tasks_sessions.
465
+ """
466
+ if job_id not in self.tasks_sessions:
467
+ msg = f"Job {job_id} not found"
468
+ raise KeyError(msg)
469
+
470
+ # Poll task status until terminal state
471
+ terminal_states = {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED}
472
+ while True:
473
+ status = await self.get_module_status(job_id)
474
+ if status in terminal_states:
475
+ logger.debug("Job %s reached terminal state: %s", job_id, status)
476
+ break
477
+ await asyncio.sleep(0.5) # Poll interval
478
+
479
+ async def stop_module(self, job_id: str) -> bool:
480
+ """Stop a running module using TaskManager.
481
+
482
+ Args:
483
+ job_id: The Taskiq task id to stop.
484
+
485
+ Returns:
486
+ bool: True if the signal was successfully sent, False otherwise.
487
+ """
488
+ if job_id not in self.tasks_sessions:
489
+ logger.warning("Job %s not found in registry", job_id)
490
+ return False
491
+
492
+ try:
493
+ session = self.tasks_sessions[job_id]
494
+ # Use TaskManager's cancel_task method which handles signal sending
495
+ await self.cancel_task(job_id, session.mission_id)
496
+ logger.info("Cancel signal sent for job %s via TaskManager", job_id)
497
+
498
+ # Clean up queue after cancellation
499
+ self.job_queues.pop(job_id, None)
500
+ logger.debug("Cleaned up queue for job %s", job_id)
501
+ except Exception:
502
+ logger.exception("Error stopping job %s", job_id)
503
+ return False
504
+ return True
505
+
506
+ async def stop_all_modules(self) -> None:
507
+ """Stop all running modules tracked in the registry."""
508
+ stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
509
+ if stop_tasks:
510
+ results = await asyncio.gather(*stop_tasks, return_exceptions=True)
511
+ logger.info("Stopped %d modules, results: %s", len(results), results)
512
+
513
+ async def list_modules(self) -> dict[str, dict[str, Any]]:
514
+ """List all modules tracked in the registry with their statuses.
515
+
516
+ Returns:
517
+ dict[str, dict[str, Any]]: A dictionary containing information about all tracked modules.
518
+ """
519
+ modules_info: dict[str, dict[str, Any]] = {}
520
+
521
+ for job_id in self.tasks_sessions:
522
+ try:
523
+ status = await self.get_module_status(job_id)
524
+ task_record = await self.channel.select_by_task_id("tasks", job_id)
525
+
526
+ modules_info[job_id] = {
527
+ "name": self.module_class.__name__,
528
+ "status": status,
529
+ "class": self.module_class.__name__,
530
+ "mission_id": task_record.get("mission_id") if task_record else "unknown",
531
+ }
532
+ except Exception: # noqa: PERF203
533
+ logger.exception("Error getting info for job %s", job_id)
534
+ modules_info[job_id] = {
535
+ "name": self.module_class.__name__,
536
+ "status": TaskStatus.FAILED,
537
+ "class": self.module_class.__name__,
538
+ "error": "Failed to retrieve status",
539
+ }
540
+
541
+ return modules_info
@@ -0,0 +1 @@
1
+ """Base task manager logic."""