digitalkin 0.3.2.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. base_server/__init__.py +1 -0
  2. base_server/mock/__init__.py +5 -0
  3. base_server/mock/mock_pb2.py +39 -0
  4. base_server/mock/mock_pb2_grpc.py +102 -0
  5. base_server/server_async_insecure.py +125 -0
  6. base_server/server_async_secure.py +143 -0
  7. base_server/server_sync_insecure.py +103 -0
  8. base_server/server_sync_secure.py +122 -0
  9. digitalkin/__init__.py +8 -0
  10. digitalkin/__version__.py +8 -0
  11. digitalkin/core/__init__.py +1 -0
  12. digitalkin/core/common/__init__.py +9 -0
  13. digitalkin/core/common/factories.py +156 -0
  14. digitalkin/core/job_manager/__init__.py +1 -0
  15. digitalkin/core/job_manager/base_job_manager.py +288 -0
  16. digitalkin/core/job_manager/single_job_manager.py +354 -0
  17. digitalkin/core/job_manager/taskiq_broker.py +311 -0
  18. digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
  19. digitalkin/core/task_manager/__init__.py +1 -0
  20. digitalkin/core/task_manager/base_task_manager.py +539 -0
  21. digitalkin/core/task_manager/local_task_manager.py +108 -0
  22. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  23. digitalkin/core/task_manager/surrealdb_repository.py +266 -0
  24. digitalkin/core/task_manager/task_executor.py +249 -0
  25. digitalkin/core/task_manager/task_session.py +406 -0
  26. digitalkin/grpc_servers/__init__.py +1 -0
  27. digitalkin/grpc_servers/_base_server.py +486 -0
  28. digitalkin/grpc_servers/module_server.py +208 -0
  29. digitalkin/grpc_servers/module_servicer.py +516 -0
  30. digitalkin/grpc_servers/utils/__init__.py +1 -0
  31. digitalkin/grpc_servers/utils/exceptions.py +29 -0
  32. digitalkin/grpc_servers/utils/grpc_client_wrapper.py +88 -0
  33. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  34. digitalkin/grpc_servers/utils/utility_schema_extender.py +97 -0
  35. digitalkin/logger.py +157 -0
  36. digitalkin/mixins/__init__.py +19 -0
  37. digitalkin/mixins/base_mixin.py +10 -0
  38. digitalkin/mixins/callback_mixin.py +24 -0
  39. digitalkin/mixins/chat_history_mixin.py +110 -0
  40. digitalkin/mixins/cost_mixin.py +76 -0
  41. digitalkin/mixins/file_history_mixin.py +93 -0
  42. digitalkin/mixins/filesystem_mixin.py +46 -0
  43. digitalkin/mixins/logger_mixin.py +51 -0
  44. digitalkin/mixins/storage_mixin.py +79 -0
  45. digitalkin/models/__init__.py +8 -0
  46. digitalkin/models/core/__init__.py +1 -0
  47. digitalkin/models/core/job_manager_models.py +36 -0
  48. digitalkin/models/core/task_monitor.py +70 -0
  49. digitalkin/models/grpc_servers/__init__.py +1 -0
  50. digitalkin/models/grpc_servers/models.py +275 -0
  51. digitalkin/models/grpc_servers/types.py +24 -0
  52. digitalkin/models/module/__init__.py +25 -0
  53. digitalkin/models/module/module.py +40 -0
  54. digitalkin/models/module/module_context.py +149 -0
  55. digitalkin/models/module/module_types.py +393 -0
  56. digitalkin/models/module/utility.py +146 -0
  57. digitalkin/models/services/__init__.py +10 -0
  58. digitalkin/models/services/cost.py +54 -0
  59. digitalkin/models/services/registry.py +42 -0
  60. digitalkin/models/services/storage.py +44 -0
  61. digitalkin/modules/__init__.py +11 -0
  62. digitalkin/modules/_base_module.py +517 -0
  63. digitalkin/modules/archetype_module.py +23 -0
  64. digitalkin/modules/tool_module.py +23 -0
  65. digitalkin/modules/trigger_handler.py +48 -0
  66. digitalkin/modules/triggers/__init__.py +12 -0
  67. digitalkin/modules/triggers/healthcheck_ping_trigger.py +45 -0
  68. digitalkin/modules/triggers/healthcheck_services_trigger.py +63 -0
  69. digitalkin/modules/triggers/healthcheck_status_trigger.py +52 -0
  70. digitalkin/py.typed +0 -0
  71. digitalkin/services/__init__.py +30 -0
  72. digitalkin/services/agent/__init__.py +6 -0
  73. digitalkin/services/agent/agent_strategy.py +19 -0
  74. digitalkin/services/agent/default_agent.py +13 -0
  75. digitalkin/services/base_strategy.py +22 -0
  76. digitalkin/services/communication/__init__.py +7 -0
  77. digitalkin/services/communication/communication_strategy.py +76 -0
  78. digitalkin/services/communication/default_communication.py +101 -0
  79. digitalkin/services/communication/grpc_communication.py +223 -0
  80. digitalkin/services/cost/__init__.py +14 -0
  81. digitalkin/services/cost/cost_strategy.py +100 -0
  82. digitalkin/services/cost/default_cost.py +114 -0
  83. digitalkin/services/cost/grpc_cost.py +138 -0
  84. digitalkin/services/filesystem/__init__.py +7 -0
  85. digitalkin/services/filesystem/default_filesystem.py +417 -0
  86. digitalkin/services/filesystem/filesystem_strategy.py +252 -0
  87. digitalkin/services/filesystem/grpc_filesystem.py +317 -0
  88. digitalkin/services/identity/__init__.py +6 -0
  89. digitalkin/services/identity/default_identity.py +15 -0
  90. digitalkin/services/identity/identity_strategy.py +14 -0
  91. digitalkin/services/registry/__init__.py +27 -0
  92. digitalkin/services/registry/default_registry.py +141 -0
  93. digitalkin/services/registry/exceptions.py +47 -0
  94. digitalkin/services/registry/grpc_registry.py +306 -0
  95. digitalkin/services/registry/registry_models.py +43 -0
  96. digitalkin/services/registry/registry_strategy.py +98 -0
  97. digitalkin/services/services_config.py +200 -0
  98. digitalkin/services/services_models.py +65 -0
  99. digitalkin/services/setup/__init__.py +1 -0
  100. digitalkin/services/setup/default_setup.py +219 -0
  101. digitalkin/services/setup/grpc_setup.py +343 -0
  102. digitalkin/services/setup/setup_strategy.py +145 -0
  103. digitalkin/services/snapshot/__init__.py +6 -0
  104. digitalkin/services/snapshot/default_snapshot.py +39 -0
  105. digitalkin/services/snapshot/snapshot_strategy.py +30 -0
  106. digitalkin/services/storage/__init__.py +7 -0
  107. digitalkin/services/storage/default_storage.py +228 -0
  108. digitalkin/services/storage/grpc_storage.py +214 -0
  109. digitalkin/services/storage/storage_strategy.py +273 -0
  110. digitalkin/services/user_profile/__init__.py +12 -0
  111. digitalkin/services/user_profile/default_user_profile.py +55 -0
  112. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  113. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  114. digitalkin/utils/__init__.py +29 -0
  115. digitalkin/utils/arg_parser.py +92 -0
  116. digitalkin/utils/development_mode_action.py +51 -0
  117. digitalkin/utils/dynamic_schema.py +483 -0
  118. digitalkin/utils/llm_ready_schema.py +75 -0
  119. digitalkin/utils/package_discover.py +357 -0
  120. digitalkin-0.3.2.dev2.dist-info/METADATA +602 -0
  121. digitalkin-0.3.2.dev2.dist-info/RECORD +131 -0
  122. digitalkin-0.3.2.dev2.dist-info/WHEEL +5 -0
  123. digitalkin-0.3.2.dev2.dist-info/licenses/LICENSE +430 -0
  124. digitalkin-0.3.2.dev2.dist-info/top_level.txt +4 -0
  125. modules/__init__.py +0 -0
  126. modules/cpu_intensive_module.py +280 -0
  127. modules/dynamic_setup_module.py +338 -0
  128. modules/minimal_llm_module.py +347 -0
  129. modules/text_transform_module.py +203 -0
  130. services/filesystem_module.py +200 -0
  131. services/storage_module.py +206 -0
