digitalkin 0.2.26__py3-none-any.whl → 0.3.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/grpc_servers/module_server.py +27 -44
  3. digitalkin/grpc_servers/module_servicer.py +27 -22
  4. digitalkin/grpc_servers/utils/models.py +1 -1
  5. digitalkin/logger.py +1 -9
  6. digitalkin/mixins/__init__.py +19 -0
  7. digitalkin/mixins/base_mixin.py +10 -0
  8. digitalkin/mixins/callback_mixin.py +24 -0
  9. digitalkin/mixins/chat_history_mixin.py +108 -0
  10. digitalkin/mixins/cost_mixin.py +76 -0
  11. digitalkin/mixins/file_history_mixin.py +99 -0
  12. digitalkin/mixins/filesystem_mixin.py +47 -0
  13. digitalkin/mixins/logger_mixin.py +59 -0
  14. digitalkin/mixins/storage_mixin.py +79 -0
  15. digitalkin/models/module/__init__.py +2 -0
  16. digitalkin/models/module/module.py +9 -1
  17. digitalkin/models/module/module_context.py +90 -6
  18. digitalkin/models/module/module_types.py +6 -6
  19. digitalkin/models/module/task_monitor.py +51 -0
  20. digitalkin/models/services/__init__.py +9 -0
  21. digitalkin/models/services/storage.py +39 -5
  22. digitalkin/modules/_base_module.py +47 -68
  23. digitalkin/modules/job_manager/base_job_manager.py +12 -8
  24. digitalkin/modules/job_manager/single_job_manager.py +84 -78
  25. digitalkin/modules/job_manager/surrealdb_repository.py +228 -0
  26. digitalkin/modules/job_manager/task_manager.py +389 -0
  27. digitalkin/modules/job_manager/task_session.py +275 -0
  28. digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
  29. digitalkin/modules/tool_module.py +10 -2
  30. digitalkin/modules/trigger_handler.py +7 -6
  31. digitalkin/services/cost/__init__.py +9 -2
  32. digitalkin/services/filesystem/default_filesystem.py +0 -2
  33. digitalkin/services/storage/grpc_storage.py +1 -1
  34. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/METADATA +20 -19
  35. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/RECORD +38 -25
  36. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/WHEEL +0 -0
  37. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/licenses/LICENSE +0 -0
  38. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,11 @@
1
1
  """BaseModule is the abstract base for all modules in the DigitalKin SDK."""
2
2
 
3
3
  import asyncio
4
- import contextlib
5
4
  import json
6
5
  from abc import ABC, abstractmethod
7
6
  from collections.abc import Callable, Coroutine
8
7
  from typing import Any, ClassVar, Generic
9
8
 
10
- from pydantic import BaseModel
11
-
12
9
  from digitalkin.logger import logger
