digitalkin 0.3.0rc1__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/base_job_manager.py +128 -28
- digitalkin/core/job_manager/single_job_manager.py +80 -25
- digitalkin/core/job_manager/taskiq_broker.py +114 -19
- digitalkin/core/job_manager/taskiq_job_manager.py +291 -39
- digitalkin/core/task_manager/base_task_manager.py +539 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +43 -4
- digitalkin/core/task_manager/task_executor.py +249 -0
- digitalkin/core/task_manager/task_session.py +107 -19
- digitalkin/grpc_servers/module_server.py +2 -2
- digitalkin/grpc_servers/module_servicer.py +21 -12
- digitalkin/grpc_servers/registry_server.py +1 -1
- digitalkin/grpc_servers/registry_servicer.py +4 -4
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/models/core/task_monitor.py +17 -0
- digitalkin/models/grpc_servers/models.py +4 -4
- digitalkin/models/module/module_context.py +5 -0
- digitalkin/models/module/module_types.py +304 -16
- digitalkin/modules/_base_module.py +66 -28
- digitalkin/services/cost/grpc_cost.py +8 -41
- digitalkin/services/filesystem/grpc_filesystem.py +9 -38
- digitalkin/services/services_config.py +11 -0
- digitalkin/services/services_models.py +3 -1
- digitalkin/services/setup/default_setup.py +5 -6
- digitalkin/services/setup/grpc_setup.py +51 -14
- digitalkin/services/storage/grpc_storage.py +2 -2
- digitalkin/services/user_profile/__init__.py +12 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- digitalkin/utils/__init__.py +28 -0
- digitalkin/utils/dynamic_schema.py +483 -0
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/METADATA +9 -29
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/RECORD +42 -30
- modules/dynamic_setup_module.py +362 -0
- digitalkin/core/task_manager/task_manager.py +0 -439
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/WHEEL +0 -0
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.3.0rc1.dist-info → digitalkin-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
"""Taskiq broker & RSTREAM producer for the job manager."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import datetime
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
7
|
import os
|
|
7
8
|
import pickle # noqa: S403
|
|
9
|
+
from typing import Any
|
|
8
10
|
|
|
9
11
|
from rstream import Producer
|
|
10
12
|
from rstream.exceptions import PreconditionFailed
|
|
@@ -14,7 +16,10 @@ from taskiq.compat import model_validate
|
|
|
14
16
|
from taskiq.message import BrokerMessage
|
|
15
17
|
from taskiq_aio_pika import AioPikaBroker
|
|
16
18
|
|
|
19
|
+
from digitalkin.core.common import ConnectionFactory, ModuleFactory
|
|
17
20
|
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
21
|
+
from digitalkin.core.task_manager.task_executor import TaskExecutor
|
|
22
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
18
23
|
from digitalkin.logger import logger
|
|
19
24
|
from digitalkin.models.core.job_manager_models import StreamCodeModel
|
|
20
25
|
from digitalkin.models.module.module_types import OutputModelT
|
|
@@ -118,6 +123,24 @@ RSTREAM_PRODUCER = define_producer()
|
|
|
118
123
|
TASKIQ_BROKER = define_broker()
|
|
119
124
|
|
|
120
125
|
|
|
126
|
+
async def cleanup_global_resources() -> None:
|
|
127
|
+
"""Clean up global resources (producer and broker connections).
|
|
128
|
+
|
|
129
|
+
This should be called during shutdown to prevent connection leaks.
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
await RSTREAM_PRODUCER.close()
|
|
133
|
+
logger.info("RStream producer closed successfully")
|
|
134
|
+
except Exception as e:
|
|
135
|
+
logger.warning("Failed to close RStream producer: %s", e)
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
await TASKIQ_BROKER.shutdown()
|
|
139
|
+
logger.info("Taskiq broker shut down successfully")
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.warning("Failed to shutdown Taskiq broker: %s", e)
|
|
142
|
+
|
|
143
|
+
|
|
121
144
|
async def send_message_to_stream(job_id: str, output_data: OutputModelT) -> None: # type: ignore
|
|
122
145
|
"""Callback define to add a message frame to the Rstream.
|
|
123
146
|
|
|
@@ -152,27 +175,70 @@ async def run_start_module(
|
|
|
152
175
|
setup_data: dict,
|
|
153
176
|
context: Allow TaskIQ context access
|
|
154
177
|
"""
|
|
155
|
-
logger.
|
|
178
|
+
logger.info("Starting module with services_mode: %s", services_mode)
|
|
156
179
|
services_config = ServicesConfig(
|
|
157
180
|
services_config_strategies=module_class.services_config_strategies,
|
|
158
181
|
services_config_params=module_class.services_config_params,
|
|
159
182
|
mode=services_mode,
|
|
160
183
|
)
|
|
161
184
|
setattr(module_class, "services_config", services_config)
|
|
162
|
-
logger.
|
|
185
|
+
logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
|
|
186
|
+
module_class.discover()
|
|
163
187
|
|
|
164
188
|
job_id = context.message.task_id
|
|
165
189
|
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
|
|
166
|
-
module = module_class
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
#
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
190
|
+
module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
191
|
+
|
|
192
|
+
channel = None
|
|
193
|
+
try:
|
|
194
|
+
# Create TaskExecutor and supporting components for worker execution
|
|
195
|
+
executor = TaskExecutor()
|
|
196
|
+
# SurrealDB env vars are expected to be set in env.
|
|
197
|
+
channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
|
|
198
|
+
session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
|
|
199
|
+
|
|
200
|
+
# Execute the task using TaskExecutor
|
|
201
|
+
# Create a proper done callback that handles errors
|
|
202
|
+
async def send_end_of_stream(_: Any) -> None: # noqa: ANN401
|
|
203
|
+
try:
|
|
204
|
+
await callback(StreamCodeModel(code="__END_OF_STREAM__"))
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error("Error sending end of stream: %s", e, exc_info=True)
|
|
207
|
+
|
|
208
|
+
# Reconstruct Pydantic models from dicts for type safety
|
|
209
|
+
try:
|
|
210
|
+
input_model = module_class.create_input_model(input_data)
|
|
211
|
+
setup_model = await module_class.create_setup_model(setup_data)
|
|
212
|
+
except Exception as e:
|
|
213
|
+
logger.error("Failed to reconstruct models for job %s: %s", job_id, e, exc_info=True)
|
|
214
|
+
raise
|
|
215
|
+
|
|
216
|
+
supervisor_task = await executor.execute_task(
|
|
217
|
+
task_id=job_id,
|
|
218
|
+
mission_id=mission_id,
|
|
219
|
+
coro=module.start(
|
|
220
|
+
input_model,
|
|
221
|
+
setup_model,
|
|
222
|
+
callback,
|
|
223
|
+
done_callback=lambda result: asyncio.ensure_future(send_end_of_stream(result)),
|
|
224
|
+
),
|
|
225
|
+
session=session,
|
|
226
|
+
channel=channel,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Wait for the supervisor task to complete
|
|
230
|
+
await supervisor_task
|
|
231
|
+
logger.info("Module task %s completed", job_id)
|
|
232
|
+
except Exception:
|
|
233
|
+
logger.exception("Error running module %s", job_id)
|
|
234
|
+
raise
|
|
235
|
+
finally:
|
|
236
|
+
# Cleanup channel
|
|
237
|
+
if channel is not None:
|
|
238
|
+
try:
|
|
239
|
+
await channel.close()
|
|
240
|
+
except Exception:
|
|
241
|
+
logger.exception("Error closing channel for job %s", job_id)
|
|
176
242
|
|
|
177
243
|
|
|
178
244
|
@TASKIQ_BROKER.task
|
|
@@ -196,20 +262,49 @@ async def run_config_module(
|
|
|
196
262
|
config_setup_data: dict,
|
|
197
263
|
context: Allow TaskIQ context access
|
|
198
264
|
"""
|
|
199
|
-
logger.
|
|
265
|
+
logger.info("Starting config module with services_mode: %s", services_mode)
|
|
200
266
|
services_config = ServicesConfig(
|
|
201
267
|
services_config_strategies=module_class.services_config_strategies,
|
|
202
268
|
services_config_params=module_class.services_config_params,
|
|
203
269
|
mode=services_mode,
|
|
204
270
|
)
|
|
205
271
|
setattr(module_class, "services_config", services_config)
|
|
206
|
-
logger.
|
|
272
|
+
logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
|
|
207
273
|
|
|
208
274
|
job_id = context.message.task_id
|
|
209
275
|
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
|
|
210
|
-
module = module_class
|
|
276
|
+
module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
211
277
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
278
|
+
# Override environment variables temporarily to use manager's SurrealDB
|
|
279
|
+
channel = None
|
|
280
|
+
try:
|
|
281
|
+
# Create TaskExecutor and supporting components for worker execution
|
|
282
|
+
executor = TaskExecutor()
|
|
283
|
+
# SurrealDB env vars are expected to be set in env.
|
|
284
|
+
channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
|
|
285
|
+
session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
|
|
286
|
+
|
|
287
|
+
# Create and run the config setup task with TaskExecutor
|
|
288
|
+
setup_model = module_class.create_config_setup_model(config_setup_data)
|
|
289
|
+
|
|
290
|
+
supervisor_task = await executor.execute_task(
|
|
291
|
+
task_id=job_id,
|
|
292
|
+
mission_id=mission_id,
|
|
293
|
+
coro=module.start_config_setup(setup_model, callback),
|
|
294
|
+
session=session,
|
|
295
|
+
channel=channel,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Wait for the supervisor task to complete
|
|
299
|
+
await supervisor_task
|
|
300
|
+
logger.info("Config module task %s completed", job_id)
|
|
301
|
+
except Exception:
|
|
302
|
+
logger.exception("Error running config module %s", job_id)
|
|
303
|
+
raise
|
|
304
|
+
finally:
|
|
305
|
+
# Cleanup channel
|
|
306
|
+
if channel is not None:
|
|
307
|
+
try:
|
|
308
|
+
await channel.close()
|
|
309
|
+
except Exception:
|
|
310
|
+
logger.exception("Error closing channel for job %s", job_id)
|
|
@@ -9,19 +9,22 @@ except ImportError:
|
|
|
9
9
|
|
|
10
10
|
import asyncio
|
|
11
11
|
import contextlib
|
|
12
|
+
import datetime
|
|
12
13
|
import json
|
|
13
14
|
import os
|
|
14
15
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
15
16
|
from contextlib import asynccontextmanager
|
|
16
|
-
from typing import TYPE_CHECKING, Any
|
|
17
|
+
from typing import TYPE_CHECKING, Any
|
|
17
18
|
|
|
18
19
|
from rstream import Consumer, ConsumerOffsetSpecification, MessageContext, OffsetType
|
|
19
20
|
|
|
21
|
+
from digitalkin.core.common import ConnectionFactory, QueueFactory
|
|
20
22
|
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
21
|
-
from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER
|
|
23
|
+
from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER, cleanup_global_resources
|
|
24
|
+
from digitalkin.core.task_manager.remote_task_manager import RemoteTaskManager
|
|
22
25
|
from digitalkin.logger import logger
|
|
23
26
|
from digitalkin.models.core.task_monitor import TaskStatus
|
|
24
|
-
from digitalkin.models.module import InputModelT, SetupModelT
|
|
27
|
+
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
25
28
|
from digitalkin.modules._base_module import BaseModule
|
|
26
29
|
from digitalkin.services.services_models import ServicesMode
|
|
27
30
|
|
|
@@ -29,7 +32,7 @@ if TYPE_CHECKING:
|
|
|
29
32
|
from taskiq.task import AsyncTaskiqTask
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
class TaskiqJobManager(BaseJobManager
|
|
35
|
+
class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
33
36
|
"""Taskiq job manager for running modules in Taskiq tasks."""
|
|
34
37
|
|
|
35
38
|
services_mode: ServicesMode
|
|
@@ -62,6 +65,13 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
62
65
|
if queue:
|
|
63
66
|
await queue.put(data.get("output_data"))
|
|
64
67
|
|
|
68
|
+
async def start(self) -> None:
|
|
69
|
+
"""Start the TaskiqJobManager and initialize SurrealDB connection."""
|
|
70
|
+
await self._start()
|
|
71
|
+
self.channel = await ConnectionFactory.create_surreal_connection(
|
|
72
|
+
database="taskiq_job_manager", timeout=datetime.timedelta(seconds=5)
|
|
73
|
+
)
|
|
74
|
+
|
|
65
75
|
async def _start(self) -> None:
|
|
66
76
|
await TASKIQ_BROKER.startup()
|
|
67
77
|
|
|
@@ -82,12 +92,34 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
82
92
|
callback=self._on_message, # type: ignore
|
|
83
93
|
offset_specification=start_spec,
|
|
84
94
|
)
|
|
95
|
+
|
|
96
|
+
# Wrap the consumer task with error handling
|
|
97
|
+
async def run_consumer_with_error_handling() -> None:
|
|
98
|
+
try:
|
|
99
|
+
await self.stream_consumer.run()
|
|
100
|
+
except asyncio.CancelledError:
|
|
101
|
+
logger.debug("Stream consumer task cancelled")
|
|
102
|
+
raise
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error("Stream consumer task failed: %s", e, exc_info=True, extra={"error": str(e)})
|
|
105
|
+
# Re-raise to ensure the error is not silently ignored
|
|
106
|
+
raise
|
|
107
|
+
|
|
85
108
|
self.stream_consumer_task = asyncio.create_task(
|
|
86
|
-
|
|
109
|
+
run_consumer_with_error_handling(),
|
|
87
110
|
name="stream_consumer_task",
|
|
88
111
|
)
|
|
89
112
|
|
|
90
113
|
async def _stop(self) -> None:
|
|
114
|
+
"""Stop the TaskiqJobManager and clean up all resources."""
|
|
115
|
+
# Close SurrealDB connection
|
|
116
|
+
if hasattr(self, "channel"):
|
|
117
|
+
try:
|
|
118
|
+
await self.channel.close()
|
|
119
|
+
logger.info("TaskiqJobManager: SurrealDB connection closed")
|
|
120
|
+
except Exception as e:
|
|
121
|
+
logger.warning("Failed to close SurrealDB connection: %s", e)
|
|
122
|
+
|
|
91
123
|
# Signal the consumer to stop
|
|
92
124
|
await self.stream_consumer.close()
|
|
93
125
|
# Cancel the background task
|
|
@@ -95,18 +127,40 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
95
127
|
with contextlib.suppress(asyncio.CancelledError):
|
|
96
128
|
await self.stream_consumer_task
|
|
97
129
|
|
|
130
|
+
# Clean up job queues
|
|
131
|
+
self.job_queues.clear()
|
|
132
|
+
logger.info("TaskiqJobManager: Cleared %d job queues", len(self.job_queues))
|
|
133
|
+
|
|
134
|
+
# Call global cleanup for producer and broker
|
|
135
|
+
await cleanup_global_resources()
|
|
136
|
+
|
|
98
137
|
def __init__(
|
|
99
138
|
self,
|
|
100
139
|
module_class: type[BaseModule],
|
|
101
140
|
services_mode: ServicesMode,
|
|
141
|
+
default_timeout: float = 10.0,
|
|
142
|
+
max_concurrent_tasks: int = 100,
|
|
143
|
+
stream_timeout: float = 30.0,
|
|
102
144
|
) -> None:
|
|
103
|
-
"""Initialize the Taskiq job manager.
|
|
104
|
-
|
|
145
|
+
"""Initialize the Taskiq job manager.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
module_class: The class of the module to be managed
|
|
149
|
+
services_mode: The mode of operation for the services
|
|
150
|
+
default_timeout: Default timeout for task operations
|
|
151
|
+
max_concurrent_tasks: Maximum number of concurrent tasks
|
|
152
|
+
stream_timeout: Timeout for stream consumer operations (default: 15.0s for distributed systems)
|
|
153
|
+
"""
|
|
154
|
+
# Create remote task manager for distributed execution
|
|
155
|
+
task_manager = RemoteTaskManager(default_timeout, max_concurrent_tasks)
|
|
156
|
+
|
|
157
|
+
# Initialize base job manager with task manager
|
|
158
|
+
super().__init__(module_class, services_mode, task_manager)
|
|
105
159
|
|
|
106
160
|
logger.warning("TaskiqJobManager initialized with app: %s", TASKIQ_BROKER)
|
|
107
|
-
self.services_mode = services_mode
|
|
108
161
|
self.job_queues: dict[str, asyncio.Queue] = {}
|
|
109
162
|
self.max_queue_size = 1000
|
|
163
|
+
self.stream_timeout = stream_timeout
|
|
110
164
|
|
|
111
165
|
async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
|
|
112
166
|
"""Generate a stream consumer for a module's output data.
|
|
@@ -120,12 +174,20 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
120
174
|
|
|
121
175
|
Returns:
|
|
122
176
|
SetupModelT: the SetupModelT object fully processed.
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
asyncio.TimeoutError: If waiting for the setup response times out.
|
|
123
180
|
"""
|
|
124
|
-
queue
|
|
181
|
+
queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
|
|
125
182
|
self.job_queues[job_id] = queue
|
|
126
183
|
|
|
127
184
|
try:
|
|
128
|
-
|
|
185
|
+
# Add timeout to prevent indefinite blocking
|
|
186
|
+
item = await asyncio.wait_for(queue.get(), timeout=30.0)
|
|
187
|
+
except asyncio.TimeoutError:
|
|
188
|
+
logger.error("Timeout waiting for config setup response for job %s", job_id)
|
|
189
|
+
raise
|
|
190
|
+
else:
|
|
129
191
|
queue.task_done()
|
|
130
192
|
return item
|
|
131
193
|
finally:
|
|
@@ -157,7 +219,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
157
219
|
TypeError: If the function is called with bad data type.
|
|
158
220
|
ValueError: If the module fails to start.
|
|
159
221
|
"""
|
|
160
|
-
task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_config_module")
|
|
222
|
+
task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_config_module")
|
|
161
223
|
|
|
162
224
|
if task is None:
|
|
163
225
|
msg = "Task not found"
|
|
@@ -167,6 +229,7 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
167
229
|
msg = "config_setup_data must be a valid model with model_dump method"
|
|
168
230
|
raise TypeError(msg)
|
|
169
231
|
|
|
232
|
+
# Submit task to Taskiq
|
|
170
233
|
running_task: AsyncTaskiqTask[Any] = await task.kiq(
|
|
171
234
|
mission_id,
|
|
172
235
|
setup_id,
|
|
@@ -177,6 +240,27 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
177
240
|
)
|
|
178
241
|
|
|
179
242
|
job_id = running_task.task_id
|
|
243
|
+
|
|
244
|
+
# Create module instance for metadata
|
|
245
|
+
module = self.module_class(
|
|
246
|
+
job_id,
|
|
247
|
+
mission_id=mission_id,
|
|
248
|
+
setup_id=setup_id,
|
|
249
|
+
setup_version_id=setup_version_id,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Register task in TaskManager (remote mode)
|
|
253
|
+
async def _dummy_coro() -> None:
|
|
254
|
+
"""Dummy coroutine - actual execution happens in worker."""
|
|
255
|
+
|
|
256
|
+
await self.create_task(
|
|
257
|
+
job_id,
|
|
258
|
+
mission_id,
|
|
259
|
+
module,
|
|
260
|
+
_dummy_coro(),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
logger.info("Registered config task: %s, waiting for initial result", job_id)
|
|
180
264
|
result = await running_task.wait_result(timeout=10)
|
|
181
265
|
logger.info("Job %s with data %s", job_id, result)
|
|
182
266
|
return job_id
|
|
@@ -191,28 +275,75 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
191
275
|
Yields:
|
|
192
276
|
messages: The stream messages from the associated module.
|
|
193
277
|
"""
|
|
194
|
-
queue
|
|
278
|
+
queue = QueueFactory.create_bounded_queue(maxsize=self.max_queue_size)
|
|
195
279
|
self.job_queues[job_id] = queue
|
|
196
280
|
|
|
197
281
|
async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
|
|
198
|
-
"""Generate the stream
|
|
282
|
+
"""Generate the stream with batch-drain optimization.
|
|
283
|
+
|
|
284
|
+
This implementation uses a micro-batching pattern optimized for distributed
|
|
285
|
+
message streams from RabbitMQ:
|
|
286
|
+
1. Block waiting for the first item (with timeout for termination checks)
|
|
287
|
+
2. Drain all immediately available items without blocking (micro-batch)
|
|
288
|
+
3. Yield control back to event loop
|
|
289
|
+
|
|
290
|
+
This pattern provides:
|
|
291
|
+
- Better throughput for bursty message streams
|
|
292
|
+
- Reduced gRPC streaming overhead
|
|
293
|
+
- Lower latency when multiple messages arrive simultaneously
|
|
199
294
|
|
|
200
295
|
Yields:
|
|
201
296
|
dict: generated object from the module
|
|
202
297
|
"""
|
|
203
298
|
while True:
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
while True:
|
|
209
|
-
try:
|
|
210
|
-
item = queue.get_nowait()
|
|
211
|
-
except asyncio.QueueEmpty:
|
|
212
|
-
break
|
|
299
|
+
try:
|
|
300
|
+
# Block for first item with timeout to allow termination checks
|
|
301
|
+
item = await asyncio.wait_for(queue.get(), timeout=self.stream_timeout)
|
|
213
302
|
queue.task_done()
|
|
214
303
|
yield item
|
|
215
304
|
|
|
305
|
+
# Drain all immediately available items (micro-batch optimization)
|
|
306
|
+
# This reduces latency when messages arrive in bursts from RabbitMQ
|
|
307
|
+
batch_count = 0
|
|
308
|
+
max_batch_size = 100 # Safety limit to prevent memory spikes
|
|
309
|
+
while batch_count < max_batch_size:
|
|
310
|
+
try:
|
|
311
|
+
item = queue.get_nowait()
|
|
312
|
+
queue.task_done()
|
|
313
|
+
yield item
|
|
314
|
+
batch_count += 1
|
|
315
|
+
except asyncio.QueueEmpty: # noqa: PERF203
|
|
316
|
+
# No more items immediately available, break to next blocking wait
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
except asyncio.TimeoutError:
|
|
320
|
+
logger.warning("Stream consumer timeout for job %s, checking if job is still active", job_id)
|
|
321
|
+
|
|
322
|
+
# Check if job is registered
|
|
323
|
+
if job_id not in self.tasks_sessions:
|
|
324
|
+
logger.info("Job %s no longer registered, ending stream", job_id)
|
|
325
|
+
break
|
|
326
|
+
|
|
327
|
+
# Check job status to detect cancelled/failed jobs
|
|
328
|
+
status = await self.get_module_status(job_id)
|
|
329
|
+
|
|
330
|
+
if status in {TaskStatus.CANCELLED, TaskStatus.FAILED}:
|
|
331
|
+
logger.info("Job %s has terminal status %s, draining queue and ending stream", job_id, status)
|
|
332
|
+
|
|
333
|
+
# Drain remaining queue items before stopping
|
|
334
|
+
while not queue.empty():
|
|
335
|
+
try:
|
|
336
|
+
item = queue.get_nowait()
|
|
337
|
+
queue.task_done()
|
|
338
|
+
yield item
|
|
339
|
+
except asyncio.QueueEmpty: # noqa: PERF203
|
|
340
|
+
break
|
|
341
|
+
|
|
342
|
+
break
|
|
343
|
+
|
|
344
|
+
# Continue waiting for active/completed jobs
|
|
345
|
+
continue
|
|
346
|
+
|
|
216
347
|
try:
|
|
217
348
|
yield _stream()
|
|
218
349
|
finally:
|
|
@@ -241,12 +372,13 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
241
372
|
Raises:
|
|
242
373
|
ValueError: If the task is not found.
|
|
243
374
|
"""
|
|
244
|
-
task = TASKIQ_BROKER.find_task("digitalkin.core.taskiq_broker:run_start_module")
|
|
375
|
+
task = TASKIQ_BROKER.find_task("digitalkin.core.job_manager.taskiq_broker:run_start_module")
|
|
245
376
|
|
|
246
377
|
if task is None:
|
|
247
378
|
msg = "Task not found"
|
|
248
379
|
raise ValueError(msg)
|
|
249
380
|
|
|
381
|
+
# Submit task to Taskiq
|
|
250
382
|
running_task: AsyncTaskiqTask[Any] = await task.kiq(
|
|
251
383
|
mission_id,
|
|
252
384
|
setup_id,
|
|
@@ -257,33 +389,153 @@ class TaskiqJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
257
389
|
setup_data.model_dump(),
|
|
258
390
|
)
|
|
259
391
|
job_id = running_task.task_id
|
|
392
|
+
|
|
393
|
+
# Create module instance for metadata
|
|
394
|
+
module = self.module_class(
|
|
395
|
+
job_id,
|
|
396
|
+
mission_id=mission_id,
|
|
397
|
+
setup_id=setup_id,
|
|
398
|
+
setup_version_id=setup_version_id,
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Register task in TaskManager (remote mode)
|
|
402
|
+
# Dummy coroutine will be closed by TaskManager since execution_mode="remote"
|
|
403
|
+
async def _dummy_coro() -> None:
|
|
404
|
+
"""Dummy coroutine - actual execution happens in worker."""
|
|
405
|
+
|
|
406
|
+
await self.create_task(
|
|
407
|
+
job_id,
|
|
408
|
+
mission_id,
|
|
409
|
+
module,
|
|
410
|
+
_dummy_coro(), # Will be closed immediately by TaskManager in remote mode
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
logger.info("Registered remote task: %s, waiting for initial result", job_id)
|
|
260
414
|
result = await running_task.wait_result(timeout=10)
|
|
261
415
|
logger.debug("Job %s with data %s", job_id, result)
|
|
262
416
|
return job_id
|
|
263
417
|
|
|
418
|
+
async def get_module_status(self, job_id: str) -> TaskStatus:
|
|
419
|
+
"""Query a module status from SurrealDB.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
job_id: The unique identifier of the job.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
TaskStatus: The status of the module task.
|
|
426
|
+
"""
|
|
427
|
+
if job_id not in self.tasks_sessions:
|
|
428
|
+
logger.warning("Job %s not found in registry", job_id)
|
|
429
|
+
return TaskStatus.FAILED
|
|
430
|
+
|
|
431
|
+
# Safety check: if channel not initialized (start() wasn't called), return FAILED
|
|
432
|
+
if not hasattr(self, "channel") or self.channel is None:
|
|
433
|
+
logger.warning("Job %s status check failed - channel not initialized", job_id)
|
|
434
|
+
return TaskStatus.FAILED
|
|
435
|
+
|
|
436
|
+
try:
|
|
437
|
+
# Query the tasks table for the task status
|
|
438
|
+
task_record = await self.channel.select_by_task_id("tasks", job_id)
|
|
439
|
+
if task_record and "status" in task_record:
|
|
440
|
+
status_str = task_record["status"]
|
|
441
|
+
return TaskStatus(status_str) if isinstance(status_str, str) else status_str
|
|
442
|
+
# If no record found in tasks, check heartbeats to see if task exists
|
|
443
|
+
heartbeat_record = await self.channel.select_by_task_id("heartbeats", job_id)
|
|
444
|
+
if heartbeat_record:
|
|
445
|
+
return TaskStatus.RUNNING
|
|
446
|
+
# No task or heartbeat record found - task may still be initializing
|
|
447
|
+
logger.debug("No task or heartbeat record found for job %s - task may still be initializing", job_id)
|
|
448
|
+
except Exception:
|
|
449
|
+
logger.exception("Error getting status for job %s", job_id)
|
|
450
|
+
return TaskStatus.FAILED
|
|
451
|
+
else:
|
|
452
|
+
return TaskStatus.FAILED
|
|
453
|
+
|
|
454
|
+
async def wait_for_completion(self, job_id: str) -> None:
|
|
455
|
+
"""Wait for a task to complete by polling its status from SurrealDB.
|
|
456
|
+
|
|
457
|
+
This method polls the task status until it reaches a terminal state.
|
|
458
|
+
Uses a 0.5 second polling interval to balance responsiveness and resource usage.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
job_id: The unique identifier of the job to wait for.
|
|
462
|
+
|
|
463
|
+
Raises:
|
|
464
|
+
KeyError: If the job_id is not found in tasks_sessions.
|
|
465
|
+
"""
|
|
466
|
+
if job_id not in self.tasks_sessions:
|
|
467
|
+
msg = f"Job {job_id} not found"
|
|
468
|
+
raise KeyError(msg)
|
|
469
|
+
|
|
470
|
+
# Poll task status until terminal state
|
|
471
|
+
terminal_states = {TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED}
|
|
472
|
+
while True:
|
|
473
|
+
status = await self.get_module_status(job_id)
|
|
474
|
+
if status in terminal_states:
|
|
475
|
+
logger.debug("Job %s reached terminal state: %s", job_id, status)
|
|
476
|
+
break
|
|
477
|
+
await asyncio.sleep(0.5) # Poll interval
|
|
478
|
+
|
|
264
479
|
async def stop_module(self, job_id: str) -> bool:
|
|
265
|
-
"""
|
|
480
|
+
"""Stop a running module using TaskManager.
|
|
266
481
|
|
|
267
482
|
Args:
|
|
268
483
|
job_id: The Taskiq task id to stop.
|
|
269
484
|
|
|
270
|
-
|
|
271
|
-
bool: True if the
|
|
485
|
+
Returns:
|
|
486
|
+
bool: True if the signal was successfully sent, False otherwise.
|
|
272
487
|
"""
|
|
273
|
-
|
|
274
|
-
|
|
488
|
+
if job_id not in self.tasks_sessions:
|
|
489
|
+
logger.warning("Job %s not found in registry", job_id)
|
|
490
|
+
return False
|
|
275
491
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
492
|
+
try:
|
|
493
|
+
session = self.tasks_sessions[job_id]
|
|
494
|
+
# Use TaskManager's cancel_task method which handles signal sending
|
|
495
|
+
await self.cancel_task(job_id, session.mission_id)
|
|
496
|
+
logger.info("Cancel signal sent for job %s via TaskManager", job_id)
|
|
280
497
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
498
|
+
# Clean up queue after cancellation
|
|
499
|
+
self.job_queues.pop(job_id, None)
|
|
500
|
+
logger.debug("Cleaned up queue for job %s", job_id)
|
|
501
|
+
except Exception:
|
|
502
|
+
logger.exception("Error stopping job %s", job_id)
|
|
503
|
+
return False
|
|
504
|
+
return True
|
|
505
|
+
|
|
506
|
+
async def stop_all_modules(self) -> None:
|
|
507
|
+
"""Stop all running modules tracked in the registry."""
|
|
508
|
+
stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
|
|
509
|
+
if stop_tasks:
|
|
510
|
+
results = await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
511
|
+
logger.info("Stopped %d modules, results: %s", len(results), results)
|
|
285
512
|
|
|
286
513
|
async def list_modules(self) -> dict[str, dict[str, Any]]:
|
|
287
|
-
"""List all modules.
|
|
288
|
-
|
|
289
|
-
|
|
514
|
+
"""List all modules tracked in the registry with their statuses.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
dict[str, dict[str, Any]]: A dictionary containing information about all tracked modules.
|
|
518
|
+
"""
|
|
519
|
+
modules_info: dict[str, dict[str, Any]] = {}
|
|
520
|
+
|
|
521
|
+
for job_id in self.tasks_sessions:
|
|
522
|
+
try:
|
|
523
|
+
status = await self.get_module_status(job_id)
|
|
524
|
+
task_record = await self.channel.select_by_task_id("tasks", job_id)
|
|
525
|
+
|
|
526
|
+
modules_info[job_id] = {
|
|
527
|
+
"name": self.module_class.__name__,
|
|
528
|
+
"status": status,
|
|
529
|
+
"class": self.module_class.__name__,
|
|
530
|
+
"mission_id": task_record.get("mission_id") if task_record else "unknown",
|
|
531
|
+
}
|
|
532
|
+
except Exception: # noqa: PERF203
|
|
533
|
+
logger.exception("Error getting info for job %s", job_id)
|
|
534
|
+
modules_info[job_id] = {
|
|
535
|
+
"name": self.module_class.__name__,
|
|
536
|
+
"status": TaskStatus.FAILED,
|
|
537
|
+
"class": self.module_class.__name__,
|
|
538
|
+
"error": "Failed to retrieve status",
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
return modules_info
|