@@ -0,0 +1,354 @@
1
+ """Background module manager with single instance."""
2
+
3
+ import asyncio
4
+ import datetime
5
+ import uuid
6
+ from collections.abc import AsyncGenerator, AsyncIterator
7
+ from contextlib import asynccontextmanager
8
+ from typing import Any
9
+
10
+ import grpc
11
+
12
+ from digitalkin.core.common import ConnectionFactory, ModuleFactory
13
+ from digitalkin.core.job_manager.base_job_manager import BaseJobManager
14
+ from digitalkin.core.task_manager.local_task_manager import LocalTaskManager
15
+ from digitalkin.core.task_manager.task_session import TaskSession
16
+ from digitalkin.logger import logger
17
+ from digitalkin.models.core.task_monitor import TaskStatus
18
+ from digitalkin.models.module.module import ModuleCodeModel
19
+ from digitalkin.models.module.module_types import InputModelT, OutputModelT, SetupModelT
20
+ from digitalkin.modules._base_module import BaseModule
21
+ from digitalkin.services.services_models import ServicesMode
22
+
23
+
24
+ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
25
+ """Manages a single instance of a module job.
26
+
27
+ This class ensures that only one instance of a module job is active at a time.
28
+ It provides functionality to create, stop, and monitor module jobs, as well as
29
+ to handle their output data.
30
+ """
31
+
32
+ async def start(self) -> None:
33
+ """Start manager."""
34
+ self.channel = await ConnectionFactory.create_surreal_connection("task_manager", datetime.timedelta(seconds=5))
35
+
36
+ def __init__(
37
+ self,
38
+ module_class: type[BaseModule],
39
+ services_mode: ServicesMode,
40
+ default_timeout: float = 10.0,
41
+ max_concurrent_tasks: int = 100,
42
+ ) -> None:
43
+ """Initialize the job manager.
44
+
45
+ Args:
46
+ module_class: The class of the module to be managed.
47
+ services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
48
+ default_timeout: Default timeout for task operations
49
+ max_concurrent_tasks: Maximum number of concurrent tasks
50
+ """
51
+ # Create local task manager for same-process execution
52
+ task_manager = LocalTaskManager(default_timeout, max_concurrent_tasks)
53
+
54
+ # Initialize base job manager with task manager
55
+ super().__init__(module_class, services_mode, task_manager)
56
+
57
+ self._lock = asyncio.Lock()
58
+
59
+ async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT | ModuleCodeModel:
60
+ """Generate a stream consumer for a module's output data.
61
+
62
+ This method creates an asynchronous generator that streams output data
63
+ from a specific module job. If the module does not exist, it generates
64
+ an error message.
65
+
66
+ Args:
67
+ job_id: The unique identifier of the job.
68
+
69
+ Returns:
70
+ SetupModelT | ModuleCodeModel: the SetupModelT object fully processed.
71
+ """
72
+ if (session := self.tasks_sessions.get(job_id, None)) is None:
73
+ return ModuleCodeModel(
74
+ code=str(grpc.StatusCode.NOT_FOUND),
75
+ message=f"Module {job_id} not found",
76
+ )
77
+
78
+ logger.debug("Module %s found: %s", job_id, session.module)
79
+ try:
80
+ # Add timeout to prevent indefinite blocking
81
+ return await asyncio.wait_for(session.queue.get(), timeout=30.0)
82
+ except asyncio.TimeoutError:
83
+ logger.error("Timeout waiting for config setup response from module %s", job_id)
84
+ return ModuleCodeModel(
85
+ code=str(grpc.StatusCode.DEADLINE_EXCEEDED),
86
+ message=f"Module {job_id} did not respond within 30 seconds",
87
+ )
88
+ finally:
89
+ logger.info(f"{job_id=}: {session.queue.empty()}")
90
+
91
+ async def create_config_setup_instance_job(
92
+ self,
93
+ config_setup_data: SetupModelT,
94
+ mission_id: str,
95
+ setup_id: str,
96
+ setup_version_id: str,
97
+ ) -> str:
98
+ """Create and start a new module setup configuration job.
99
+
100
+ This method initializes a new module job, assigns it a unique job ID,
101
+ and starts the config setup it in the background.
102
+
103
+ Args:
104
+ config_setup_data: The input data required to start the job.
105
+ mission_id: The mission ID associated with the job.
106
+ setup_id: The setup ID associated with the module.
107
+ setup_version_id: The setup ID.
108
+
109
+ Returns:
110
+ str: The unique identifier (job ID) of the created job.
111
+
112
+ Raises:
113
+ Exception: If the module fails to start.
114
+ """
115
+ job_id = str(uuid.uuid4())
116
+ # TODO: Ensure the job_id is unique.
117
+ module = ModuleFactory.create_module_instance(self.module_class, job_id, mission_id, setup_id, setup_version_id)
118
+ self.tasks_sessions[job_id] = TaskSession(job_id, mission_id, self.channel, module)
119
+
120
+ try:
121
+ await module.start_config_setup(
122
+ config_setup_data,
123
+ await self.job_specific_callback(self.add_to_queue, job_id),
124
+ )
125
+ logger.debug("Module %s (%s) started successfully", job_id, module.name)
126
+ except Exception:
127
+ # Remove the module from the manager in case of an error.
128
+ del self.tasks_sessions[job_id]
129
+ logger.exception("Failed to start module %s: %s", job_id)
130
+ raise
131
+ else:
132
+ return job_id
133
+
134
+ async def add_to_queue(self, job_id: str, output_data: OutputModelT | ModuleCodeModel) -> None:
135
+ """Add output data to the queue for a specific job.
136
+
137
+ This method is used as a callback to handle output data generated by a module job.
138
+
139
+ Args:
140
+ job_id: The unique identifier of the job.
141
+ output_data: The output data produced by the job.
142
+ """
143
+ await self.tasks_sessions[job_id].queue.put(output_data.model_dump())
144
+
145
+ @asynccontextmanager # type: ignore
146
+ async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
147
+ """Generate a stream consumer for a module's output data.
148
+
149
+ This method creates an asynchronous generator that streams output data
150
+ from a specific module job. If the module does not exist, it generates
151
+ an error message.
152
+
153
+ Args:
154
+ job_id: The unique identifier of the job.
155
+
156
+ Yields:
157
+ AsyncGenerator: A stream of output data or error messages.
158
+ """
159
+ if (session := self.tasks_sessions.get(job_id, None)) is None:
160
+
161
+ async def _error_gen() -> AsyncGenerator[dict[str, Any], None]: # noqa: RUF029
162
+ """Generate an error message for a non-existent module.
163
+
164
+ Yields:
165
+ AsyncGenerator: A generator yielding an error message.
166
+ """
167
+ yield {
168
+ "error": {
169
+ "error_message": f"Module {job_id} not found",
170
+ "code": grpc.StatusCode.NOT_FOUND,
171
+ }
172
+ }
173
+
174
+ yield _error_gen()
175
+ return
176
+
177
+ logger.debug("Session: %s with Module %s", job_id, session.module)
178
+
179
+ async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
180
+ """Stream output data from the module with simple blocking pattern.
181
+
182
+ This implementation uses a simple one-item-at-a-time pattern optimized
183
+ for local execution where we have direct access to session status:
184
+ 1. Block waiting for each item
185
+ 2. Check termination conditions after each item
186
+ 3. Clean shutdown when task completes
187
+
188
+ This pattern provides:
189
+ - Immediate termination when task completes
190
+ - Direct session status monitoring
191
+ - Simple, predictable behavior for local tasks
192
+
193
+ Yields:
194
+ dict: Output data generated by the module.
195
+ """
196
+ while True:
197
+ # Block for next item - if queue is empty but producer not finished yet
198
+ msg = await session.queue.get()
199
+ try:
200
+ yield msg
201
+ finally:
202
+ # Always mark task as done, even if consumer raises exception
203
+ session.queue.task_done()
204
+
205
+ # Check termination conditions after each message
206
+ # This allows immediate shutdown when the task completes
207
+ if (
208
+ session.is_cancelled.is_set()
209
+ or (session.status is TaskStatus.COMPLETED and session.queue.empty())
210
+ or session.status is TaskStatus.FAILED
211
+ ):
212
+ logger.debug(
213
+ "Stream ending for job %s: cancelled=%s, status=%s, queue_empty=%s",
214
+ job_id,
215
+ session.is_cancelled.is_set(),
216
+ session.status,
217
+ session.queue.empty(),
218
+ )
219
+ break
220
+
221
+ yield _stream()
222
+
223
+ async def create_module_instance_job(
224
+ self,
225
+ input_data: InputModelT,
226
+ setup_data: SetupModelT,
227
+ mission_id: str,
228
+ setup_id: str,
229
+ setup_version_id: str,
230
+ ) -> str:
231
+ """Create and start a new module job.
232
+
233
+ This method initializes a new module job, assigns it a unique job ID,
234
+ and starts it in the background.
235
+
236
+ Args:
237
+ input_data: The input data required to start the job.
238
+ setup_data: The setup configuration for the module.
239
+ mission_id: The mission ID associated with the job.
240
+ setup_id: The setup ID associated with the module.
241
+ setup_version_id: The setup Version ID associated with the module.
242
+
243
+ Returns:
244
+ str: The unique identifier (job ID) of the created job.
245
+
246
+ Raises:
247
+ Exception: If the module fails to start.
248
+ """
249
+ job_id = str(uuid.uuid4())
250
+ module = ModuleFactory.create_module_instance(self.module_class, job_id, mission_id, setup_id, setup_version_id)
251
+ callback = await self.job_specific_callback(self.add_to_queue, job_id)
252
+
253
+ await self.create_task(
254
+ job_id,
255
+ mission_id,
256
+ module,
257
+ module.start(input_data, setup_data, callback, done_callback=None),
258
+ )
259
+ logger.info("Managed task started: '%s'", job_id, extra={"task_id": job_id})
260
+ return job_id
261
+
262
+ async def stop_module(self, job_id: str) -> bool:
263
+ """Stop a running module job.
264
+
265
+ Args:
266
+ job_id: The unique identifier of the job to stop.
267
+
268
+ Returns:
269
+ bool: True if the module was successfully stopped, False if it does not exist.
270
+
271
+ Raises:
272
+ Exception: If an error occurs while stopping the module.
273
+ """
274
+ logger.info(f"STOP required for {job_id=}")
275
+
276
+ async with self._lock:
277
+ session = self.tasks_sessions.get(job_id)
278
+
279
+ if not session:
280
+ logger.warning(f"session with id: {job_id} not found")
281
+ return False
282
+ try:
283
+ await session.module.stop()
284
+ await self.cancel_task(job_id, session.mission_id)
285
+ logger.debug(f"session {job_id} ({session.module.name}) stopped successfully")
286
+ except Exception as e:
287
+ logger.error(f"Error while stopping module {job_id}: {e}")
288
+ raise
289
+ else:
290
+ return True
291
+
292
+ async def get_module_status(self, job_id: str) -> TaskStatus:
293
+ """Retrieve the status of a module job.
294
+
295
+ Args:
296
+ job_id: The unique identifier of the job.
297
+
298
+ Returns:
299
+ ModuleStatus: The status of the module.
300
+ """
301
+ session = self.tasks_sessions.get(job_id, None)
302
+ return session.status if session is not None else TaskStatus.FAILED
303
+
304
+ async def wait_for_completion(self, job_id: str) -> None:
305
+ """Wait for a task to complete by awaiting its asyncio.Task.
306
+
307
+ Args:
308
+ job_id: The unique identifier of the job to wait for.
309
+
310
+ Raises:
311
+ KeyError: If the job_id is not found in tasks.
312
+ """
313
+ if job_id not in self._task_manager.tasks:
314
+ msg = f"Job {job_id} not found"
315
+ raise KeyError(msg)
316
+ await self._task_manager.tasks[job_id]
317
+
318
+ async def stop_all_modules(self) -> None:
319
+ """Stop all currently running module jobs.
320
+
321
+ This method ensures that all active jobs are gracefully terminated
322
+ and closes the SurrealDB connection.
323
+ """
324
+ # Snapshot job IDs while holding lock
325
+ async with self._lock:
326
+ job_ids = list(self.tasks_sessions.keys())
327
+
328
+ # Release lock before calling stop_module (which has its own lock)
329
+ if job_ids:
330
+ stop_tasks = [self.stop_module(job_id) for job_id in job_ids]
331
+ await asyncio.gather(*stop_tasks, return_exceptions=True)
332
+
333
+ # Close SurrealDB connection after stopping all modules
334
+ if hasattr(self, "channel"):
335
+ try:
336
+ await self.channel.close()
337
+ logger.info("SingleJobManager: SurrealDB connection closed")
338
+ except Exception as e:
339
+ logger.warning("Failed to close SurrealDB connection: %s", e)
340
+
341
+ async def list_modules(self) -> dict[str, dict[str, Any]]:
342
+ """List all modules along with their statuses.
343
+
344
+ Returns:
345
+ dict[str, dict[str, Any]]: A dictionary containing information about all modules and their statuses.
346
+ """
347
+ return {
348
+ job_id: {
349
+ "name": session.module.name,
350
+ "status": session.module.status,
351
+ "class": session.module.__class__.__name__,
352
+ }
353
+ for job_id, session in self.tasks_sessions.items()
354
+ }
@@ -0,0 +1,311 @@
1
+ """Taskiq broker & RSTREAM producer for the job manager."""
2
+
3
+ import asyncio
4
+ import datetime
5
+ import json
6
+ import logging
7
+ import os
8
+ import pickle # noqa: S403
9
+ from typing import Any
10
+
11
+ from rstream import Producer
12
+ from rstream.exceptions import PreconditionFailed
13
+ from taskiq import Context, TaskiqDepends, TaskiqMessage
14
+ from taskiq.abc.formatter import TaskiqFormatter
15
+ from taskiq.compat import model_validate
16
+ from taskiq.message import BrokerMessage
17
+ from taskiq_aio_pika import AioPikaBroker
18
+
19
+ from digitalkin.core.common import ConnectionFactory, ModuleFactory
20
+ from digitalkin.core.job_manager.base_job_manager import BaseJobManager
21
+ from digitalkin.core.task_manager.task_executor import TaskExecutor
22
+ from digitalkin.core.task_manager.task_session import TaskSession
23
+ from digitalkin.logger import logger
24
+ from digitalkin.models.module.module import ModuleCodeModel
25
+ from digitalkin.models.module.module_types import DataModel, OutputModelT
26
+ from digitalkin.models.module.utility import EndOfStreamOutput
27
+ from digitalkin.modules._base_module import BaseModule
28
+ from digitalkin.services.services_config import ServicesConfig
29
+ from digitalkin.services.services_models import ServicesMode
30
+
31
+ logging.getLogger("taskiq").setLevel(logging.INFO)
32
+ logging.getLogger("aiormq").setLevel(logging.INFO)
33
+ logging.getLogger("aio_pika").setLevel(logging.INFO)
34
+ logging.getLogger("rstream").setLevel(logging.INFO)
35
+
36
+
37
+ class PickleFormatter(TaskiqFormatter):
38
+ """Formatter that pickles the JSON-dumped TaskiqMessage.
39
+
40
+ This lets you send arbitrary Python objects (classes, functions, etc.)
41
+ by first converting to JSON-safe primitives, then pickling that string.
42
+ """
43
+
44
+ def dumps(self, message: TaskiqMessage) -> BrokerMessage: # noqa: PLR6301
45
+ """Dumps message from python complex object to JSON.
46
+
47
+ Args:
48
+ message: TaskIQ message
49
+
50
+ Returns:
51
+ BrokerMessage with mandatory information for TaskIQ
52
+ """
53
+ payload: bytes = pickle.dumps(message)
54
+
55
+ return BrokerMessage(
56
+ task_id=message.task_id,
57
+ task_name=message.task_name,
58
+ message=payload,
59
+ labels=message.labels,
60
+ )
61
+
62
+ def loads(self, message: bytes) -> TaskiqMessage: # noqa: PLR6301
63
+ """Recreate Python object from bytes.
64
+
65
+ Args:
66
+ message: Broker message from bytes.
67
+
68
+ Returns:
69
+ message with TaskIQ format
70
+ """
71
+ json_str = pickle.loads(message) # noqa: S301
72
+ return model_validate(TaskiqMessage, json_str)
73
+
74
+
75
+ def define_producer() -> Producer:
76
+ """Get from the env the connection parameter to RabbitMQ.
77
+
78
+ Returns:
79
+ Producer
80
+ """
81
+ host: str = os.environ.get("RABBITMQ_RSTREAM_HOST", "localhost")
82
+ port: str = os.environ.get("RABBITMQ_RSTREAM_PORT", "5552")
83
+ username: str = os.environ.get("RABBITMQ_RSTREAM_USERNAME", "guest")
84
+ password: str = os.environ.get("RABBITMQ_RSTREAM_PASSWORD", "guest")
85
+
86
+ logger.info("Connection to RabbitMQ: %s:%s.", host, port)
87
+ return Producer(host=host, port=int(port), username=username, password=password)
88
+
89
+
90
+ async def init_rstream() -> None:
91
+ """Init a stream for every tasks."""
92
+ try:
93
+ await RSTREAM_PRODUCER.create_stream(
94
+ STREAM,
95
+ exists_ok=True,
96
+ arguments={"max-length-bytes": STREAM_RETENTION},
97
+ )
98
+ except PreconditionFailed:
99
+ logger.warning("stream already exist")
100
+
101
+
102
+ def define_broker() -> AioPikaBroker:
103
+ """Define broker with from env paramter.
104
+
105
+ Returns:
106
+ Broker: connected to RabbitMQ and with custom formatter.
107
+ """
108
+ host: str = os.environ.get("RABBITMQ_BROKER_HOST", "localhost")
109
+ port: str = os.environ.get("RABBITMQ_BROKER_PORT", "5672")
110
+ username: str = os.environ.get("RABBITMQ_BROKER_USERNAME", "guest")
111
+ password: str = os.environ.get("RABBITMQ_BROKER_PASSWORD", "guest")
112
+
113
+ broker = AioPikaBroker(
114
+ f"amqp://{username}:{password}@{host}:{port}",
115
+ startup=[init_rstream],
116
+ )
117
+ broker.formatter = PickleFormatter()
118
+ return broker
119
+
120
+
121
+ STREAM = "taskiq_data"
122
+ STREAM_RETENTION = 200_000
123
+ RSTREAM_PRODUCER = define_producer()
124
+ TASKIQ_BROKER = define_broker()
125
+
126
+
127
+ async def cleanup_global_resources() -> None:
128
+ """Clean up global resources (producer and broker connections).
129
+
130
+ This should be called during shutdown to prevent connection leaks.
131
+ """
132
+ try:
133
+ await RSTREAM_PRODUCER.close()
134
+ logger.info("RStream producer closed successfully")
135
+ except Exception as e:
136
+ logger.warning("Failed to close RStream producer: %s", e)
137
+
138
+ try:
139
+ await TASKIQ_BROKER.shutdown()
140
+ logger.info("Taskiq broker shut down successfully")
141
+ except Exception as e:
142
+ logger.warning("Failed to shutdown Taskiq broker: %s", e)
143
+
144
+
145
+ async def send_message_to_stream(job_id: str, output_data: OutputModelT | ModuleCodeModel) -> None: # type: ignore[type-var]
146
+ """Callback define to add a message frame to the Rstream.
147
+
148
+ Args:
149
+ job_id: id of the job that sent the message
150
+ output_data: message body as a OutputModelT or error / stream_code
151
+ """
152
+ body = json.dumps({"job_id": job_id, "output_data": output_data.model_dump()}).encode("utf-8")
153
+ await RSTREAM_PRODUCER.send(stream=STREAM, message=body)
154
+
155
+
156
+ @TASKIQ_BROKER.task
157
+ async def run_start_module(
158
+ mission_id: str,
159
+ setup_id: str,
160
+ setup_version_id: str,
161
+ module_class: type[BaseModule],
162
+ services_mode: ServicesMode,
163
+ input_data: dict,
164
+ setup_data: dict,
165
+ context: Context = TaskiqDepends(),
166
+ ) -> None:
167
+ """TaskIQ task allowing a module to compute in the background asynchronously.
168
+
169
+ Args:
170
+ mission_id: str,
171
+ setup_id: The setup ID associated with the module.
172
+ setup_version_id: The setup ID associated with the module.
173
+ module_class: type[BaseModule],
174
+ services_mode: ServicesMode,
175
+ input_data: dict,
176
+ setup_data: dict,
177
+ context: Allow TaskIQ context access
178
+ """
179
+ logger.info("Starting module with services_mode: %s", services_mode)
180
+ services_config = ServicesConfig(
181
+ services_config_strategies=module_class.services_config_strategies,
182
+ services_config_params=module_class.services_config_params,
183
+ mode=services_mode,
184
+ )
185
+ setattr(module_class, "services_config", services_config)
186
+ logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
187
+ module_class.discover()
188
+
189
+ job_id = context.message.task_id
190
+ callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id) # type: ignore[type-var]
191
+ module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
192
+
193
+ channel = None
194
+ try:
195
+ # Create TaskExecutor and supporting components for worker execution
196
+ executor = TaskExecutor()
197
+ # SurrealDB env vars are expected to be set in env.
198
+ channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
199
+ session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
200
+
201
+ # Execute the task using TaskExecutor
202
+ # Create a proper done callback that handles errors
203
+ async def send_end_of_stream(_: Any) -> None: # noqa: ANN401
204
+ try:
205
+ await callback(DataModel(root=EndOfStreamOutput()))
206
+ except Exception as e:
207
+ logger.error("Error sending end of stream: %s", e, exc_info=True)
208
+
209
+ # Reconstruct Pydantic models from dicts for type safety
210
+ try:
211
+ input_model = module_class.create_input_model(input_data)
212
+ setup_model = await module_class.create_setup_model(setup_data)
213
+ except Exception as e:
214
+ logger.error("Failed to reconstruct models for job %s: %s", job_id, e, exc_info=True)
215
+ raise
216
+
217
+ supervisor_task = await executor.execute_task(
218
+ task_id=job_id,
219
+ mission_id=mission_id,
220
+ coro=module.start(
221
+ input_model,
222
+ setup_model,
223
+ callback,
224
+ done_callback=lambda result: asyncio.ensure_future(send_end_of_stream(result)),
225
+ ),
226
+ session=session,
227
+ channel=channel,
228
+ )
229
+
230
+ # Wait for the supervisor task to complete
231
+ await supervisor_task
232
+ logger.info("Module task %s completed", job_id)
233
+ except Exception:
234
+ logger.exception("Error running module %s", job_id)
235
+ raise
236
+ finally:
237
+ # Cleanup channel
238
+ if channel is not None:
239
+ try:
240
+ await channel.close()
241
+ except Exception:
242
+ logger.exception("Error closing channel for job %s", job_id)
243
+
244
+
245
+ @TASKIQ_BROKER.task
246
+ async def run_config_module(
247
+ mission_id: str,
248
+ setup_id: str,
249
+ setup_version_id: str,
250
+ module_class: type[BaseModule],
251
+ services_mode: ServicesMode,
252
+ config_setup_data: dict,
253
+ context: Context = TaskiqDepends(),
254
+ ) -> None:
255
+ """TaskIQ task allowing a module to compute in the background asynchronously.
256
+
257
+ Args:
258
+ mission_id: str,
259
+ setup_id: The setup ID associated with the module.
260
+ setup_version_id: The setup ID associated with the module.
261
+ module_class: type[BaseModule],
262
+ services_mode: ServicesMode,
263
+ config_setup_data: dict,
264
+ context: Allow TaskIQ context access
265
+ """
266
+ logger.info("Starting config module with services_mode: %s", services_mode)
267
+ services_config = ServicesConfig(
268
+ services_config_strategies=module_class.services_config_strategies,
269
+ services_config_params=module_class.services_config_params,
270
+ mode=services_mode,
271
+ )
272
+ setattr(module_class, "services_config", services_config)
273
+ logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
274
+
275
+ job_id = context.message.task_id
276
+ callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id) # type: ignore[type-var]
277
+ module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
278
+
279
+ # Override environment variables temporarily to use manager's SurrealDB
280
+ channel = None
281
+ try:
282
+ # Create TaskExecutor and supporting components for worker execution
283
+ executor = TaskExecutor()
284
+ # SurrealDB env vars are expected to be set in env.
285
+ channel = await ConnectionFactory.create_surreal_connection("taskiq_worker", datetime.timedelta(seconds=5))
286
+ session = TaskSession(job_id, mission_id, channel, module, datetime.timedelta(seconds=2))
287
+
288
+ # Create and run the config setup task with TaskExecutor
289
+ setup_model = module_class.create_config_setup_model(config_setup_data)
290
+
291
+ supervisor_task = await executor.execute_task(
292
+ task_id=job_id,
293
+ mission_id=mission_id,
294
+ coro=module.start_config_setup(setup_model, callback),
295
+ session=session,
296
+ channel=channel,
297
+ )
298
+
299
+ # Wait for the supervisor task to complete
300
+ await supervisor_task
301
+ logger.info("Config module task %s completed", job_id)
302
+ except Exception:
303
+ logger.exception("Error running config module %s", job_id)
304
+ raise
305
+ finally:
306
+ # Cleanup channel
307
+ if channel is not None:
308
+ try:
309
+ await channel.close()
310
+ except Exception:
311
+ logger.exception("Error closing channel for job %s", job_id)