13
10
  from digitalkin.models.module import (
14
11
  InputModelT,
@@ -17,28 +14,14 @@ from digitalkin.models.module import (
17
14
  SecretModelT,
18
15
  SetupModelT,
19
16
  )
17
+ from digitalkin.models.module.module import ModuleCodeModel
20
18
  from digitalkin.models.module.module_context import ModuleContext
21
19
  from digitalkin.modules.trigger_handler import TriggerHandler
22
- from digitalkin.services.agent.agent_strategy import AgentStrategy
23
- from digitalkin.services.cost.cost_strategy import CostStrategy
24
- from digitalkin.services.filesystem.filesystem_strategy import FilesystemStrategy
25
- from digitalkin.services.identity.identity_strategy import IdentityStrategy
26
- from digitalkin.services.registry.registry_strategy import RegistryStrategy
27
20
  from digitalkin.services.services_config import ServicesConfig, ServicesStrategy
28
- from digitalkin.services.snapshot.snapshot_strategy import SnapshotStrategy
29
- from digitalkin.services.storage.storage_strategy import StorageStrategy
30
21
  from digitalkin.utils.llm_ready_schema import llm_ready_schema
31
22
  from digitalkin.utils.package_discover import ModuleDiscoverer
32
23
 
33
24
 
34
- class ModuleCodeModel(BaseModel):
35
- """typed error/code model."""
36
-
37
- code: str
38
- message: str
39
- short_description: str
40
-
41
-
42
25
  class BaseModule( # noqa: PLR0904
43
26
  ABC,
44
27
  Generic[
@@ -67,33 +50,35 @@ class BaseModule( # noqa: PLR0904
67
50
  services_config_params: ClassVar[dict[str, dict[str, Any | None] | None]]
68
51
  services_config: ServicesConfig
69
52
 
70
- # services list
71
- agent: AgentStrategy
72
- cost: CostStrategy
73
- filesystem: FilesystemStrategy
74
- identity: IdentityStrategy
75
- registry: RegistryStrategy
76
- snapshot: SnapshotStrategy
77
- storage: StorageStrategy
78
-
79
53
  # runtime params
80
54
  job_id: str
81
55
  mission_id: str
82
56
  setup_id: str
83
57
  setup_version_id: str
84
- _status: ModuleStatus
85
- _task: asyncio.Task | None
86
58
 
87
- def _init_strategies(self) -> None:
88
- """Initialize the services configuration."""
89
- for service_name in self.services_config.valid_strategy_names():
90
- service = self.services_config.init_strategy(
59
+ def _init_strategies(self) -> dict[str, Any]:
60
+ """Initialize the services configuration.
61
+
62
+ Returns:
63
+ dict of services with name: Strategy
64
+ agent: AgentStrategy
65
+ cost: CostStrategy
66
+ filesystem: FilesystemStrategy
67
+ identity: IdentityStrategy
68
+ registry: RegistryStrategy
69
+ snapshot: SnapshotStrategy
70
+ storage: StorageStrategy
71
+ """
72
+ logger.debug("Service initialisation: %s", self.services_config_strategies.keys())
73
+ return {
74
+ service_name: self.services_config.init_strategy(
91
75
  service_name,
92
76
  self.mission_id,
93
77
  self.setup_id,
94
78
  self.setup_version_id,
95
79
  )
96
- setattr(self, service_name, service)
80
+ for service_name in self.services_config.valid_strategy_names()
81
+ }
97
82
 
98
83
  def __init__(
99
84
  self,
@@ -110,15 +95,16 @@ class BaseModule( # noqa: PLR0904
110
95
  # SetupVersion reference needed for the precise Kin scope as the cost
111
96
  self.setup_version_id: str = setup_version_id
112
97
  self._status = ModuleStatus.CREATED
113
- self._task: asyncio.Task | None = None
114
- # Initialize services configuration
115
- self._init_strategies()
116
98
 
117
99
  # Initialize minimum context
118
100
  self.context = ModuleContext(
119
- storage=self.storage,
120
- cost=self.cost,
121
- filesystem=self.filesystem,
101
+ # Initialize services configuration
102
+ **self._init_strategies(),
103
+ session={
104
+ "mission_id": mission_id,
105
+ "setup_version_id": setup_version_id,
106
+ "job_id": job_id,
107
+ },
122
108
  )
123
109
 
124
110
  @property
@@ -307,7 +293,11 @@ class BaseModule( # noqa: PLR0904
307
293
  """
308
294
  return cls.triggers_discoverer.register_trigger(handler_cls)
309
295
 
310
- async def run_config_setup(self, config_setup_data: SetupModelT) -> SetupModelT: # noqa: PLR6301
296
+ async def run_config_setup( # noqa: PLR6301
297
+ self,
298
+ context: ModuleContext, # noqa: ARG002
299
+ config_setup_data: SetupModelT,
300
+ ) -> SetupModelT:
311
301
  """Run config setup the module.
312
302
 
313
303
  The config setup is used to initialize the setup with configuration data.
@@ -323,7 +313,7 @@ class BaseModule( # noqa: PLR0904
323
313
  return config_setup_data
324
314
 
325
315
  @abstractmethod
326
- async def initialize(self, setup_data: SetupModelT) -> None:
316
+ async def initialize(self, context: ModuleContext, setup_data: SetupModelT) -> None:
327
317
  """Initialize the module."""
328
318
  raise NotImplementedError
329
319
 
@@ -331,7 +321,6 @@ class BaseModule( # noqa: PLR0904
331
321
  self,
332
322
  input_data: InputModelT,
333
323
  setup_data: SetupModelT,
334
- callback: Callable[[OutputModelT], Coroutine[Any, Any, None]],
335
324
  ) -> None:
336
325
  """Run the module with the given input and setup data.
337
326
 
@@ -360,7 +349,6 @@ class BaseModule( # noqa: PLR0904
360
349
  await handler_instance.handle(
361
350
  input_instance.root,
362
351
  setup_data,
363
- callback,
364
352
  self.context,
365
353
  )
366
354
 
@@ -373,7 +361,6 @@ class BaseModule( # noqa: PLR0904
373
361
  self,
374
362
  input_data: InputModelT,
375
363
  setup_data: SetupModelT,
376
- callback: Callable[[OutputModelT], Coroutine[Any, Any, None]],
377
364
  ) -> None:
378
365
  """Run the module lifecycle.
379
366
 
@@ -391,7 +378,7 @@ class BaseModule( # noqa: PLR0904
391
378
  "job_id": self.job_id,
392
379
  },
393
380
  )
394
- await self.run(input_data, setup_data, callback)
381
+ await self.run(input_data, setup_data)
395
382
  logger.info(
396
383
  "Module %s finished",
397
384
  self.name,
@@ -428,8 +415,6 @@ class BaseModule( # noqa: PLR0904
428
415
  )
429
416
  else:
430
417
  self._status = ModuleStatus.STOPPING
431
- finally:
432
- await self.stop()
433
418
 
434
419
  async def start(
435
420
  self,
@@ -440,15 +425,17 @@ class BaseModule( # noqa: PLR0904
440
425
  ) -> None:
441
426
  """Start the module."""
442
427
  try:
443
- logger.debug("Inititalize module")
444
- await self.initialize(setup_data=setup_data)
428
+ self.context.callbacks.logger = logger
429
+ self.context.callbacks.send_message = callback
430
+ logger.info(f"Inititalize module {self.job_id}")
431
+ await self.initialize(self.context, setup_data)
445
432
  except Exception as e:
446
433
  self._status = ModuleStatus.FAILED
447
434
  short_description = "Error initializing module"
448
435
  logger.exception("%s: %s", short_description, e)
449
436
  await callback(
450
437
  ModuleCodeModel(
451
- code=str(self._status),
438
+ code="Error",
452
439
  short_description=short_description,
453
440
  message=str(e),
454
441
  )
@@ -461,32 +448,22 @@ class BaseModule( # noqa: PLR0904
461
448
  try:
462
449
  logger.debug("Init the discovered input handlers.")
463
450
  self.triggers_discoverer.init_handlers(self.context)
464
- logger.debug("Run lifecycle")
465
- self._status = ModuleStatus.RUNNING
466
- self._task = asyncio.create_task(
467
- self._run_lifecycle(input_data, setup_data, callback),
468
- name="module_lifecycle",
469
- )
470
- if done_callback is not None:
471
- self._task.add_done_callback(done_callback)
451
+ logger.debug(f"Run lifecycle {self.job_id}")
452
+ await self._run_lifecycle(input_data, setup_data)
472
453
  except Exception:
473
454
  self._status = ModuleStatus.FAILED
474
455
  logger.exception("Error during module lifecyle")
456
+ finally:
457
+ await self.stop()
475
458
 
476
459
  async def stop(self) -> None:
477
460
  """Stop the module."""
478
- logger.info("Stopping module %s with status %s", self.name, self._status)
479
- if self._status not in {ModuleStatus.RUNNING, ModuleStatus.STOPPING}:
480
- return
481
-
461
+ logger.info("Stopping module %s | job_id=%s", self.name, self.job_id)
482
462
  try:
483
463
  self._status = ModuleStatus.STOPPING
484
- if self._task and not self._task.done():
485
- self._task.cancel()
486
- with contextlib.suppress(asyncio.CancelledError):
487
- await self._task
488
464
  logger.debug("Module %s stopped", self.name)
489
465
  await self.cleanup()
466
+ await self.context.callbacks.send_message(ModuleCodeModel(code="__END_OF_STREAM__"))
490
467
  self._status = ModuleStatus.STOPPED
491
468
  logger.debug("Module %s cleaned", self.name)
492
469
  except Exception:
@@ -510,13 +487,15 @@ class BaseModule( # noqa: PLR0904
510
487
  },
511
488
  )
512
489
  self._status = ModuleStatus.RUNNING
513
- content = await self.run_config_setup(config_setup_data)
490
+ self.context.callbacks.set_config_setup = callback
491
+ content = await self.run_config_setup(self.context, config_setup_data)
514
492
 
515
493
  wrapper = config_setup_data.model_dump()
516
494
  wrapper["content"] = content.model_dump()
517
495
  await callback(self.create_setup_model(wrapper))
518
496
  self._status = ModuleStatus.STOPPING
519
497
  except Exception:
498
+ logger.error("Error during module lifecyle")
520
499
  self._status = ModuleStatus.FAILED
521
500
  logger.exception(
522
501
  "Error during module lifecyle",
@@ -5,17 +5,18 @@ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
5
5
  from contextlib import asynccontextmanager
6
6
  from typing import Any, Generic
7
7
 
8
- from digitalkin.models import ModuleStatus
9
8
  from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
9
+ from digitalkin.models.module.module import ModuleCodeModel
10
+ from digitalkin.models.module.task_monitor import TaskStatus
10
11
  from digitalkin.modules._base_module import BaseModule
11
12
  from digitalkin.services.services_config import ServicesConfig
12
13
  from digitalkin.services.services_models import ServicesMode
13
14
 
14
15
 
15
- class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT]):
16
+ class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT, OutputModelT]):
16
17
  """Abstract base class for managing background module jobs."""
17
18
 
18
- async def _start(self) -> None:
19
+ async def start(self) -> None:
19
20
  """Start the job manager.
20
21
 
21
22
  This method initializes any necessary resources or configurations
@@ -24,8 +25,8 @@ class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT]):
24
25
 
25
26
  @staticmethod
26
27
  async def job_specific_callback(
27
- callback: Callable[[str, OutputModelT], Coroutine[Any, Any, None]], job_id: str
28
- ) -> Callable[[OutputModelT], Coroutine[Any, Any, None]]:
28
+ callback: Callable[[str, OutputModelT | ModuleCodeModel], Coroutine[Any, Any, None]], job_id: str
29
+ ) -> Callable[[OutputModelT | ModuleCodeModel], Coroutine[Any, Any, None]]:
29
30
  """Generate a job-specific callback function.
30
31
 
31
32
  Args:
@@ -36,7 +37,7 @@ class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT]):
36
37
  Callable: A wrapped callback function that includes the job ID.
37
38
  """
38
39
 
39
- def callback_wrapper(output_data: OutputModelT) -> Coroutine[Any, Any, None]:
40
+ def callback_wrapper(output_data: OutputModelT | ModuleCodeModel) -> Coroutine[Any, Any, None]:
40
41
  """Wrapper for the callback function.
41
42
 
42
43
  Args:
@@ -53,12 +54,14 @@ class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT]):
53
54
  self,
54
55
  module_class: type[BaseModule],
55
56
  services_mode: ServicesMode,
57
+ **kwargs, # noqa: ANN003
56
58
  ) -> None:
57
59
  """Initialize the job manager.
58
60
 
59
61
  Args:
60
62
  module_class: The class of the module to be managed.
61
63
  services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
64
+ **kwargs: Additional keyword arguments for the job manager.
62
65
  """
63
66
  self.module_class = module_class
64
67
 
@@ -68,6 +71,7 @@ class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT]):
68
71
  mode=services_mode,
69
72
  )
