digitalkin 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/grpc_servers/_base_server.py +15 -17
- digitalkin/grpc_servers/module_server.py +9 -10
- digitalkin/grpc_servers/module_servicer.py +108 -85
- digitalkin/grpc_servers/registry_server.py +3 -6
- digitalkin/grpc_servers/registry_servicer.py +18 -19
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +3 -5
- digitalkin/logger.py +45 -1
- digitalkin/models/module/module.py +1 -0
- digitalkin/modules/_base_module.py +47 -6
- digitalkin/modules/job_manager/base_job_manager.py +139 -0
- digitalkin/modules/job_manager/job_manager_models.py +44 -0
- digitalkin/modules/job_manager/single_job_manager.py +218 -0
- digitalkin/modules/job_manager/taskiq_broker.py +173 -0
- digitalkin/modules/job_manager/taskiq_job_manager.py +213 -0
- digitalkin/services/base_strategy.py +3 -1
- digitalkin/services/cost/cost_strategy.py +64 -16
- digitalkin/services/cost/default_cost.py +95 -12
- digitalkin/services/cost/grpc_cost.py +149 -60
- digitalkin/services/filesystem/default_filesystem.py +5 -6
- digitalkin/services/filesystem/filesystem_strategy.py +3 -2
- digitalkin/services/filesystem/grpc_filesystem.py +31 -26
- digitalkin/services/services_config.py +6 -5
- digitalkin/services/setup/__init__.py +1 -0
- digitalkin/services/setup/default_setup.py +10 -12
- digitalkin/services/setup/grpc_setup.py +8 -10
- digitalkin/services/storage/default_storage.py +13 -6
- digitalkin/services/storage/grpc_storage.py +25 -9
- digitalkin/services/storage/storage_strategy.py +3 -2
- digitalkin/utils/arg_parser.py +5 -48
- digitalkin/utils/development_mode_action.py +51 -0
- {digitalkin-0.2.11.dist-info → digitalkin-0.2.13.dist-info}/METADATA +43 -12
- {digitalkin-0.2.11.dist-info → digitalkin-0.2.13.dist-info}/RECORD +40 -33
- {digitalkin-0.2.11.dist-info → digitalkin-0.2.13.dist-info}/WHEEL +1 -1
- modules/cpu_intensive_module.py +271 -0
- modules/minimal_llm_module.py +200 -56
- modules/storage_module.py +5 -6
- modules/text_transform_module.py +1 -1
- digitalkin/modules/job_manager.py +0 -176
- {digitalkin-0.2.11.dist-info → digitalkin-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.2.11.dist-info → digitalkin-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""Background module manager with single instance."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from typing import Any, Generic
|
|
8
|
+
|
|
9
|
+
import grpc
|
|
10
|
+
|
|
11
|
+
from digitalkin.logger import logger
|
|
12
|
+
from digitalkin.models import ModuleStatus
|
|
13
|
+
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
14
|
+
from digitalkin.modules._base_module import BaseModule
|
|
15
|
+
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
16
|
+
from digitalkin.services.services_models import ServicesMode
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
20
|
+
"""Manages a single instance of a module job.
|
|
21
|
+
|
|
22
|
+
This class ensures that only one instance of a module job is active at a time.
|
|
23
|
+
It provides functionality to create, stop, and monitor module jobs, as well as
|
|
24
|
+
to handle their output data.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
modules: dict[str, BaseModule]
|
|
28
|
+
queue: dict[str, asyncio.Queue]
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
module_class: type[BaseModule],
|
|
33
|
+
services_mode: ServicesMode,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initialize the job manager.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
module_class: The class of the module to be managed.
|
|
39
|
+
services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(module_class, services_mode)
|
|
42
|
+
|
|
43
|
+
self._lock = asyncio.Lock()
|
|
44
|
+
self.modules: dict[str, BaseModule] = {}
|
|
45
|
+
self.queues: dict[str, asyncio.Queue] = {}
|
|
46
|
+
|
|
47
|
+
async def add_to_queue(self, job_id: str, output_data: OutputModelT) -> None: # type: ignore
|
|
48
|
+
"""Add output data to the queue for a specific job.
|
|
49
|
+
|
|
50
|
+
This method is used as a callback to handle output data generated by a module job.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
job_id: The unique identifier of the job.
|
|
54
|
+
output_data: The output data produced by the job.
|
|
55
|
+
"""
|
|
56
|
+
await self.queues[job_id].put(output_data.model_dump())
|
|
57
|
+
|
|
58
|
+
@asynccontextmanager # type: ignore
|
|
59
|
+
async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
|
|
60
|
+
"""Generate a stream consumer for a module's output data.
|
|
61
|
+
|
|
62
|
+
This method creates an asynchronous generator that streams output data
|
|
63
|
+
from a specific module job. If the module does not exist, it generates
|
|
64
|
+
an error message.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
job_id: The unique identifier of the job.
|
|
68
|
+
|
|
69
|
+
Yields:
|
|
70
|
+
AsyncGenerator: A stream of output data or error messages.
|
|
71
|
+
"""
|
|
72
|
+
module = self.modules.get(job_id, None)
|
|
73
|
+
|
|
74
|
+
logger.debug("Module %s found: %s", job_id, module)
|
|
75
|
+
|
|
76
|
+
async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
|
|
77
|
+
"""Stream output data from the module.
|
|
78
|
+
|
|
79
|
+
Yields:
|
|
80
|
+
dict: Output data generated by the module.
|
|
81
|
+
"""
|
|
82
|
+
if module is None:
|
|
83
|
+
yield {
|
|
84
|
+
"error": {
|
|
85
|
+
"error_message": f"Module {job_id} not found",
|
|
86
|
+
"code": grpc.StatusCode.NOT_FOUND,
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
while module.status == ModuleStatus.RUNNING or (
|
|
93
|
+
not self.queues[job_id].empty()
|
|
94
|
+
and module.status
|
|
95
|
+
in {
|
|
96
|
+
ModuleStatus.STOPPED,
|
|
97
|
+
ModuleStatus.STOPPING,
|
|
98
|
+
}
|
|
99
|
+
):
|
|
100
|
+
logger.info(f"{job_id=}: {module.status=}")
|
|
101
|
+
yield await self.queues[job_id].get()
|
|
102
|
+
logger.info(f"{job_id=}: {module.status=} | {self.queues[job_id].empty()}")
|
|
103
|
+
|
|
104
|
+
finally:
|
|
105
|
+
del self.queues[job_id]
|
|
106
|
+
|
|
107
|
+
yield _stream()
|
|
108
|
+
|
|
109
|
+
async def create_job(
|
|
110
|
+
self,
|
|
111
|
+
input_data: InputModelT,
|
|
112
|
+
setup_data: SetupModelT,
|
|
113
|
+
mission_id: str,
|
|
114
|
+
setup_version_id: str,
|
|
115
|
+
) -> str:
|
|
116
|
+
"""Create and start a new module job.
|
|
117
|
+
|
|
118
|
+
This method initializes a new module job, assigns it a unique job ID,
|
|
119
|
+
and starts it in the background.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
input_data: The input data required to start the job.
|
|
123
|
+
setup_data: The setup configuration for the module.
|
|
124
|
+
mission_id: The mission ID associated with the job.
|
|
125
|
+
setup_version_id: The setup ID associated with the module.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
str: The unique identifier (job ID) of the created job.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
Exception: If the module fails to start.
|
|
132
|
+
"""
|
|
133
|
+
job_id = str(uuid.uuid4())
|
|
134
|
+
# TODO: Ensure the job_id is unique.
|
|
135
|
+
module = self.module_class(job_id, mission_id=mission_id, setup_version_id=setup_version_id)
|
|
136
|
+
self.modules[job_id] = module
|
|
137
|
+
self.queues[job_id] = asyncio.Queue()
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
await module.start(
|
|
141
|
+
input_data,
|
|
142
|
+
setup_data,
|
|
143
|
+
await self.job_specific_callback(self.add_to_queue, job_id),
|
|
144
|
+
)
|
|
145
|
+
logger.debug("Module %s (%s) started successfully", job_id, module.name)
|
|
146
|
+
except Exception:
|
|
147
|
+
# Remove the module from the manager in case of an error.
|
|
148
|
+
del self.modules[job_id]
|
|
149
|
+
logger.exception("Failed to start module %s: %s", job_id)
|
|
150
|
+
raise
|
|
151
|
+
else:
|
|
152
|
+
return job_id
|
|
153
|
+
|
|
154
|
+
async def stop_module(self, job_id: str) -> bool:
|
|
155
|
+
"""Stop a running module job.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
job_id: The unique identifier of the job to stop.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
bool: True if the module was successfully stopped, False if it does not exist.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
Exception: If an error occurs while stopping the module.
|
|
165
|
+
"""
|
|
166
|
+
async with self._lock:
|
|
167
|
+
module = self.modules.get(job_id)
|
|
168
|
+
if not module:
|
|
169
|
+
logger.warning(f"Module {job_id} not found")
|
|
170
|
+
return False
|
|
171
|
+
try:
|
|
172
|
+
await module.stop()
|
|
173
|
+
# should maybe be added in finally
|
|
174
|
+
del self.queues[job_id]
|
|
175
|
+
del self.modules[job_id]
|
|
176
|
+
logger.debug(f"Module {job_id} ({module.name}) stopped successfully")
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.error(f"Error while stopping module {job_id}: {e}")
|
|
179
|
+
raise
|
|
180
|
+
else:
|
|
181
|
+
return True
|
|
182
|
+
|
|
183
|
+
async def get_module_status(self, job_id: str) -> ModuleStatus | None:
|
|
184
|
+
"""Retrieve the status of a module job.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
job_id: The unique identifier of the job.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
ModuleStatus | None: The status of the module, or None if it does not exist.
|
|
191
|
+
"""
|
|
192
|
+
module = self.modules.get(job_id)
|
|
193
|
+
return module.status if module else None
|
|
194
|
+
|
|
195
|
+
async def stop_all_modules(self) -> None:
|
|
196
|
+
"""Stop all currently running module jobs.
|
|
197
|
+
|
|
198
|
+
This method ensures that all active jobs are gracefully terminated.
|
|
199
|
+
"""
|
|
200
|
+
async with self._lock:
|
|
201
|
+
stop_tasks = [self.stop_module(job_id) for job_id in list(self.modules.keys())]
|
|
202
|
+
if stop_tasks:
|
|
203
|
+
await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
204
|
+
|
|
205
|
+
async def list_modules(self) -> dict[str, dict[str, Any]]:
|
|
206
|
+
"""List all modules along with their statuses.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
dict[str, dict[str, Any]]: A dictionary containing information about all modules and their statuses.
|
|
210
|
+
"""
|
|
211
|
+
return {
|
|
212
|
+
job_id: {
|
|
213
|
+
"name": module.name,
|
|
214
|
+
"status": module.status,
|
|
215
|
+
"class": module.__class__.__name__,
|
|
216
|
+
}
|
|
217
|
+
for job_id, module in self.modules.items()
|
|
218
|
+
}
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Taskiq broker & RSTREAM producer for the job manager."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import pickle # noqa: S403
|
|
8
|
+
|
|
9
|
+
from rstream import Producer
|
|
10
|
+
from rstream.exceptions import PreconditionFailed
|
|
11
|
+
from taskiq import Context, TaskiqDepends, TaskiqMessage
|
|
12
|
+
from taskiq.abc.formatter import TaskiqFormatter
|
|
13
|
+
from taskiq.compat import model_validate
|
|
14
|
+
from taskiq.message import BrokerMessage
|
|
15
|
+
from taskiq_aio_pika import AioPikaBroker
|
|
16
|
+
|
|
17
|
+
from digitalkin.logger import logger
|
|
18
|
+
from digitalkin.models.module.module_types import OutputModelT
|
|
19
|
+
from digitalkin.modules._base_module import BaseModule
|
|
20
|
+
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
21
|
+
from digitalkin.modules.job_manager.job_manager_models import StreamCodeModel
|
|
22
|
+
from digitalkin.services.services_config import ServicesConfig
|
|
23
|
+
from digitalkin.services.services_models import ServicesMode
|
|
24
|
+
|
|
25
|
+
logging.getLogger("taskiq").setLevel(logging.INFO)
|
|
26
|
+
logging.getLogger("aiormq").setLevel(logging.INFO)
|
|
27
|
+
logging.getLogger("aio_pika").setLevel(logging.INFO)
|
|
28
|
+
logging.getLogger("rstream").setLevel(logging.INFO)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PickleFormatter(TaskiqFormatter):
|
|
32
|
+
"""Formatter that pickles the JSON-dumped TaskiqMessage.
|
|
33
|
+
|
|
34
|
+
This lets you send arbitrary Python objects (classes, functions, etc.)
|
|
35
|
+
by first converting to JSON-safe primitives, then pickling that string.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def dumps(self, message: TaskiqMessage) -> BrokerMessage: # noqa: PLR6301
|
|
39
|
+
"""Dumps message from python complex object to JSON.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
message: TaskIQ message
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
BrokerMessage with mandatory information for TaskIQ
|
|
46
|
+
"""
|
|
47
|
+
payload: bytes = pickle.dumps(message)
|
|
48
|
+
|
|
49
|
+
return BrokerMessage(
|
|
50
|
+
task_id=message.task_id,
|
|
51
|
+
task_name=message.task_name,
|
|
52
|
+
message=payload,
|
|
53
|
+
labels=message.labels,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def loads(self, message: bytes) -> TaskiqMessage: # noqa: PLR6301
|
|
57
|
+
"""Recreate Python object from bytes.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
message: Broker message from bytes.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
message with TaskIQ format
|
|
64
|
+
"""
|
|
65
|
+
json_str = pickle.loads(message) # noqa: S301
|
|
66
|
+
return model_validate(TaskiqMessage, json_str)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def define_producer() -> Producer:
|
|
70
|
+
"""Get from the env the connection parameter to RabbitMQ.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Producer
|
|
74
|
+
"""
|
|
75
|
+
host: str = os.environ.get("RABBITMQ_RSTREAM_HOST", "localhost")
|
|
76
|
+
port: str = os.environ.get("RABBITMQ_RSTREAM_PORT", "5552")
|
|
77
|
+
username: str = os.environ.get("RABBITMQ_RSTREAM_USERNAME", "guest")
|
|
78
|
+
password: str = os.environ.get("RABBITMQ_RSTREAM_PASSWORD", "guest")
|
|
79
|
+
|
|
80
|
+
logger.info("Connection to RabbitMQ: %s:%s.", host, port)
|
|
81
|
+
return Producer(host=host, port=int(port), username=username, password=password)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def init_rstream() -> None:
|
|
85
|
+
"""Init a stream for every tasks."""
|
|
86
|
+
try:
|
|
87
|
+
await RSTREAM_PRODUCER.create_stream(
|
|
88
|
+
STREAM,
|
|
89
|
+
exists_ok=True,
|
|
90
|
+
arguments={"max-length-bytes": STREAM_RETENTION},
|
|
91
|
+
)
|
|
92
|
+
except PreconditionFailed:
|
|
93
|
+
logger.warning("stream already exist")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def define_broker() -> AioPikaBroker:
|
|
97
|
+
"""Define broker with from env paramter.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Broker: connected to RabbitMQ and with custom formatter.
|
|
101
|
+
"""
|
|
102
|
+
host: str = os.environ.get("RABBITMQ_BROKER_HOST", "localhost")
|
|
103
|
+
port: str = os.environ.get("RABBITMQ_BROKER_PORT", "5672")
|
|
104
|
+
username: str = os.environ.get("RABBITMQ_BROKER_USERNAME", "guest")
|
|
105
|
+
password: str = os.environ.get("RABBITMQ_BROKER_PASSWORD", "guest")
|
|
106
|
+
|
|
107
|
+
broker = AioPikaBroker(
|
|
108
|
+
f"amqp://{username}:{password}@{host}:{port}",
|
|
109
|
+
startup=[init_rstream],
|
|
110
|
+
)
|
|
111
|
+
broker.formatter = PickleFormatter()
|
|
112
|
+
return broker
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
STREAM = "taskiq_data"
|
|
116
|
+
STREAM_RETENTION = 200_000
|
|
117
|
+
RSTREAM_PRODUCER = define_producer()
|
|
118
|
+
TASKIQ_BROKER = define_broker()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
async def send_message_to_stream(job_id: str, output_data: OutputModelT) -> None: # type: ignore
|
|
122
|
+
"""Callback define to add a message frame to the Rstream.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
job_id: id of the job that sent the message
|
|
126
|
+
output_data: message body as a OutputModelT or error / stream_code
|
|
127
|
+
"""
|
|
128
|
+
body = json.dumps({"job_id": job_id, "output_data": output_data.model_dump()}).encode("utf-8")
|
|
129
|
+
await RSTREAM_PRODUCER.send(stream=STREAM, message=body)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@TASKIQ_BROKER.task
|
|
133
|
+
async def run_task(
|
|
134
|
+
mission_id: str,
|
|
135
|
+
setup_version_id: str,
|
|
136
|
+
module_class: type[BaseModule],
|
|
137
|
+
services_mode: ServicesMode,
|
|
138
|
+
input_data: dict,
|
|
139
|
+
setup_data: dict,
|
|
140
|
+
context: Context = TaskiqDepends(),
|
|
141
|
+
) -> None:
|
|
142
|
+
"""TaskIQ task allowing a module to compute in the background asynchronously.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
mission_id: str,
|
|
146
|
+
setup_version_id: The setup ID associated with the module.
|
|
147
|
+
module_class: type[BaseModule],
|
|
148
|
+
services_mode: ServicesMode,
|
|
149
|
+
input_data: dict,
|
|
150
|
+
setup_data: dict,
|
|
151
|
+
context: Allow TaskIQ context access
|
|
152
|
+
"""
|
|
153
|
+
logger.warning("%s", services_mode)
|
|
154
|
+
services_config = ServicesConfig(
|
|
155
|
+
services_config_strategies=module_class.services_config_strategies,
|
|
156
|
+
services_config_params=module_class.services_config_params,
|
|
157
|
+
mode=services_mode,
|
|
158
|
+
)
|
|
159
|
+
setattr(module_class, "services_config", services_config)
|
|
160
|
+
logger.warning("%s | %s", services_config, module_class.services_config)
|
|
161
|
+
|
|
162
|
+
job_id = context.message.task_id
|
|
163
|
+
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
|
|
164
|
+
module = module_class(job_id, mission_id=mission_id, setup_version_id=setup_version_id)
|
|
165
|
+
|
|
166
|
+
await module.start(
|
|
167
|
+
input_data,
|
|
168
|
+
setup_data,
|
|
169
|
+
callback,
|
|
170
|
+
# ensure that the callback is called when the task is done + allow asyncio to run
|
|
171
|
+
# TODO: should define a BaseModel for stream code / error
|
|
172
|
+
done_callback=lambda _: asyncio.create_task(callback(StreamCodeModel(code="__END_OF_STREAM__"))),
|
|
173
|
+
)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Taskiq job manager module."""
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import taskiq # noqa: F401
|
|
5
|
+
|
|
6
|
+
except ImportError:
|
|
7
|
+
msg = "Install digitalkin[taskiq] to use this functionality\n$ uv pip install digitalkin[taskiq]."
|
|
8
|
+
raise ImportError(msg)
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import contextlib
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
15
|
+
from contextlib import asynccontextmanager
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Generic
|
|
17
|
+
|
|
18
|
+
from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, OffsetType
|
|
19
|
+
|
|
20
|
+
from digitalkin.logger import logger
|
|
21
|
+
from digitalkin.models.module import InputModelT, SetupModelT
|
|
22
|
+
from digitalkin.models.module.module import ModuleStatus
|
|
23
|
+
from digitalkin.modules._base_module import BaseModule
|
|
24
|
+
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
25
|
+
from digitalkin.modules.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
|
|
26
|
+
from digitalkin.services.services_models import ServicesMode
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from taskiq.task import AsyncTaskiqTask
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
33
|
+
"""Taskiq job manager for running modules in Taskiq tasks."""
|
|
34
|
+
|
|
35
|
+
services_mode: ServicesMode
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _define_consumer() -> Consumer:
|
|
39
|
+
"""Get from the env the connection parameter to RabbitMQ.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Consumer
|
|
43
|
+
"""
|
|
44
|
+
host: str = os.environ.get("RABBITMQ_RSTREAM_HOST", "localhost")
|
|
45
|
+
port: str = os.environ.get("RABBITMQ_RSTREAM_PORT", "5552")
|
|
46
|
+
username: str = os.environ.get("RABBITMQ_RSTREAM_USERNAME", "guest")
|
|
47
|
+
password: str = os.environ.get("RABBITMQ_RSTREAM_PASSWORD", "guest")
|
|
48
|
+
|
|
49
|
+
logger.info("Connection to RabbitMQ: %s:%s.", host, port)
|
|
50
|
+
return Consumer(host=host, port=int(port), username=username, password=password)
|
|
51
|
+
|
|
52
|
+
async def _on_message(self, message: bytes, message_context: MessageContext) -> None: # noqa: ARG002
|
|
53
|
+
"""Internal callback: parse JSON and route to the correct job queue."""
|
|
54
|
+
try:
|
|
55
|
+
data = json.loads(message.decode("utf-8"))
|
|
56
|
+
except json.JSONDecodeError:
|
|
57
|
+
return
|
|
58
|
+
job_id = data.get("job_id")
|
|
59
|
+
if not job_id:
|
|
60
|
+
return
|
|
61
|
+
queue = self.job_queues.get(job_id)
|
|
62
|
+
if queue:
|
|
63
|
+
await queue.put(data.get("output_data"))
|
|
64
|
+
|
|
65
|
+
async def _start(self) -> None:
|
|
66
|
+
await TASKIQ_BROKER.startup()
|
|
67
|
+
|
|
68
|
+
self.stream_consumer = self._define_consumer()
|
|
69
|
+
|
|
70
|
+
await self.stream_consumer.create_stream(
|
|
71
|
+
STREAM,
|
|
72
|
+
exists_ok=True,
|
|
73
|
+
arguments={"max-length-bytes": STREAM_RETENTION},
|
|
74
|
+
)
|
|
75
|
+
await self.stream_consumer.start()
|
|
76
|
+
|
|
77
|
+
start_spec = ConsumerOffsetSpecification(OffsetType.LAST)
|
|
78
|
+
# on_message use bytes instead of AMQPMessage
|
|
79
|
+
await self.stream_consumer.subscribe(
|
|
80
|
+
stream=STREAM,
|
|
81
|
+
subscriber_name=f"""subscriber_{os.environ.get("SERVER_NAME", "module_servicer")}""",
|
|
82
|
+
callback=self._on_message, # type: ignore
|
|
83
|
+
offset_specification=start_spec,
|
|
84
|
+
)
|
|
85
|
+
self.stream_consumer_task = asyncio.create_task(
|
|
86
|
+
self.stream_consumer.run(),
|
|
87
|
+
name="stream_consumer_task",
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
async def _stop(self) -> None:
|
|
91
|
+
# Signal the consumer to stop
|
|
92
|
+
await self.stream_consumer.close()
|
|
93
|
+
# Cancel the background task
|
|
94
|
+
self.stream_consumer_task.cancel()
|
|
95
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
96
|
+
await self.stream_consumer_task
|
|
97
|
+
|
|
98
|
+
@asynccontextmanager # type: ignore
|
|
99
|
+
async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
|
|
100
|
+
"""Generate a stream consumer for the RStream stream.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
job_id: The job ID to filter messages.
|
|
104
|
+
|
|
105
|
+
Yields:
|
|
106
|
+
messages: The stream messages from the associated module.
|
|
107
|
+
"""
|
|
108
|
+
queue: asyncio.Queue = asyncio.Queue(maxsize=self.max_queue_size)
|
|
109
|
+
self.job_queues[job_id] = queue
|
|
110
|
+
|
|
111
|
+
async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
|
|
112
|
+
"""Generate the stream allowing flowless communication.
|
|
113
|
+
|
|
114
|
+
Yields:
|
|
115
|
+
dict: generated object from the module
|
|
116
|
+
"""
|
|
117
|
+
while True:
|
|
118
|
+
item = await queue.get()
|
|
119
|
+
queue.task_done()
|
|
120
|
+
yield item
|
|
121
|
+
|
|
122
|
+
while True:
|
|
123
|
+
try:
|
|
124
|
+
item = queue.get_nowait()
|
|
125
|
+
except asyncio.QueueEmpty:
|
|
126
|
+
break
|
|
127
|
+
queue.task_done()
|
|
128
|
+
yield item
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
yield _stream()
|
|
132
|
+
finally:
|
|
133
|
+
self.job_queues.pop(job_id, None)
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
module_class: type[BaseModule],
|
|
138
|
+
services_mode: ServicesMode,
|
|
139
|
+
) -> None:
|
|
140
|
+
"""Initialize the Taskiq job manager."""
|
|
141
|
+
super().__init__(module_class, services_mode)
|
|
142
|
+
|
|
143
|
+
logger.warning("TaskiqJobManager initialized with app: %s", TASKIQ_BROKER)
|
|
144
|
+
self.services_mode = services_mode
|
|
145
|
+
self.job_queues: dict[str, asyncio.Queue] = {}
|
|
146
|
+
self.max_queue_size = 1000
|
|
147
|
+
|
|
148
|
+
async def create_job(
|
|
149
|
+
self,
|
|
150
|
+
input_data: InputModelT,
|
|
151
|
+
setup_data: SetupModelT,
|
|
152
|
+
mission_id: str,
|
|
153
|
+
setup_version_id: str,
|
|
154
|
+
) -> str:
|
|
155
|
+
"""Launches the module_task in Taskiq, returns the Taskiq task id as job_id.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
input_data: Input data for the module
|
|
159
|
+
setup_data: Setup data for the module
|
|
160
|
+
mission_id: Mission ID for the module
|
|
161
|
+
setup_version_id: The setup ID associated with the module.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
job_id: The Taskiq task id.
|
|
165
|
+
|
|
166
|
+
Raises:
|
|
167
|
+
ValueError: If the task is not found.
|
|
168
|
+
"""
|
|
169
|
+
task = TASKIQ_BROKER.find_task("digitalkin.modules.job_manager.taskiq_broker:run_task")
|
|
170
|
+
|
|
171
|
+
if task is None:
|
|
172
|
+
msg = "Task not found"
|
|
173
|
+
raise ValueError(msg)
|
|
174
|
+
|
|
175
|
+
running_task: AsyncTaskiqTask[Any] = await task.kiq(
|
|
176
|
+
mission_id,
|
|
177
|
+
setup_version_id,
|
|
178
|
+
self.module_class,
|
|
179
|
+
self.services_mode,
|
|
180
|
+
input_data.model_dump(),
|
|
181
|
+
setup_data.model_dump(),
|
|
182
|
+
)
|
|
183
|
+
job_id = running_task.task_id
|
|
184
|
+
result = await running_task.wait_result(timeout=10)
|
|
185
|
+
logger.debug("Job %s with data %s", job_id, result)
|
|
186
|
+
return job_id
|
|
187
|
+
|
|
188
|
+
async def stop_module(self, job_id: str) -> bool:
|
|
189
|
+
"""Revoke (terminate) the Taskiq task with id.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
job_id: The Taskiq task id to stop.
|
|
193
|
+
|
|
194
|
+
Raises:
|
|
195
|
+
bool: True if the task was successfully revoked, False otherwise.
|
|
196
|
+
"""
|
|
197
|
+
msg = "stop_module not implemented in TaskiqJobManager"
|
|
198
|
+
raise NotImplementedError(msg)
|
|
199
|
+
|
|
200
|
+
async def stop_all_modules(self) -> None:
|
|
201
|
+
"""Stop all running modules."""
|
|
202
|
+
msg = "stop_all_modules not implemented in TaskiqJobManager"
|
|
203
|
+
raise NotImplementedError(msg)
|
|
204
|
+
|
|
205
|
+
async def get_module_status(self, job_id: str) -> ModuleStatus | None:
|
|
206
|
+
"""Query a module status."""
|
|
207
|
+
msg = "get_module_status not implemented in TaskiqJobManager"
|
|
208
|
+
raise NotImplementedError(msg)
|
|
209
|
+
|
|
210
|
+
async def list_modules(self) -> dict[str, dict[str, Any]]:
|
|
211
|
+
"""List all modules."""
|
|
212
|
+
msg = "list_modules not implemented in TaskiqJobManager"
|
|
213
|
+
raise NotImplementedError(msg)
|
|
@@ -9,10 +9,12 @@ class BaseStrategy(ABC):
|
|
|
9
9
|
This class defines the interface for all strategies.
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
def __init__(self, mission_id: str) -> None:
|
|
12
|
+
def __init__(self, mission_id: str, setup_version_id: str) -> None:
|
|
13
13
|
"""Initialize the strategy.
|
|
14
14
|
|
|
15
15
|
Args:
|
|
16
16
|
mission_id: The ID of the mission this strategy is associated with
|
|
17
|
+
setup_version_id: The ID of the setup version this strategy is associated with
|
|
17
18
|
"""
|
|
18
19
|
self.mission_id: str = mission_id
|
|
20
|
+
self.setup_version_id: str = setup_version_id
|