digitalkin 0.3.0rc1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/core/common/__init__.py +9 -0
  3. digitalkin/core/common/factories.py +156 -0
  4. digitalkin/core/job_manager/base_job_manager.py +128 -28
  5. digitalkin/core/job_manager/single_job_manager.py +80 -25
  6. digitalkin/core/job_manager/taskiq_broker.py +114 -19
  7. digitalkin/core/job_manager/taskiq_job_manager.py +291 -39
  8. digitalkin/core/task_manager/base_task_manager.py +539 -0
  9. digitalkin/core/task_manager/local_task_manager.py +108 -0
  10. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  11. digitalkin/core/task_manager/surrealdb_repository.py +43 -4
  12. digitalkin/core/task_manager/task_executor.py +249 -0
  13. digitalkin/core/task_manager/task_session.py +107 -19
  14. digitalkin/grpc_servers/module_server.py +2 -2
  15. digitalkin/grpc_servers/module_servicer.py +21 -12
  16. digitalkin/grpc_servers/registry_server.py +1 -1
  17. digitalkin/grpc_servers/registry_servicer.py +4 -4
  18. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  19. digitalkin/models/core/task_monitor.py +17 -0
  20. digitalkin/models/grpc_servers/models.py +4 -4
  21. digitalkin/models/module/module_context.py +5 -0
  22. digitalkin/models/module/module_types.py +304 -16
  23. digitalkin/modules/_base_module.py +66 -28
  24. digitalkin/services/cost/grpc_cost.py +8 -41
  25. digitalkin/services/filesystem/grpc_filesystem.py +9 -38
  26. digitalkin/services/services_config.py +11 -0
  27. digitalkin/services/services_models.py +3 -1
  28. digitalkin/services/setup/default_setup.py +5 -6
  29. digitalkin/services/setup/grpc_setup.py +51 -14
  30. digitalkin/services/storage/grpc_storage.py +2 -2
  31. digitalkin/services/user_profile/__init__.py +12 -0
  32. digitalkin/services/user_profile/default_user_profile.py +55 -0
  33. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  34. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  35. digitalkin/utils/__init__.py +28 -0
  36. digitalkin/utils/dynamic_schema.py +483 -0
  37. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/METADATA +9 -29
  38. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/RECORD +42 -30
  39. modules/dynamic_setup_module.py +362 -0
  40. digitalkin/core/task_manager/task_manager.py +0 -439
  41. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/WHEEL +0 -0
  42. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/licenses/LICENSE +0 -0
  43. {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/top_level.txt +0 -0
digitalkin/__version__.py CHANGED
@@ -5,4 +5,4 @@ from importlib.metadata import PackageNotFoundError, version
5
5
  try:
6
6
  __version__ = version("digitalkin")
7
7
  except PackageNotFoundError:
8
- __version__ = "0.3.0-rc1"
8
+ __version__ = "0.3.1"
@@ -0,0 +1,9 @@
1
+ """Common utilities for the core module."""
2
+
3
+ from digitalkin.core.common.factories import ConnectionFactory, ModuleFactory, QueueFactory
4
+
5
+ __all__ = [
6
+ "ConnectionFactory",
7
+ "ModuleFactory",
8
+ "QueueFactory",
9
+ ]
@@ -0,0 +1,156 @@
1
+ """Common factory functions for reducing code duplication in core module."""
2
+
3
+ import asyncio
4
+ import datetime
5
+
6
+ from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
7
+ from digitalkin.logger import logger
8
+ from digitalkin.modules._base_module import BaseModule
9
+
10
+
11
+ class ConnectionFactory:
12
+ """Factory for creating SurrealDB connections with consistent configuration."""
13
+
14
+ @staticmethod
15
+ async def create_surreal_connection(
16
+ database: str = "task_manager",
17
+ timeout: datetime.timedelta = datetime.timedelta(seconds=5),
18
+ *,
19
+ auto_init: bool = True,
20
+ ) -> SurrealDBConnection:
21
+ """Create and optionally initialize a SurrealDB connection.
22
+
23
+ This factory method centralizes the creation of SurrealDB connections
24
+ to ensure consistent configuration across the codebase.
25
+
26
+ Args:
27
+ database: Database name to connect to
28
+ timeout: Connection timeout
29
+ auto_init: Whether to automatically initialize the connection
30
+
31
+ Returns:
32
+ Initialized or uninitialized SurrealDBConnection instance
33
+
34
+ Example:
35
+ # Create and auto-initialize
36
+ conn = await ConnectionFactory.create_surreal_connection("taskiq_worker")
37
+
38
+ # Create without initialization
39
+ conn = await ConnectionFactory.create_surreal_connection(auto_init=False)
40
+ await conn.init_surreal_instance()
41
+ """
42
+ logger.debug(
43
+ "Creating SurrealDB connection for database: %s, timeout: %s",
44
+ database,
45
+ timeout,
46
+ extra={"database": database, "timeout": str(timeout)},
47
+ )
48
+
49
+ connection: SurrealDBConnection = SurrealDBConnection(database, timeout)
50
+
51
+ if auto_init:
52
+ await connection.init_surreal_instance()
53
+ logger.debug("SurrealDB connection initialized for database: %s", database)
54
+
55
+ return connection
56
+
57
+
58
+ class ModuleFactory:
59
+ """Factory for creating module instances with consistent configuration."""
60
+
61
+ @staticmethod
62
+ def create_module_instance(
63
+ module_class: type[BaseModule],
64
+ job_id: str,
65
+ mission_id: str,
66
+ setup_id: str,
67
+ setup_version_id: str,
68
+ ) -> BaseModule:
69
+ """Create a module instance with standard parameters.
70
+
71
+ This factory method centralizes module instantiation to ensure
72
+ consistent parameter passing across the codebase.
73
+
74
+ Args:
75
+ module_class: The module class to instantiate
76
+ job_id: Unique job identifier
77
+ mission_id: Mission identifier
78
+ setup_id: Setup identifier
79
+ setup_version_id: Setup version identifier
80
+
81
+ Returns:
82
+ Instantiated module
83
+
84
+ Raises:
85
+ ValueError: If job_id or mission_id is empty
86
+
87
+ Example:
88
+ module = ModuleFactory.create_module_instance(
89
+ MyModule,
90
+ job_id="job_123",
91
+ mission_id="mission:test",
92
+ setup_id="setup:config",
93
+ setup_version_id="v1.0",
94
+ )
95
+ """
96
+ # Validate parameters
97
+ if not job_id:
98
+ msg = "job_id cannot be empty"
99
+ raise ValueError(msg)
100
+ if not mission_id:
101
+ msg = "mission_id cannot be empty"
102
+ raise ValueError(msg)
103
+
104
+ logger.debug(
105
+ "Creating module instance: %s for job: %s",
106
+ module_class.__name__,
107
+ job_id,
108
+ extra={
109
+ "module_class": module_class.__name__,
110
+ "job_id": job_id,
111
+ "mission_id": mission_id,
112
+ "setup_id": setup_id,
113
+ "setup_version_id": setup_version_id,
114
+ },
115
+ )
116
+
117
+ return module_class(
118
+ job_id=job_id,
119
+ mission_id=mission_id,
120
+ setup_id=setup_id,
121
+ setup_version_id=setup_version_id,
122
+ )
123
+
124
+
125
+ class QueueFactory:
126
+ """Factory for creating asyncio queues with consistent configuration."""
127
+
128
+ # Default max queue size to prevent unbounded memory growth
129
+ DEFAULT_MAX_QUEUE_SIZE = 1000
130
+
131
+ @staticmethod
132
+ def create_bounded_queue(maxsize: int = DEFAULT_MAX_QUEUE_SIZE) -> asyncio.Queue:
133
+ """Create a bounded asyncio queue with standard configuration.
134
+
135
+ Args:
136
+ maxsize: Maximum queue size (default 1000, 0 means unlimited)
137
+
138
+ Returns:
139
+ Bounded asyncio.Queue instance
140
+
141
+ Raises:
142
+ ValueError: If maxsize is negative
143
+
144
+ Example:
145
+ queue = QueueFactory.create_bounded_queue()
146
+ # or with custom size
147
+ queue = QueueFactory.create_bounded_queue(maxsize=500)
148
+ # unlimited queue
149
+ queue = QueueFactory.create_bounded_queue(maxsize=0)
150
+ """
151
+ if maxsize < 0:
152
+ msg = "maxsize must be >= 0"
153
+ raise ValueError(msg)
154
+
155
+ logger.debug("Creating bounded queue with maxsize: %d", maxsize, extra={"maxsize": maxsize})
156
+ return asyncio.Queue(maxsize=maxsize)
@@ -5,7 +5,8 @@ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
5
5
  from contextlib import asynccontextmanager
6
6
  from typing import Any, Generic
7
7
 
8
- from digitalkin.core.task_manager.task_manager import TaskManager
8
+ from digitalkin.core.task_manager.base_task_manager import BaseTaskManager
9
+ from digitalkin.core.task_manager.task_session import TaskSession
9
10
  from digitalkin.models.core.task_monitor import TaskStatus
10
11
  from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
11
12
  from digitalkin.models.module.module import ModuleCodeModel
@@ -14,9 +15,115 @@ from digitalkin.services.services_config import ServicesConfig
14
15
  from digitalkin.services.services_models import ServicesMode
15
16
 
16
17
 
17
- class BaseJobManager(abc.ABC, TaskManager, Generic[InputModelT, SetupModelT, OutputModelT]):
18
- """Abstract base class for managing background module jobs."""
18
+ class BaseJobManager(abc.ABC, Generic[InputModelT, OutputModelT, SetupModelT]):
19
+ """Abstract base class for managing background module jobs.
19
20
 
21
+ Uses composition to delegate task lifecycle management to a TaskManager.
22
+ """
23
+
24
+ module_class: type[BaseModule]
25
+ services_mode: ServicesMode
26
+ _task_manager: BaseTaskManager
27
+
28
+ def __init__(
29
+ self,
30
+ module_class: type[BaseModule],
31
+ services_mode: ServicesMode,
32
+ task_manager: BaseTaskManager,
33
+ ) -> None:
34
+ """Initialize the job manager.
35
+
36
+ Args:
37
+ module_class: The class of the module to be managed.
38
+ services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
39
+ task_manager: The task manager instance to use for task lifecycle management.
40
+ """
41
+ self.module_class = module_class
42
+ self.services_mode = services_mode
43
+ self._task_manager = task_manager
44
+
45
+ services_config = ServicesConfig(
46
+ services_config_strategies=self.module_class.services_config_strategies,
47
+ services_config_params=self.module_class.services_config_params,
48
+ mode=services_mode,
49
+ )
50
+ setattr(self.module_class, "services_config", services_config)
51
+
52
+ # Properties to expose task manager attributes
53
+ @property
54
+ def tasks_sessions(self) -> dict[str, TaskSession]:
55
+ """Get task sessions from the task manager."""
56
+ return self._task_manager.tasks_sessions
57
+
58
+ @property
59
+ def tasks(self) -> dict[str, Any]:
60
+ """Get tasks from the task manager."""
61
+ return self._task_manager.tasks
62
+
63
+ # Delegate task lifecycle methods to task manager
64
+ async def create_task(
65
+ self,
66
+ task_id: str,
67
+ mission_id: str,
68
+ module: BaseModule,
69
+ coro: Coroutine[Any, Any, None],
70
+ **kwargs: Any, # noqa: ANN401
71
+ ) -> None:
72
+ """Create a task using the task manager.
73
+
74
+ Args:
75
+ task_id: Unique identifier for the task
76
+ mission_id: Mission identifier
77
+ module: Module instance
78
+ coro: Coroutine to execute
79
+ **kwargs: Additional arguments for task creation
80
+ """
81
+ await self._task_manager.create_task(task_id, mission_id, module, coro, **kwargs)
82
+
83
+ async def clean_session(self, task_id: str, mission_id: str) -> bool:
84
+ """Clean a task's session.
85
+
86
+ Args:
87
+ task_id: Unique identifier for the task.
88
+ mission_id: Mission identifier.
89
+
90
+ Returns:
91
+ bool: True if the task was successfully cancelled, False otherwise.
92
+ """
93
+ return await self._task_manager.clean_session(task_id, mission_id)
94
+
95
+ async def cancel_task(self, task_id: str, mission_id: str, timeout: float | None = None) -> bool:
96
+ """Cancel a task.
97
+
98
+ Args:
99
+ task_id: Unique identifier for the task.
100
+ mission_id: Mission identifier.
101
+ timeout: Optional timeout in seconds to wait for the cancellation to complete.
102
+
103
+ Returns:
104
+ bool: True if the task was successfully cancelled, False otherwise.
105
+ """
106
+ return await self._task_manager.cancel_task(task_id, mission_id, timeout)
107
+
108
+ async def send_signal(self, task_id: str, mission_id: str, signal_type: str, payload: dict) -> bool:
109
+ """Send signal to a task.
110
+
111
+ Args:
112
+ task_id: Unique identifier for the task.
113
+ mission_id: Mission identifier.
114
+ signal_type: Type of signal to send.
115
+ payload: Payload data for the signal.
116
+
117
+ Returns:
118
+ bool: True if the signal was successfully sent, False otherwise.
119
+ """
120
+ return await self._task_manager.send_signal(task_id, mission_id, signal_type, payload)
121
+
122
+ async def shutdown(self, mission_id: str, timeout: float = 30.0) -> None:
123
+ """Shutdown all tasks."""
124
+ await self._task_manager.shutdown(mission_id, timeout)
125
+
126
+ @abc.abstractmethod
20
127
  async def start(self) -> None:
21
128
  """Start the job manager.
22
129
 
@@ -52,29 +159,6 @@ class BaseJobManager(abc.ABC, TaskManager, Generic[InputModelT, SetupModelT, Out
52
159
 
53
160
  return callback_wrapper
54
161
 
55
- def __init__(
56
- self,
57
- module_class: type[BaseModule],
58
- services_mode: ServicesMode,
59
- **kwargs, # noqa: ANN003
60
- ) -> None:
61
- """Initialize the job manager.
62
-
63
- Args:
64
- module_class: The class of the module to be managed.
65
- services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
66
- **kwargs: Additional keyword arguments for the job manager.
67
- """
68
- self.module_class = module_class
69
-
70
- services_config = ServicesConfig(
71
- services_config_strategies=self.module_class.services_config_strategies,
72
- services_config_params=self.module_class.services_config_params,
73
- mode=services_mode,
74
- )
75
- setattr(self.module_class, "services_config", services_config)
76
- super().__init__(**kwargs)
77
-
78
162
  @abc.abstractmethod # type: ignore
79
163
  @asynccontextmanager # type: ignore
80
164
  async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]:
@@ -110,7 +194,7 @@ class BaseJobManager(abc.ABC, TaskManager, Generic[InputModelT, SetupModelT, Out
110
194
  """
111
195
 
112
196
  @abc.abstractmethod
113
- async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
197
+ async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT | ModuleCodeModel:
114
198
  """Generate a stream consumer for a module's output data.
115
199
 
116
200
  This method creates an asynchronous generator that streams output data
@@ -121,7 +205,7 @@ class BaseJobManager(abc.ABC, TaskManager, Generic[InputModelT, SetupModelT, Out
121
205
  job_id: The unique identifier of the job.
122
206
 
123
207
  Returns:
124
- SetupModelT: the SetupModelT object fully processed.
208
+ SetupModelT | ModuleCodeModel: the SetupModelT object fully processed, or an error code.
125
209
  """
126
210
 
127
211
  @abc.abstractmethod
@@ -172,6 +256,22 @@ class BaseJobManager(abc.ABC, TaskManager, Generic[InputModelT, SetupModelT, Out
172
256
  ModuleStatu: The status of the job.
173
257
  """
174
258
 
259
+ @abc.abstractmethod
260
+ async def wait_for_completion(self, job_id: str) -> None:
261
+ """Wait for a task to complete.
262
+
263
+ This method blocks until the specified job has reached a terminal state.
264
+ The implementation varies by job manager type:
265
+ - SingleJobManager: Awaits the asyncio.Task directly
266
+ - TaskiqJobManager: Polls task status from SurrealDB
267
+
268
+ Args:
269
+ job_id: The unique identifier of the job to wait for.
270
+
271
+ Raises:
272
+ KeyError: If the job_id is not found.
273
+ """
274
+
175
275
  @abc.abstractmethod
176
276
  async def stop_all_modules(self) -> None:
177
277
  """Stop all currently running module jobs.
@@ -5,12 +5,13 @@ import datetime
5
5
  import uuid
6
6
  from collections.abc import AsyncGenerator, AsyncIterator
7
7
  from contextlib import asynccontextmanager
8
- from typing import Any, Generic
8
+ from typing import Any
9
9
 
10
10
  import grpc
11
11
 
12
+ from digitalkin.core.common import ConnectionFactory, ModuleFactory
12
13
  from digitalkin.core.job_manager.base_job_manager import BaseJobManager
13
- from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
14
+ from digitalkin.core.task_manager.local_task_manager import LocalTaskManager
14
15
  from digitalkin.core.task_manager.task_session import TaskSession
15
16
  from digitalkin.logger import logger
16
17
  from digitalkin.models.core.task_monitor import TaskStatus
@@ -20,7 +21,7 @@ from digitalkin.modules._base_module import BaseModule
20
21
  from digitalkin.services.services_models import ServicesMode
21
22
 
22
23
 
23
- class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupModelT]):
24
+ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
24
25
  """Manages a single instance of a module job.
25
26
 
26
27
  This class ensures that only one instance of a module job is active at a time.
@@ -30,21 +31,29 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupM
30
31
 
31
32
  async def start(self) -> None:
32
33
  """Start manager."""
33
- self.channel: SurrealDBConnection = SurrealDBConnection("task_manager", datetime.timedelta(seconds=5))
34
- await self.channel.init_surreal_instance()
34
+ self.channel = await ConnectionFactory.create_surreal_connection("task_manager", datetime.timedelta(seconds=5))
35
35
 
36
36
  def __init__(
37
37
  self,
38
38
  module_class: type[BaseModule],
39
39
  services_mode: ServicesMode,
40
+ default_timeout: float = 10.0,
41
+ max_concurrent_tasks: int = 100,
40
42
  ) -> None:
41
43
  """Initialize the job manager.
42
44
 
43
45
  Args:
44
46
  module_class: The class of the module to be managed.
45
47
  services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
48
+ default_timeout: Default timeout for task operations
49
+ max_concurrent_tasks: Maximum number of concurrent tasks
46
50
  """
47
- super().__init__(module_class, services_mode)
51
+ # Create local task manager for same-process execution
52
+ task_manager = LocalTaskManager(default_timeout, max_concurrent_tasks)
53
+
54
+ # Initialize base job manager with task manager
55
+ super().__init__(module_class, services_mode, task_manager)
56
+
48
57
  self._lock = asyncio.Lock()
49
58
 
50
59
  async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT | ModuleCodeModel:
@@ -68,7 +77,14 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupM
68
77
 
69
78
  logger.debug("Module %s found: %s", job_id, session.module)
70
79
  try:
71
- return await session.queue.get()
80
+ # Add timeout to prevent indefinite blocking
81
+ return await asyncio.wait_for(session.queue.get(), timeout=30.0)
82
+ except asyncio.TimeoutError:
83
+ logger.error("Timeout waiting for config setup response from module %s", job_id)
84
+ return ModuleCodeModel(
85
+ code=str(grpc.StatusCode.DEADLINE_EXCEEDED),
86
+ message=f"Module {job_id} did not respond within 30 seconds",
87
+ )
72
88
  finally:
73
89
  logger.info(f"{job_id=}: {session.queue.empty()}")
74
90
 
@@ -98,7 +114,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupM
98
114
  """
99
115
  job_id = str(uuid.uuid4())
100
116
  # TODO: Ensure the job_id is unique.
101
- module = self.module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
117
+ module = ModuleFactory.create_module_instance(self.module_class, job_id, mission_id, setup_id, setup_version_id)
102
118
  self.tasks_sessions[job_id] = TaskSession(job_id, mission_id, self.channel, module)
103
119
 
104
120
  try:
@@ -161,26 +177,45 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupM
161
177
  logger.debug("Session: %s with Module %s", job_id, session.module)
162
178
 
163
179
  async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
164
- """Stream output data from the module.
180
+ """Stream output data from the module with simple blocking pattern.
181
+
182
+ This implementation uses a simple one-item-at-a-time pattern optimized
183
+ for local execution where we have direct access to session status:
184
+ 1. Block waiting for each item
185
+ 2. Check termination conditions after each item
186
+ 3. Clean shutdown when task completes
187
+
188
+ This pattern provides:
189
+ - Immediate termination when task completes
190
+ - Direct session status monitoring
191
+ - Simple, predictable behavior for local tasks
165
192
 
166
193
  Yields:
167
194
  dict: Output data generated by the module.
168
195
  """
169
196
  while True:
170
- # if queue is empty but producer not finished yet, block on get()
197
+ # Block for next item - if queue is empty but producer not finished yet
171
198
  msg = await session.queue.get()
172
199
  try:
173
200
  yield msg
174
201
  finally:
202
+ # Always mark task as done, even if consumer raises exception
175
203
  session.queue.task_done()
176
204
 
177
- # If the producer marked finished and no more items, break soon:
205
+ # Check termination conditions after each message
206
+ # This allows immediate shutdown when the task completes
178
207
  if (
179
208
  session.is_cancelled.is_set()
180
209
  or (session.status is TaskStatus.COMPLETED and session.queue.empty())
181
210
  or session.status is TaskStatus.FAILED
182
211
  ):
183
- # and session.queue.empty():
212
+ logger.debug(
213
+ "Stream ending for job %s: cancelled=%s, status=%s, queue_empty=%s",
214
+ job_id,
215
+ session.is_cancelled.is_set(),
216
+ session.status,
217
+ session.queue.empty(),
218
+ )
184
219
  break
185
220
 
186
221
  yield _stream()
@@ -212,12 +247,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupM
212
247
  Exception: If the module fails to start.
213
248
  """
214
249
  job_id = str(uuid.uuid4())
215
- module = self.module_class(
216
- job_id,
217
- mission_id=mission_id,
218
- setup_id=setup_id,
219
- setup_version_id=setup_version_id,
220
- )
250
+ module = ModuleFactory.create_module_instance(self.module_class, job_id, mission_id, setup_id, setup_version_id)
221
251
  callback = await self.job_specific_callback(self.add_to_queue, job_id)
222
252
 
223
253
  await self.create_task(
@@ -251,9 +281,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupM
251
281
  return False
252
282
  try:
253
283
  await session.module.stop()
254
-
255
- if job_id in self.tasks:
256
- await self.cancel_task(job_id, session.mission_id)
284
+ await self.cancel_task(job_id, session.mission_id)
257
285
  logger.debug(f"session {job_id} ({session.module.name}) stopped successfully")
258
286
  except Exception as e:
259
287
  logger.error(f"Error while stopping module {job_id}: {e}")
@@ -273,15 +301,42 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, OutputModelT, SetupM
273
301
  session = self.tasks_sessions.get(job_id, None)
274
302
  return session.status if session is not None else TaskStatus.FAILED
275
303
 
304
+ async def wait_for_completion(self, job_id: str) -> None:
305
+ """Wait for a task to complete by awaiting its asyncio.Task.
306
+
307
+ Args:
308
+ job_id: The unique identifier of the job to wait for.
309
+
310
+ Raises:
311
+ KeyError: If the job_id is not found in tasks.
312
+ """
313
+ if job_id not in self._task_manager.tasks:
314
+ msg = f"Job {job_id} not found"
315
+ raise KeyError(msg)
316
+ await self._task_manager.tasks[job_id]
317
+
276
318
  async def stop_all_modules(self) -> None:
277
319
  """Stop all currently running module jobs.
278
320
 
279
- This method ensures that all active jobs are gracefully terminated.
321
+ This method ensures that all active jobs are gracefully terminated
322
+ and closes the SurrealDB connection.
280
323
  """
324
+ # Snapshot job IDs while holding lock
281
325
  async with self._lock:
282
- stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
283
- if stop_tasks:
284
- await asyncio.gather(*stop_tasks, return_exceptions=True)
326
+ job_ids = list(self.tasks_sessions.keys())
327
+
328
+ # Release lock before calling stop_module (which has its own lock)
329
+ if job_ids:
330
+ stop_tasks = [self.stop_module(job_id) for job_id in job_ids]
331
+ await asyncio.gather(*stop_tasks, return_exceptions=True)
332
+
333
+ # Close SurrealDB connection after stopping all modules
334
+ if hasattr(self, "channel"):
335
+ try:
336
+ await self.channel.close()
337
+ logger.info("SingleJobManager: SurrealDB connection closed")
338
+ except Exception as e:
339
+ logger.warning("Failed to close SurrealDB connection: %s", e)
285
340
 
286
341
  async def list_modules(self) -> dict[str, dict[str, Any]]:
287
342
  """List all modules along with their statuses.