70
73
  setattr(self.module_class, "services_config", services_config)
74
+ super().__init__(**kwargs)
71
75
 
72
76
  @abc.abstractmethod # type: ignore
73
77
  @asynccontextmanager # type: ignore
@@ -156,14 +160,14 @@ class BaseJobManager(abc.ABC, Generic[InputModelT, SetupModelT]):
156
160
  """
157
161
 
158
162
  @abc.abstractmethod
159
- async def get_module_status(self, job_id: str) -> ModuleStatus | None:
163
+ async def get_module_status(self, job_id: str) -> TaskStatus:
160
164
  """Retrieve the status of a module job.
161
165
 
162
166
  Args:
163
167
  job_id: The unique identifier of the job.
164
168
 
165
169
  Returns:
166
- ModuleStatus | None: The status of the job, or None if the job does not exist.
170
+ ModuleStatu: The status of the job.
167
171
  """
168
172
 
169
173
  @abc.abstractmethod
@@ -1,6 +1,7 @@
1
1
  """Background module manager with single instance."""
2
2
 
3
3
  import asyncio
4
+ import datetime
4
5
  import uuid
5
6
  from collections.abc import AsyncGenerator, AsyncIterator
6
7
  from contextlib import asynccontextmanager
@@ -9,15 +10,18 @@ from typing import Any, Generic
9
10
  import grpc
10
11
 
11
12
  from digitalkin.logger import logger
12
- from digitalkin.models import ModuleStatus
13
13
  from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
14
+ from digitalkin.models.module.module import ModuleCodeModel
15
+ from digitalkin.models.module.task_monitor import TaskStatus
14
16
  from digitalkin.modules._base_module import BaseModule
15
17
  from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
16
- from digitalkin.modules.job_manager.job_manager_models import StreamCodeModel
18
+ from digitalkin.modules.job_manager.surrealdb_repository import SurrealDBConnection
19
+ from digitalkin.modules.job_manager.task_manager import TaskManager
20
+ from digitalkin.modules.job_manager.task_session import TaskSession
17
21
  from digitalkin.services.services_models import ServicesMode
18
22
 
19
23
 
20
- class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
24
+ class SingleJobManager(BaseJobManager, TaskManager, Generic[InputModelT, OutputModelT, SetupModelT]):
21
25
  """Manages a single instance of a module job.
