digitalkin 0.2.23__py3-none-any.whl → 0.3.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/__init__.py +1 -0
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/__init__.py +1 -0
- digitalkin/{modules → core}/job_manager/base_job_manager.py +137 -31
- digitalkin/core/job_manager/single_job_manager.py +354 -0
- digitalkin/{modules → core}/job_manager/taskiq_broker.py +116 -22
- digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
- digitalkin/core/task_manager/__init__.py +1 -0
- digitalkin/core/task_manager/base_task_manager.py +539 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +266 -0
- digitalkin/core/task_manager/task_executor.py +249 -0
- digitalkin/core/task_manager/task_session.py +406 -0
- digitalkin/grpc_servers/__init__.py +1 -19
- digitalkin/grpc_servers/_base_server.py +3 -3
- digitalkin/grpc_servers/module_server.py +27 -43
- digitalkin/grpc_servers/module_servicer.py +51 -36
- digitalkin/grpc_servers/registry_server.py +2 -2
- digitalkin/grpc_servers/registry_servicer.py +4 -4
- digitalkin/grpc_servers/utils/__init__.py +1 -0
- digitalkin/grpc_servers/utils/exceptions.py +0 -8
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +4 -4
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/logger.py +73 -24
- digitalkin/mixins/__init__.py +19 -0
- digitalkin/mixins/base_mixin.py +10 -0
- digitalkin/mixins/callback_mixin.py +24 -0
- digitalkin/mixins/chat_history_mixin.py +110 -0
- digitalkin/mixins/cost_mixin.py +76 -0
- digitalkin/mixins/file_history_mixin.py +93 -0
- digitalkin/mixins/filesystem_mixin.py +46 -0
- digitalkin/mixins/logger_mixin.py +51 -0
- digitalkin/mixins/storage_mixin.py +79 -0
- digitalkin/models/core/__init__.py +1 -0
- digitalkin/{modules/job_manager → models/core}/job_manager_models.py +3 -3
- digitalkin/models/core/task_monitor.py +70 -0
- digitalkin/models/grpc_servers/__init__.py +1 -0
- digitalkin/{grpc_servers/utils → models/grpc_servers}/models.py +5 -5
- digitalkin/models/module/__init__.py +2 -0
- digitalkin/models/module/module.py +9 -1
- digitalkin/models/module/module_context.py +122 -6
- digitalkin/models/module/module_types.py +307 -19
- digitalkin/models/services/__init__.py +9 -0
- digitalkin/models/services/cost.py +1 -0
- digitalkin/models/services/storage.py +39 -5
- digitalkin/modules/_base_module.py +123 -118
- digitalkin/modules/tool_module.py +10 -2
- digitalkin/modules/trigger_handler.py +7 -6
- digitalkin/services/cost/__init__.py +9 -2
- digitalkin/services/cost/grpc_cost.py +9 -42
- digitalkin/services/filesystem/default_filesystem.py +0 -2
- digitalkin/services/filesystem/grpc_filesystem.py +10 -39
- digitalkin/services/setup/default_setup.py +5 -6
- digitalkin/services/setup/grpc_setup.py +52 -15
- digitalkin/services/storage/grpc_storage.py +4 -4
- digitalkin/services/user_profile/__init__.py +1 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- digitalkin/utils/__init__.py +28 -0
- digitalkin/utils/arg_parser.py +1 -1
- digitalkin/utils/development_mode_action.py +2 -2
- digitalkin/utils/dynamic_schema.py +483 -0
- digitalkin/utils/package_discover.py +1 -2
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/METADATA +11 -30
- digitalkin-0.3.1.dev2.dist-info/RECORD +119 -0
- modules/dynamic_setup_module.py +362 -0
- digitalkin/grpc_servers/utils/factory.py +0 -180
- digitalkin/modules/job_manager/single_job_manager.py +0 -294
- digitalkin/modules/job_manager/taskiq_job_manager.py +0 -290
- digitalkin-0.2.23.dist-info/RECORD +0 -89
- /digitalkin/{grpc_servers/utils → models/grpc_servers}/types.py +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/WHEEL +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,354 @@
|
|
|
1
|
+
"""Background module manager with single instance."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import datetime
|
|
5
|
+
import uuid
|
|
6
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import grpc
|
|
11
|
+
|
|
12
|
+
from digitalkin.core.common import ConnectionFactory, ModuleFactory
|
|
13
|
+
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
14
|
+
from digitalkin.core.task_manager.local_task_manager import LocalTaskManager
|
|
15
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
16
|
+
from digitalkin.logger import logger
|
|
17
|
+
from digitalkin.models.core.task_monitor import TaskStatus
|
|
18
|
+
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
19
|
+
from digitalkin.models.module.module import ModuleCodeModel
|
|
20
|
+
from digitalkin.modules._base_module import BaseModule
|
|
21
|
+
from digitalkin.services.services_models import ServicesMode
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
25
|
+
"""Manages a single instance of a module job.
|
|
26
|
+
|
|
27
|
+
This class ensures that only one instance of a module job is active at a time.
|
|
28
|
+
It provides functionality to create, stop, and monitor module jobs, as well as
|
|
29
|
+
to handle their output data.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
async def start(self) -> None:
|
|
33
|
+
"""Start manager."""
|
|
34
|
+
self.channel = await ConnectionFactory.create_surreal_connection("task_manager", datetime.timedelta(seconds=5))
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
module_class: type[BaseModule],
|
|
39
|
+
services_mode: ServicesMode,
|
|
40
|
+
default_timeout: float = 10.0,
|
|
41
|
+
max_concurrent_tasks: int = 100,
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Initialize the job manager.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
module_class: The class of the module to be managed.
|
|
47
|
+
services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
|
|
48
|
+
default_timeout: Default timeout for task operations
|
|
49
|
+
max_concurrent_tasks: Maximum number of concurrent tasks
|
|
50
|
+
"""
|
|
51
|
+
# Create local task manager for same-process execution
|
|
52
|
+
task_manager = LocalTaskManager(default_timeout, max_concurrent_tasks)
|
|
53
|
+
|
|
54
|
+
# Initialize base job manager with task manager
|
|
55
|
+
super().__init__(module_class, services_mode, task_manager)
|
|
56
|
+
|
|
57
|
+
self._lock = asyncio.Lock()
|
|
58
|
+
|
|
59
|
+
async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT | ModuleCodeModel:
|
|
60
|
+
"""Generate a stream consumer for a module's output data.
|
|
61
|
+
|
|
62
|
+
This method creates an asynchronous generator that streams output data
|
|
63
|
+
from a specific module job. If the module does not exist, it generates
|
|
64
|
+
an error message.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
job_id: The unique identifier of the job.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
SetupModelT | ModuleCodeModel: the SetupModelT object fully processed.
|
|
71
|
+
"""
|
|
72
|
+
if (session := self.tasks_sessions.get(job_id, None)) is None:
|
|
73
|
+
return ModuleCodeModel(
|
|
74
|
+
code=str(grpc.StatusCode.NOT_FOUND),
|
|
75
|
+
message=f"Module {job_id} not found",
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
logger.debug("Module %s found: %s", job_id, session.module)
|
|
79
|
+
try:
|
|
80
|
+
# Add timeout to prevent indefinite blocking
|
|
81
|
+
return await asyncio.wait_for(session.queue.get(), timeout=30.0)
|
|
82
|
+
except asyncio.TimeoutError:
|
|
83
|
+
logger.error("Timeout waiting for config setup response from module %s", job_id)
|
|
84
|
+
return ModuleCodeModel(
|
|
85
|
+
code=str(grpc.StatusCode.DEADLINE_EXCEEDED),
|
|
86
|
+
message=f"Module {job_id} did not respond within 30 seconds",
|
|
87
|
+
)
|
|
88
|
+
finally:
|
|
89
|
+
logger.info(f"{job_id=}: {session.queue.empty()}")
|
|
90
|
+
|
|
91
|
+
async def create_config_setup_instance_job(
|
|
92
|
+
self,
|
|
93
|
+
config_setup_data: SetupModelT,
|
|
94
|
+
mission_id: str,
|
|
95
|
+
setup_id: str,
|
|
96
|
+
setup_version_id: str,
|
|
97
|
+
) -> str:
|
|
98
|
+
"""Create and start a new module setup configuration job.
|
|
99
|
+
|
|
100
|
+
This method initializes a new module job, assigns it a unique job ID,
|
|
101
|
+
and starts the config setup it in the background.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
config_setup_data: The input data required to start the job.
|
|
105
|
+
mission_id: The mission ID associated with the job.
|
|
106
|
+
setup_id: The setup ID associated with the module.
|
|
107
|
+
setup_version_id: The setup ID.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
str: The unique identifier (job ID) of the created job.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
Exception: If the module fails to start.
|
|
114
|
+
"""
|
|
115
|
+
job_id = str(uuid.uuid4())
|
|
116
|
+
# TODO: Ensure the job_id is unique.
|
|
117
|
+
module = ModuleFactory.create_module_instance(self.module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
118
|
+
self.tasks_sessions[job_id] = TaskSession(job_id, mission_id, self.channel, module)
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
await module.start_config_setup(
|
|
122
|
+
config_setup_data,
|
|
123
|
+
await self.job_specific_callback(self.add_to_queue, job_id),
|
|
124
|
+
)
|
|
125
|
+
logger.debug("Module %s (%s) started successfully", job_id, module.name)
|
|
126
|
+
except Exception:
|
|
127
|
+
# Remove the module from the manager in case of an error.
|
|
128
|
+
del self.tasks_sessions[job_id]
|
|
129
|
+
logger.exception("Failed to start module %s: %s", job_id)
|
|
130
|
+
raise
|
|
131
|
+
else:
|
|
132
|
+
return job_id
|
|
133
|
+
|
|
134
|
+
async def add_to_queue(self, job_id: str, output_data: OutputModelT | ModuleCodeModel) -> None:
|
|
135
|
+
"""Add output data to the queue for a specific job.
|
|
136
|
+
|
|
137
|
+
This method is used as a callback to handle output data generated by a module job.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
job_id: The unique identifier of the job.
|
|
141
|
+
output_data: The output data produced by the job.
|
|
142
|
+
"""
|
|
143
|
+
await self.tasks_sessions[job_id].queue.put(output_data.model_dump())
|
|
144
|
+
|
|
145
|
+
@asynccontextmanager # type: ignore
|
|
146
|
+
async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
|
|
147
|
+
"""Generate a stream consumer for a module's output data.
|
|
148
|
+
|
|
149
|
+
This method creates an asynchronous generator that streams output data
|
|
150
|
+
from a specific module job. If the module does not exist, it generates
|
|
151
|
+
an error message.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
job_id: The unique identifier of the job.
|
|
155
|
+
|
|
156
|
+
Yields:
|
|
157
|
+
AsyncGenerator: A stream of output data or error messages.
|
|
158
|
+
"""
|
|
159
|
+
if (session := self.tasks_sessions.get(job_id, None)) is None:
|
|
160
|
+
|
|
161
|
+
async def _error_gen() -> AsyncGenerator[dict[str, Any], None]: # noqa: RUF029
|
|
162
|
+
"""Generate an error message for a non-existent module.
|
|
163
|
+
|
|
164
|
+
Yields:
|
|
165
|
+
AsyncGenerator: A generator yielding an error message.
|
|
166
|
+
"""
|
|
167
|
+
yield {
|
|
168
|
+
"error": {
|
|
169
|
+
"error_message": f"Module {job_id} not found",
|
|
170
|
+
"code": grpc.StatusCode.NOT_FOUND,
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
yield _error_gen()
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
logger.debug("Session: %s with Module %s", job_id, session.module)
|
|
178
|
+
|
|
179
|
+
async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
|
|
180
|
+
"""Stream output data from the module with simple blocking pattern.
|
|
181
|
+
|
|
182
|
+
This implementation uses a simple one-item-at-a-time pattern optimized
|
|
183
|
+
for local execution where we have direct access to session status:
|
|
184
|
+
1. Block waiting for each item
|
|
185
|
+
2. Check termination conditions after each item
|
|
186
|
+
3. Clean shutdown when task completes
|
|
187
|
+
|
|
188
|
+
This pattern provides:
|
|
189
|
+
- Immediate termination when task completes
|
|
190
|
+
- Direct session status monitoring
|
|
191
|
+
- Simple, predictable behavior for local tasks
|
|
192
|
+
|
|
193
|
+
Yields:
|
|
194
|
+
dict: Output data generated by the module.
|
|
195
|
+
"""
|
|
196
|
+
while True:
|
|
197
|
+
# Block for next item - if queue is empty but producer not finished yet
|
|
198
|
+
msg = await session.queue.get()
|
|
199
|
+
try:
|
|
200
|
+
yield msg
|
|
201
|
+
finally:
|
|
202
|
+
# Always mark task as done, even if consumer raises exception
|
|
203
|
+
session.queue.task_done()
|
|
204
|
+
|
|
205
|
+
# Check termination conditions after each message
|
|
206
|
+
# This allows immediate shutdown when the task completes
|
|
207
|
+
if (
|
|
208
|
+
session.is_cancelled.is_set()
|
|
209
|
+
or (session.status is TaskStatus.COMPLETED and session.queue.empty())
|
|
210
|
+
or session.status is TaskStatus.FAILED
|
|
211
|
+
):
|
|
212
|
+
logger.debug(
|
|
213
|
+
"Stream ending for job %s: cancelled=%s, status=%s, queue_empty=%s",
|
|
214
|
+
job_id,
|
|
215
|
+
session.is_cancelled.is_set(),
|
|
216
|
+
session.status,
|
|
217
|
+
session.queue.empty(),
|
|
218
|
+
)
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
yield _stream()
|
|
222
|
+
|
|
223
|
+
async def create_module_instance_job(
|
|
224
|
+
self,
|
|
225
|
+
input_data: InputModelT,
|
|
226
|
+
setup_data: SetupModelT,
|
|
227
|
+
mission_id: str,
|
|
228
|
+
setup_id: str,
|
|
229
|
+
setup_version_id: str,
|
|
230
|
+
) -> str:
|
|
231
|
+
"""Create and start a new module job.
|
|
232
|
+
|
|
233
|
+
This method initializes a new module job, assigns it a unique job ID,
|
|
234
|
+
and starts it in the background.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
input_data: The input data required to start the job.
|
|
238
|
+
setup_data: The setup configuration for the module.
|
|
239
|
+
mission_id: The mission ID associated with the job.
|
|
240
|
+
setup_id: The setup ID associated with the module.
|
|
241
|
+
setup_version_id: The setup Version ID associated with the module.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
str: The unique identifier (job ID) of the created job.
|
|
245
|
+
|
|
246
|
+
Raises:
|
|
247
|
+
Exception: If the module fails to start.
|
|
248
|
+
"""
|
|
249
|
+
job_id = str(uuid.uuid4())
|
|
250
|
+
module = ModuleFactory.create_module_instance(self.module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
251
|
+
callback = await self.job_specific_callback(self.add_to_queue, job_id)
|
|
252
|
+
|
|
253
|
+
await self.create_task(
|
|
254
|
+
job_id,
|
|
255
|
+
mission_id,
|
|
256
|
+
module,
|
|
257
|
+
module.start(input_data, setup_data, callback, done_callback=None),
|
|
258
|
+
)
|
|
259
|
+
logger.info("Managed task started: '%s'", job_id, extra={"task_id": job_id})
|
|
260
|
+
return job_id
|
|
261
|
+
|
|
262
|
+
async def stop_module(self, job_id: str) -> bool:
|
|
263
|
+
"""Stop a running module job.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
job_id: The unique identifier of the job to stop.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
bool: True if the module was successfully stopped, False if it does not exist.
|
|
270
|
+
|
|
271
|
+
Raises:
|
|
272
|
+
Exception: If an error occurs while stopping the module.
|
|
273
|
+
"""
|
|
274
|
+
logger.info(f"STOP required for {job_id=}")
|
|
275
|
+
|
|
276
|
+
async with self._lock:
|
|
277
|
+
session = self.tasks_sessions.get(job_id)
|
|
278
|
+
|
|
279
|
+
if not session:
|
|
280
|
+
logger.warning(f"session with id: {job_id} not found")
|
|
281
|
+
return False
|
|
282
|
+
try:
|
|
283
|
+
await session.module.stop()
|
|
284
|
+
await self.cancel_task(job_id, session.mission_id)
|
|
285
|
+
logger.debug(f"session {job_id} ({session.module.name}) stopped successfully")
|
|
286
|
+
except Exception as e:
|
|
287
|
+
logger.error(f"Error while stopping module {job_id}: {e}")
|
|
288
|
+
raise
|
|
289
|
+
else:
|
|
290
|
+
return True
|
|
291
|
+
|
|
292
|
+
async def get_module_status(self, job_id: str) -> TaskStatus:
|
|
293
|
+
"""Retrieve the status of a module job.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
job_id: The unique identifier of the job.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
ModuleStatus: The status of the module.
|
|
300
|
+
"""
|
|
301
|
+
session = self.tasks_sessions.get(job_id, None)
|
|
302
|
+
return session.status if session is not None else TaskStatus.FAILED
|
|
303
|
+
|
|
304
|
+
async def wait_for_completion(self, job_id: str) -> None:
|
|
305
|
+
"""Wait for a task to complete by awaiting its asyncio.Task.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
job_id: The unique identifier of the job to wait for.
|
|
309
|
+
|
|
310
|
+
Raises:
|
|
311
|
+
KeyError: If the job_id is not found in tasks.
|
|
312
|
+
"""
|
|
313
|
+
if job_id not in self._task_manager.tasks:
|
|
314
|
+
msg = f"Job {job_id} not found"
|
|
315
|
+
raise KeyError(msg)
|
|
316
|
+
await self._task_manager.tasks[job_id]
|
|
317
|
+
|
|
318
|
+
async def stop_all_modules(self) -> None:
|
|
319
|
+
"""Stop all currently running module jobs.
|
|
320
|
+
|
|
321
|
+
This method ensures that all active jobs are gracefully terminated
|
|
322
|
+
and closes the SurrealDB connection.
|
|
323
|
+
"""
|
|
324
|
+
# Snapshot job IDs while holding lock
|
|
325
|
+
async with self._lock:
|
|
326
|
+
job_ids = list(self.tasks_sessions.keys())
|
|
327
|
+
|
|
328
|
+
# Release lock before calling stop_module (which has its own lock)
|
|
329
|
+
if job_ids:
|
|
330
|
+
stop_tasks = [self.stop_module(job_id) for job_id in job_ids]
|
|
331
|
+
await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
332
|
+
|
|
333
|
+
# Close SurrealDB connection after stopping all modules
|
|
334
|
+
if hasattr(self, "channel"):
|
|
335
|
+
try:
|
|
336
|
+
await self.channel.close()
|
|
337
|
+
logger.info("SingleJobManager: SurrealDB connection closed")
|
|
338
|
+
except Exception as e:
|
|
339
|
+
logger.warning("Failed to close SurrealDB connection: %s", e)
|
|
340
|
+
|
|
341
|
+
async def list_modules(self) -> dict[str, dict[str, Any]]:
|
|
342
|
+
"""List all modules along with their statuses.
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
dict[str, dict[str, Any]]: A dictionary containing information about all modules and their statuses.
|
|
346
|
+
"""
|
|
347
|
+
return {
|
|
348
|
+
job_id: {
|
|
349
|
+
"name": session.module.name,
|
|
350
|
+
"status": session.module.status,
|
|
351
|
+
"class": session.module.__class__.__name__,
|
|
352
|
+
}
|
|
353
|
+
for job_id, session in self.tasks_sessions.items()
|
|
354
|
+
}
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"""Taskiq broker & RSTREAM producer for the job manager."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import datetime
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
7
|
import os
|
|
7
8
|
import pickle # noqa: S403
|
|
9
|
+
from typing import Any
|
|
8
10
|
|
|
9
11
|
from rstream import Producer
|
|
10
12
|
from rstream.exceptions import PreconditionFailed
|
|
@@ -14,11 +16,14 @@ from taskiq.compat import model_validate
|
|
|
14
16
|
from taskiq.message import BrokerMessage
|
|
15
17
|
from taskiq_aio_pika import AioPikaBroker
|
|
16
18
|
|
|
19
|
+
from digitalkin.core.common import ConnectionFactory, ModuleFactory
|
|
20
|
+
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
21
|
+
from digitalkin.core.task_manager.task_executor import TaskExecutor
|
|
22
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
17
23
|
from digitalkin.logger import logger
|
|
24
|
+
from digitalkin.models.core.job_manager_models import StreamCodeModel
|
|
18
25
|
from digitalkin.models.module.module_types import OutputModelT
|
|
19
26
|
from digitalkin.modules._base_module import BaseModule
|
|
20
|
-
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
21
|
-
from digitalkin.modules.job_manager.job_manager_models import StreamCodeModel
|
|
22
27
|
from digitalkin.services.services_config import ServicesConfig
|
|
23
28
|
from digitalkin.services.services_models import ServicesMode
|
|
24
29
|
|
|
@@ -118,6 +123,24 @@ RSTREAM_PRODUCER = define_producer()
|
|
|
118
123
|
TASKIQ_BROKER = define_broker()
|
|
119
124
|
|
|
120
125
|
|
|
126
|
+
async def cleanup_global_resources() -> None:
|
|
127
|
+
"""Clean up global resources (producer and broker connections).
|
|
128
|
+
|
|
129
|
+
This should be called during shutdown to prevent connection leaks.
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
await RSTREAM_PRODUCER.close()
|
|
133
|
+
logger.info("RStream producer closed successfully")
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.warning("Failed to close RStream producer: %s", e)
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
await TASKIQ_BROKER.shutdown()
|
|
139
|
+
logger.info("Taskiq broker shut down successfully")
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.warning("Failed to shutdown Taskiq broker: %s", e)
|
|
142
|
+
|
|
143
|
+
|
|
121
144
|
async def send_message_to_stream(job_id: str, output_data: OutputModelT) -> None: # type: ignore
|
|
122
145
|
"""Callback define to add a message frame to the Rstream.
|
|
123
146
|
|
|
@@ -152,27 +175,70 @@ async def run_start_module(
|
|
|
152
175
|
setup_data: dict,
|
|
153
176
|
context: Allow TaskIQ context access
|
|
154
177
|
"""
|
|
155
|
-
logger.
|
|
178
|
+
logger.info("Starting module with services_mode: %s", services_mode)
|
|
156
179
|
services_config = ServicesConfig(
|
|
157
180
|
services_config_strategies=module_class.services_config_strategies,
|
|
158
181
|
services_config_params=module_class.services_config_params,
|
|
159
182
|
mode=services_mode,
|
|
160
183
|
)
|
|
161
184
|
setattr(module_class, "services_config", services_config)
|
|
162
|
-
logger.
|
|
185
|
+
logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
|
|
186
|
+
module_class.discover()
|
|
163
187
|
|
|
164
188
|
job_id = context.message.task_id
|
|
165
189
|
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
|
|
166
|
-
module = module_class
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
#
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
190
|
+
module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
191
|
+
|
|
192
|
+
channel = None
|
|
193
|
+
try:
|
|
194
|
+
# Create TaskExecutor and supporting components for worker execution
|
|
195
|
+
executor = TaskExecutor()
|
|
196
|
+
# SurrealDB env vars are expected to be set in env.
|
|
197
|
+
channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
|
|
198
|
+
session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
|
|
199
|
+
|
|
200
|
+
# Execute the task using TaskExecutor
|
|
201
|
+
# Create a proper done callback that handles errors
|
|
202
|
+
async def send_end_of_stream(_: Any) -> None: # noqa: ANN401
|
|
203
|
+
try:
|
|
204
|
+
await callback(StreamCodeModel(code="__END_OF_STREAM__"))
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error("Error sending end of stream: %s", e, exc_info=True)
|
|
207
|
+
|
|
208
|
+
# Reconstruct Pydantic models from dicts for type safety
|
|
209
|
+
try:
|
|
210
|
+
input_model = module_class.create_input_model(input_data)
|
|
211
|
+
setup_model = await module_class.create_setup_model(setup_data)
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error("Failed to reconstruct models for job %s: %s", job_id, e, exc_info=True)
|
|
214
|
+
raise
|
|
215
|
+
|
|
216
|
+
supervisor_task = await executor.execute_task(
|
|
217
|
+
task_id=job_id,
|
|
218
|
+
mission_id=mission_id,
|
|
219
|
+
coro=module.start(
|
|
220
|
+
input_model,
|
|
221
|
+
setup_model,
|
|
222
|
+
callback,
|
|
223
|
+
done_callback=lambda result: asyncio.ensure_future(send_end_of_stream(result)),
|
|
224
|
+
),
|
|
225
|
+
session=session,
|
|
226
|
+
channel=channel,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Wait for the supervisor task to complete
|
|
230
|
+
await supervisor_task
|
|
231
|
+
logger.info("Module task %s completed", job_id)
|
|
232
|
+
except Exception:
|
|
233
|
+
logger.exception("Error running module %s", job_id)
|
|
234
|
+
raise
|
|
235
|
+
finally:
|
|
236
|
+
# Cleanup channel
|
|
237
|
+
if channel is not None:
|
|
238
|
+
try:
|
|
239
|
+
await channel.close()
|
|
240
|
+
except Exception:
|
|
241
|
+
logger.exception("Error closing channel for job %s", job_id)
|
|
176
242
|
|
|
177
243
|
|
|
178
244
|
@TASKIQ_BROKER.task
|
|
@@ -194,23 +260,51 @@ async def run_config_module(
|
|
|
194
260
|
module_class: type[BaseModule],
|
|
195
261
|
services_mode: ServicesMode,
|
|
196
262
|
config_setup_data: dict,
|
|
197
|
-
setup_data: dict,
|
|
198
263
|
context: Allow TaskIQ context access
|
|
199
264
|
"""
|
|
200
|
-
logger.
|
|
265
|
+
logger.info("Starting config module with services_mode: %s", services_mode)
|
|
201
266
|
services_config = ServicesConfig(
|
|
202
267
|
services_config_strategies=module_class.services_config_strategies,
|
|
203
268
|
services_config_params=module_class.services_config_params,
|
|
204
269
|
mode=services_mode,
|
|
205
270
|
)
|
|
206
271
|
setattr(module_class, "services_config", services_config)
|
|
207
|
-
logger.
|
|
272
|
+
logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
|
|
208
273
|
|
|
209
274
|
job_id = context.message.task_id
|
|
210
275
|
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
|
|
211
|
-
module = module_class
|
|
276
|
+
module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
212
277
|
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
278
|
+
# Override environment variables temporarily to use manager's SurrealDB
|
|
279
|
+
channel = None
|
|
280
|
+
try:
|
|
281
|
+
# Create TaskExecutor and supporting components for worker execution
|
|
282
|
+
executor = TaskExecutor()
|
|
283
|
+
# SurrealDB env vars are expected to be set in env.
|
|
284
|
+
channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
|
|
285
|
+
session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
|
|
286
|
+
|
|
287
|
+
# Create and run the config setup task with TaskExecutor
|
|
288
|
+
setup_model = module_class.create_config_setup_model(config_setup_data)
|
|
289
|
+
|
|
290
|
+
supervisor_task = await executor.execute_task(
|
|
291
|
+
task_id=job_id,
|
|
292
|
+
mission_id=mission_id,
|
|
293
|
+
coro=module.start_config_setup(setup_model, callback),
|
|
294
|
+
session=session,
|
|
295
|
+
channel=channel,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Wait for the supervisor task to complete
|
|
299
|
+
await supervisor_task
|
|
300
|
+
logger.info("Config module task %s completed", job_id)
|
|
301
|
+
except Exception:
|
|
302
|
+
logger.exception("Error running config module %s", job_id)
|
|
303
|
+
raise
|
|
304
|
+
finally:
|
|
305
|
+
# Cleanup channel
|
|
306
|
+
if channel is not None:
|
|
307
|
+
try:
|
|
308
|
+
await channel.close()
|
|
309
|
+
except Exception:
|
|
310
|
+
logger.exception("Error closing channel for job %s", job_id)
|