22
26
 
23
27
  This class ensures that only one instance of a module job is active at a time.
@@ -25,8 +29,10 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
25
29
  to handle their output data.
26
30
  """
27
31
 
28
- modules: dict[str, BaseModule]
29
- queue: dict[str, asyncio.Queue]
32
+ async def start(self) -> None:
33
+ """Start manager."""
34
+ self.channel = SurrealDBConnection("task_manager", datetime.timedelta(seconds=5))
35
+ await self.channel.init_surreal_instance()
30
36
 
31
37
  def __init__(
32
38
  self,
@@ -40,12 +46,9 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
40
46
  services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
41
47
  """
42
48
  super().__init__(module_class, services_mode)
43
-
44
49
  self._lock = asyncio.Lock()
45
- self.modules: dict[str, BaseModule] = {}
46
- self.queues: dict[str, asyncio.Queue] = {}
47
50
 
48
- async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
51
+ async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT | ModuleCodeModel:
49
52
  """Generate a stream consumer for a module's output data.
50
53
 
51
54
  This method creates an asynchronous generator that streams output data
@@ -56,16 +59,19 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
56
59
  job_id: The unique identifier of the job.
57
60
 
58
61
  Returns:
59
- SetupModelT: the SetupModelT object fully processed.
62
+ SetupModelT | ModuleCodeModel: the SetupModelT object fully processed.
60
63
  """
61
- module = self.modules.get(job_id, None)
62
- logger.debug("Module %s found: %s", job_id, module)
64
+ if (session := self.tasks_sessions.get(job_id, None)) is None:
65
+ return ModuleCodeModel(
66
+ code=str(grpc.StatusCode.NOT_FOUND),
67
+ message=f"Module {job_id} not found",
68
+ )
63
69
 
70
+ logger.debug("Module %s found: %s", job_id, session.module)
64
71
  try:
65
- return await self.queues[job_id].get()
72
+ return await session.queue.get()
66
73
  finally:
67
- logger.info(f"{job_id=}: {self.queues[job_id].empty()}")
68
- del self.queues[job_id]
74
+ logger.info(f"{job_id=}: {session.queue.empty()}")
69
75
 
70
76
  async def create_config_setup_instance_job(
71
77
  self,
@@ -95,8 +101,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
95
101
  job_id = str(uuid.uuid4())
96
102
  # TODO: Ensure the job_id is unique.
97
103
  module = self.module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
98
- self.modules[job_id] = module
99
- self.queues[job_id] = asyncio.Queue()
104
+ self.tasks_sessions[job_id] = TaskSession(job_id, self.channel, module)
100
105
 
101
106
  try:
102
107
  await module.start_config_setup(
@@ -106,13 +111,13 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
106
111
  logger.debug("Module %s (%s) started successfully", job_id, module.name)
107
112
  except Exception:
108
113
  # Remove the module from the manager in case of an error.
109
- del self.modules[job_id]
114
+ del self.tasks_sessions[job_id]
110
115
  logger.exception("Failed to start module %s: %s", job_id)
111
116
  raise
112
117
  else:
113
118
  return job_id
114
119
 
115
- async def add_to_queue(self, job_id: str, output_data: OutputModelT) -> None: # type: ignore
120
+ async def add_to_queue(self, job_id: str, output_data: OutputModelT | ModuleCodeModel) -> None:
116
121
  """Add output data to the queue for a specific job.
117
122
 
118
123
  This method is used as a callback to handle output data generated by a module job.
@@ -121,7 +126,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
121
126
  job_id: The unique identifier of the job.
122
127
  output_data: The output data produced by the job.
123
128
  """
124
- await self.queues[job_id].put(output_data.model_dump())
129
+ await self.tasks_sessions[job_id].queue.put(output_data.model_dump())
125
130
 
126
131
  @asynccontextmanager # type: ignore
127
132
  async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
@@ -137,39 +142,48 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
137
142
  Yields:
138
143
  AsyncGenerator: A stream of output data or error messages.
139
144
  """
140
- module = self.modules.get(job_id, None)
141
-
142
- logger.debug("Module %s found: %s", job_id, module)
145
+ if (session := self.tasks_sessions.get(job_id, None)) is None:
143
146
 
144
- async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
145
- """Stream output data from the module.
147
+ async def _error_gen() -> AsyncGenerator[dict[str, Any], None]: # noqa: RUF029
148
+ """Generate an error message for a non-existent module.
146
149
 
147
- Yields:
148
- dict: Output data generated by the module.
149
- """
150
- if module is None:
150
+ Yields:
151
+ AsyncGenerator: A generator yielding an error message.
152
+ """
151
153
  yield {
152
154
  "error": {
153
155
  "error_message": f"Module {job_id} not found",
154
156
  "code": grpc.StatusCode.NOT_FOUND,
155
157
  }
156
158
  }
157
- return
158
159
 
159
- try:
160
- while module.status == ModuleStatus.RUNNING or (
161
- not self.queues[job_id].empty()
162
- and module.status
163
- in {
164
- ModuleStatus.STOPPED,
165
- ModuleStatus.STOPPING,
166
- }
167
- ):
168
- logger.info(f"{job_id=}: {module.status=}")
169
- yield await self.queues[job_id].get()
160
+ yield _error_gen()
161
+ return
162
+
163
+ logger.debug("Session: %s with Module %s", job_id, session.module)
164
+
165
+ async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
166
+ """Stream output data from the module.
170
167
 
171
- finally:
172
- del self.queues[job_id]
168
+ Yields:
169
+ dict: Output data generated by the module.
170
+ """
171
+ while True:
172
+ # if queue is empty but producer not finished yet, block on get()
173
+ msg = await session.queue.get()
174
+ try:
175
+ yield msg
176
+ finally:
177
+ session.queue.task_done()
178
+
179
+ # If the producer marked finished and no more items, break soon:
180
+ if (
181
+ session.is_cancelled.is_set()
182
+ or (session.status is TaskStatus.COMPLETED and session.queue.empty())
183
+ or session.status is TaskStatus.FAILED
184
+ ):
185
+ # and session.queue.empty():
186
+ break
173
187
 
174
188
  yield _stream()
175
189
 
@@ -200,32 +214,21 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
200
214
  Exception: If the module fails to start.
201
215
  """
202
216
  job_id = str(uuid.uuid4())
203
- # TODO: Ensure the job_id is unique.
204
217
  module = self.module_class(
205
218
  job_id,
206
219
  mission_id=mission_id,
207
220
  setup_id=setup_id,
208
221
  setup_version_id=setup_version_id,
209
222
  )
210
- self.modules[job_id] = module
211
- self.queues[job_id] = asyncio.Queue()
212
223
  callback = await self.job_specific_callback(self.add_to_queue, job_id)
213
224
 
214
- try:
215
- await module.start(
216
- input_data,
217
- setup_data,
218
- callback,
219
- done_callback=lambda _: asyncio.create_task(callback(StreamCodeModel(code="__END_OF_STREAM__"))),
220
- )
221
- logger.debug("Module %s (%s) started successfully", job_id, module.name)
222
- except Exception:
223
- # Remove the module from the manager in case of an error.
224
- del self.modules[job_id]
225
- logger.exception("Failed to start module %s: %s", job_id)
226
- raise
227
- else:
228
- return job_id
225
+ await self.create_task(
226
+ job_id,
227
+ module,
228
+ module.start(input_data, setup_data, callback, done_callback=None),
229
+ )
230
+ logger.info("Managed task started: '%s'", job_id, extra={"task_id": job_id})
231
+ return job_id
229
232
 
230
233
  async def stop_module(self, job_id: str) -> bool:
231
234
  """Stop a running module job.
@@ -239,34 +242,37 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
239
242
  Raises:
240
243
  Exception: If an error occurs while stopping the module.
241
244
  """
245
+ logger.info(f"STOP required for {job_id=}")
246
+
242
247
  async with self._lock:
243
- module = self.modules.get(job_id)
244
- if not module:
245
- logger.warning(f"Module {job_id} not found")
248
+ session = self.tasks_sessions.get(job_id)
249
+
250
+ if not session:
251
+ logger.warning(f"session with id: {job_id} not found")
246
252
  return False
247
253
  try:
248
- await module.stop()
249
- # should maybe be added in finally
250
- del self.queues[job_id]
251
- del self.modules[job_id]
252
- logger.debug(f"Module {job_id} ({module.name}) stopped successfully")
254
+ await session.module.stop()
255
+
256
+ if job_id in self.tasks:
257
+ await self.cancel_task(job_id)
258
+ logger.debug(f"session {job_id} ({session.module.name}) stopped successfully")
253
259
  except Exception as e:
254
260
  logger.error(f"Error while stopping module {job_id}: {e}")
255
261
  raise
256
262
  else:
257
263
  return True
258
264
 
259
- async def get_module_status(self, job_id: str) -> ModuleStatus | None:
265
+ async def get_module_status(self, job_id: str) -> TaskStatus:
260
266
  """Retrieve the status of a module job.
261
267
 
262
268
  Args:
263
269
  job_id: The unique identifier of the job.
264
270
 
265
271
  Returns:
266
- ModuleStatus | None: The status of the module, or None if it does not exist.
272
+ ModuleStatus: The status of the module.
267
273
  """
268
- module = self.modules.get(job_id)
269
- return module.status if module else None
274
+ session = self.tasks_sessions.get(job_id, None)
275
+ return session.status if session is not None else TaskStatus.FAILED
270
276
 
271
277
  async def stop_all_modules(self) -> None:
272
278
  """Stop all currently running module jobs.
@@ -274,7 +280,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
274
280
  This method ensures that all active jobs are gracefully terminated.
275
281
  """
276
282
  async with self._lock:
277
- stop_tasks = [self.stop_module(job_id) for job_id in list(self.modules.keys())]
283
+ stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
278
284
  if stop_tasks:
279
285
  await asyncio.gather(*stop_tasks, return_exceptions=True)
280
286
 
@@ -286,9 +292,9 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
286
292
  """
287
293
  return {
288
294
  job_id: {
289
- "name": module.name,
290
- "status": module.status,
291
- "class": module.__class__.__name__,
295
+ "name": session.module.name,
296
+ "status": session.module.status,
297
+ "class": session.module.__class__.__name__,
292
298
  }
293
- for job_id, module in self.modules.items()
299
+ for job_id, session in self.tasks_sessions.items()
294
300
  }