digitalkin 0.3.1.dev1__py3-none-any.whl → 0.3.2a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- base_server/server_async_insecure.py +6 -5
- base_server/server_async_secure.py +6 -5
- base_server/server_sync_insecure.py +5 -4
- base_server/server_sync_secure.py +5 -4
- digitalkin/__version__.py +1 -1
- digitalkin/core/job_manager/base_job_manager.py +1 -1
- digitalkin/core/job_manager/single_job_manager.py +78 -36
- digitalkin/core/job_manager/taskiq_broker.py +8 -7
- digitalkin/core/job_manager/taskiq_job_manager.py +9 -5
- digitalkin/core/task_manager/base_task_manager.py +3 -1
- digitalkin/core/task_manager/surrealdb_repository.py +13 -7
- digitalkin/core/task_manager/task_executor.py +27 -10
- digitalkin/core/task_manager/task_session.py +133 -101
- digitalkin/grpc_servers/module_server.py +95 -171
- digitalkin/grpc_servers/module_servicer.py +133 -27
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +36 -10
- digitalkin/grpc_servers/utils/utility_schema_extender.py +106 -0
- digitalkin/models/__init__.py +1 -1
- digitalkin/models/core/job_manager_models.py +0 -8
- digitalkin/models/core/task_monitor.py +23 -1
- digitalkin/models/grpc_servers/models.py +95 -8
- digitalkin/models/module/__init__.py +26 -13
- digitalkin/models/module/base_types.py +61 -0
- digitalkin/models/module/module_context.py +279 -13
- digitalkin/models/module/module_types.py +29 -109
- digitalkin/models/module/setup_types.py +547 -0
- digitalkin/models/module/tool_cache.py +230 -0
- digitalkin/models/module/tool_reference.py +160 -0
- digitalkin/models/module/utility.py +167 -0
- digitalkin/models/services/cost.py +22 -1
- digitalkin/models/services/registry.py +77 -0
- digitalkin/modules/__init__.py +5 -1
- digitalkin/modules/_base_module.py +253 -90
- digitalkin/modules/archetype_module.py +6 -1
- digitalkin/modules/tool_module.py +6 -1
- digitalkin/modules/triggers/__init__.py +8 -0
- digitalkin/modules/triggers/healthcheck_ping_trigger.py +45 -0
- digitalkin/modules/triggers/healthcheck_services_trigger.py +63 -0
- digitalkin/modules/triggers/healthcheck_status_trigger.py +52 -0
- digitalkin/services/__init__.py +4 -0
- digitalkin/services/communication/__init__.py +7 -0
- digitalkin/services/communication/communication_strategy.py +87 -0
- digitalkin/services/communication/default_communication.py +104 -0
- digitalkin/services/communication/grpc_communication.py +264 -0
- digitalkin/services/cost/cost_strategy.py +36 -14
- digitalkin/services/cost/default_cost.py +61 -1
- digitalkin/services/cost/grpc_cost.py +98 -2
- digitalkin/services/filesystem/grpc_filesystem.py +9 -2
- digitalkin/services/registry/__init__.py +22 -1
- digitalkin/services/registry/default_registry.py +156 -4
- digitalkin/services/registry/exceptions.py +47 -0
- digitalkin/services/registry/grpc_registry.py +382 -0
- digitalkin/services/registry/registry_models.py +15 -0
- digitalkin/services/registry/registry_strategy.py +106 -4
- digitalkin/services/services_config.py +25 -3
- digitalkin/services/services_models.py +5 -1
- digitalkin/services/setup/default_setup.py +1 -1
- digitalkin/services/setup/grpc_setup.py +1 -1
- digitalkin/services/storage/grpc_storage.py +1 -1
- digitalkin/services/user_profile/__init__.py +11 -0
- digitalkin/services/user_profile/grpc_user_profile.py +2 -2
- digitalkin/services/user_profile/user_profile_strategy.py +0 -15
- digitalkin/utils/__init__.py +40 -0
- digitalkin/utils/conditional_schema.py +260 -0
- digitalkin/utils/dynamic_schema.py +487 -0
- digitalkin/utils/schema_splitter.py +290 -0
- {digitalkin-0.3.1.dev1.dist-info → digitalkin-0.3.2a2.dist-info}/METADATA +13 -13
- digitalkin-0.3.2a2.dist-info/RECORD +144 -0
- {digitalkin-0.3.1.dev1.dist-info → digitalkin-0.3.2a2.dist-info}/WHEEL +1 -1
- {digitalkin-0.3.1.dev1.dist-info → digitalkin-0.3.2a2.dist-info}/top_level.txt +1 -0
- modules/archetype_with_tools_module.py +232 -0
- modules/cpu_intensive_module.py +1 -1
- modules/dynamic_setup_module.py +338 -0
- modules/minimal_llm_module.py +1 -1
- modules/text_transform_module.py +1 -1
- monitoring/digitalkin_observability/__init__.py +46 -0
- monitoring/digitalkin_observability/http_server.py +150 -0
- monitoring/digitalkin_observability/interceptors.py +176 -0
- monitoring/digitalkin_observability/metrics.py +201 -0
- monitoring/digitalkin_observability/prometheus.py +137 -0
- monitoring/tests/test_metrics.py +172 -0
- services/filesystem_module.py +7 -5
- services/storage_module.py +4 -2
- digitalkin/grpc_servers/registry_server.py +0 -65
- digitalkin/grpc_servers/registry_servicer.py +0 -456
- digitalkin-0.3.1.dev1.dist-info/RECORD +0 -117
- {digitalkin-0.3.1.dev1.dist-info → digitalkin-0.3.2a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,8 +9,9 @@ from pathlib import Path
|
|
|
9
9
|
# Add parent directory to path to enable imports
|
|
10
10
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
|
11
11
|
|
|
12
|
-
from digitalkin.grpc_servers._base_server import BaseServer
|
|
13
12
|
from digitalkin.grpc_servers.utils.models import SecurityMode, ServerConfig, ServerMode
|
|
13
|
+
|
|
14
|
+
from digitalkin.grpc_servers._base_server import BaseServer
|
|
14
15
|
from examples.base_server.mock.mock_pb2 import DESCRIPTOR, HelloReply # type: ignore
|
|
15
16
|
from examples.base_server.mock.mock_pb2_grpc import (
|
|
16
17
|
Greeter,
|
|
@@ -30,7 +31,7 @@ class AsyncGreeterImpl(Greeter):
|
|
|
30
31
|
|
|
31
32
|
async def SayHello(self, request, context): # noqa: N802
|
|
32
33
|
"""Asynchronous implementation of SayHello method."""
|
|
33
|
-
logger.info(
|
|
34
|
+
logger.info("Received request object: %s", request)
|
|
34
35
|
logger.info(f"Request attributes: {vars(request)}")
|
|
35
36
|
logger.info(f"Received request with name: {request.name}")
|
|
36
37
|
|
|
@@ -40,7 +41,7 @@ class AsyncGreeterImpl(Greeter):
|
|
|
40
41
|
name = "unknown"
|
|
41
42
|
# Check context metadata
|
|
42
43
|
for key, value in context.invocation_metadata():
|
|
43
|
-
logger.info(
|
|
44
|
+
logger.info("Metadata: %s=%s", key, value)
|
|
44
45
|
if key.lower() == "name":
|
|
45
46
|
name = value
|
|
46
47
|
|
|
@@ -97,7 +98,7 @@ async def main_async() -> int:
|
|
|
97
98
|
# as the KeyboardInterrupt usually breaks out of asyncio.run()
|
|
98
99
|
logger.info("Server stopping due to keyboard interrupt...")
|
|
99
100
|
except Exception as e:
|
|
100
|
-
logger.exception(
|
|
101
|
+
logger.exception("Error running server: %s", e)
|
|
101
102
|
return 1
|
|
102
103
|
finally:
|
|
103
104
|
# Clean up resources if server was started
|
|
@@ -116,7 +117,7 @@ def main():
|
|
|
116
117
|
logger.info("Server stopped by keyboard interrupt")
|
|
117
118
|
return 0 # Clean exit
|
|
118
119
|
except Exception as e:
|
|
119
|
-
logger.exception(
|
|
120
|
+
logger.exception("Fatal error: %s", e)
|
|
120
121
|
return 1
|
|
121
122
|
|
|
122
123
|
|
|
@@ -9,13 +9,14 @@ from pathlib import Path
|
|
|
9
9
|
# Add parent directory to path to enable imports
|
|
10
10
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
|
11
11
|
|
|
12
|
-
from digitalkin.grpc_servers._base_server import BaseServer
|
|
13
12
|
from digitalkin.grpc_servers.utils.models import (
|
|
14
13
|
SecurityMode,
|
|
15
14
|
ServerConfig,
|
|
16
15
|
ServerCredentials,
|
|
17
16
|
ServerMode,
|
|
18
17
|
)
|
|
18
|
+
|
|
19
|
+
from digitalkin.grpc_servers._base_server import BaseServer
|
|
19
20
|
from examples.base_server.mock.mock_pb2 import DESCRIPTOR, HelloReply # type: ignore
|
|
20
21
|
from examples.base_server.mock.mock_pb2_grpc import (
|
|
21
22
|
Greeter,
|
|
@@ -35,7 +36,7 @@ class AsyncGreeterImpl(Greeter):
|
|
|
35
36
|
|
|
36
37
|
async def SayHello(self, request, context): # noqa: N802
|
|
37
38
|
"""Asynchronous implementation of SayHello method."""
|
|
38
|
-
logger.info(
|
|
39
|
+
logger.info("Received request object: %s", request)
|
|
39
40
|
logger.info(f"Request attributes: {vars(request)}")
|
|
40
41
|
logger.info(f"Received request with name: {request.name}")
|
|
41
42
|
|
|
@@ -45,7 +46,7 @@ class AsyncGreeterImpl(Greeter):
|
|
|
45
46
|
name = "unknown"
|
|
46
47
|
# Check context metadata
|
|
47
48
|
for key, value in context.invocation_metadata():
|
|
48
|
-
logger.info(
|
|
49
|
+
logger.info("Metadata: %s=%s", key, value)
|
|
49
50
|
if key.lower() == "name":
|
|
50
51
|
name = value
|
|
51
52
|
|
|
@@ -115,7 +116,7 @@ async def main_async() -> int:
|
|
|
115
116
|
# as the KeyboardInterrupt usually breaks out of asyncio.run()
|
|
116
117
|
logger.info("Server stopping due to keyboard interrupt...")
|
|
117
118
|
except Exception as e:
|
|
118
|
-
logger.exception(
|
|
119
|
+
logger.exception("Error running server: %s", e)
|
|
119
120
|
return 1
|
|
120
121
|
finally:
|
|
121
122
|
# Clean up resources if server was started
|
|
@@ -134,7 +135,7 @@ def main():
|
|
|
134
135
|
logger.info("Server stopped by keyboard interrupt")
|
|
135
136
|
return 0 # Clean exit
|
|
136
137
|
except Exception as e:
|
|
137
|
-
logger.exception(
|
|
138
|
+
logger.exception("Fatal error: %s", e)
|
|
138
139
|
return 1
|
|
139
140
|
|
|
140
141
|
|
|
@@ -8,8 +8,9 @@ from pathlib import Path
|
|
|
8
8
|
# Add parent directory to path to enable imports
|
|
9
9
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
|
10
10
|
|
|
11
|
-
from digitalkin.grpc_servers._base_server import BaseServer
|
|
12
11
|
from digitalkin.grpc_servers.utils.models import SecurityMode, ServerConfig, ServerMode
|
|
12
|
+
|
|
13
|
+
from digitalkin.grpc_servers._base_server import BaseServer
|
|
13
14
|
from examples.base_server.mock.mock_pb2 import DESCRIPTOR, HelloReply # type: ignore
|
|
14
15
|
from examples.base_server.mock.mock_pb2_grpc import (
|
|
15
16
|
Greeter,
|
|
@@ -29,7 +30,7 @@ class SyncGreeterServicer(Greeter):
|
|
|
29
30
|
|
|
30
31
|
def SayHello(self, request, context): # noqa: N802
|
|
31
32
|
"""Implementation of SayHello method."""
|
|
32
|
-
logger.info(
|
|
33
|
+
logger.info("Received request object: %s", request)
|
|
33
34
|
logger.info(f"Request attributes: {vars(request)}")
|
|
34
35
|
logger.info(f"Received request with name: {request.name}")
|
|
35
36
|
|
|
@@ -39,7 +40,7 @@ class SyncGreeterServicer(Greeter):
|
|
|
39
40
|
name = "unknown"
|
|
40
41
|
# Check context metadata
|
|
41
42
|
for key, value in context.invocation_metadata():
|
|
42
|
-
logger.info(
|
|
43
|
+
logger.info("Metadata: %s=%s", key, value)
|
|
43
44
|
if key.lower() == "name":
|
|
44
45
|
name = value
|
|
45
46
|
|
|
@@ -92,7 +93,7 @@ def main() -> int:
|
|
|
92
93
|
server.stop()
|
|
93
94
|
|
|
94
95
|
except Exception as e:
|
|
95
|
-
logger.exception(
|
|
96
|
+
logger.exception("Error running server: %s", e)
|
|
96
97
|
return 1
|
|
97
98
|
|
|
98
99
|
return 0
|
|
@@ -8,13 +8,14 @@ from pathlib import Path
|
|
|
8
8
|
# Add parent directory to path to enable imports
|
|
9
9
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
|
10
10
|
|
|
11
|
-
from digitalkin.grpc_servers._base_server import BaseServer
|
|
12
11
|
from digitalkin.grpc_servers.utils.models import (
|
|
13
12
|
SecurityMode,
|
|
14
13
|
ServerConfig,
|
|
15
14
|
ServerCredentials,
|
|
16
15
|
ServerMode,
|
|
17
16
|
)
|
|
17
|
+
|
|
18
|
+
from digitalkin.grpc_servers._base_server import BaseServer
|
|
18
19
|
from examples.base_server.mock.mock_pb2 import DESCRIPTOR, HelloReply # type: ignore
|
|
19
20
|
from examples.base_server.mock.mock_pb2_grpc import (
|
|
20
21
|
Greeter,
|
|
@@ -34,7 +35,7 @@ class SyncGreeterServicer(Greeter):
|
|
|
34
35
|
|
|
35
36
|
def SayHello(self, request, context): # noqa: N802
|
|
36
37
|
"""Implementation of SayHello method."""
|
|
37
|
-
logger.info(
|
|
38
|
+
logger.info("Received request object: %s", request)
|
|
38
39
|
logger.info(f"Request attributes: {vars(request)}")
|
|
39
40
|
logger.info(f"Received request with name: {request.name}")
|
|
40
41
|
|
|
@@ -44,7 +45,7 @@ class SyncGreeterServicer(Greeter):
|
|
|
44
45
|
name = "unknown"
|
|
45
46
|
# Check context metadata
|
|
46
47
|
for key, value in context.invocation_metadata():
|
|
47
|
-
logger.info(
|
|
48
|
+
logger.info("Metadata: %s=%s", key, value)
|
|
48
49
|
if key.lower() == "name":
|
|
49
50
|
name = value
|
|
50
51
|
|
|
@@ -111,7 +112,7 @@ def main() -> int:
|
|
|
111
112
|
server.stop()
|
|
112
113
|
|
|
113
114
|
except Exception as e:
|
|
114
|
-
logger.exception(
|
|
115
|
+
logger.exception("Error running server: %s", e)
|
|
115
116
|
return 1
|
|
116
117
|
|
|
117
118
|
return 0
|
digitalkin/__version__.py
CHANGED
|
@@ -8,8 +8,8 @@ from typing import Any, Generic
|
|
|
8
8
|
from digitalkin.core.task_manager.base_task_manager import BaseTaskManager
|
|
9
9
|
from digitalkin.core.task_manager.task_session import TaskSession
|
|
10
10
|
from digitalkin.models.core.task_monitor import TaskStatus
|
|
11
|
-
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
12
11
|
from digitalkin.models.module.module import ModuleCodeModel
|
|
12
|
+
from digitalkin.models.module.module_types import InputModelT, OutputModelT, SetupModelT
|
|
13
13
|
from digitalkin.modules._base_module import BaseModule
|
|
14
14
|
from digitalkin.services.services_config import ServicesConfig
|
|
15
15
|
from digitalkin.services.services_models import ServicesMode
|
|
@@ -5,7 +5,7 @@ import datetime
|
|
|
5
5
|
import uuid
|
|
6
6
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
7
7
|
from contextlib import asynccontextmanager
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
9
|
|
|
10
10
|
import grpc
|
|
11
11
|
|
|
@@ -13,9 +13,12 @@ from digitalkin.core.common import ConnectionFactory, ModuleFactory
|
|
|
13
13
|
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
14
14
|
from digitalkin.core.task_manager.local_task_manager import LocalTaskManager
|
|
15
15
|
from digitalkin.core.task_manager.task_session import TaskSession
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
16
19
|
from digitalkin.logger import logger
|
|
17
20
|
from digitalkin.models.core.task_monitor import TaskStatus
|
|
18
|
-
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
21
|
+
from digitalkin.models.module.base_types import InputModelT, OutputModelT, SetupModelT
|
|
19
22
|
from digitalkin.models.module.module import ModuleCodeModel
|
|
20
23
|
from digitalkin.modules._base_module import BaseModule
|
|
21
24
|
from digitalkin.services.services_models import ServicesMode
|
|
@@ -29,10 +32,6 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
29
32
|
to handle their output data.
|
|
30
33
|
"""
|
|
31
34
|
|
|
32
|
-
async def start(self) -> None:
|
|
33
|
-
"""Start manager."""
|
|
34
|
-
self.channel = await ConnectionFactory.create_surreal_connection("task_manager", datetime.timedelta(seconds=5))
|
|
35
|
-
|
|
36
35
|
def __init__(
|
|
37
36
|
self,
|
|
38
37
|
module_class: type[BaseModule],
|
|
@@ -55,6 +54,11 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
55
54
|
super().__init__(module_class, services_mode, task_manager)
|
|
56
55
|
|
|
57
56
|
self._lock = asyncio.Lock()
|
|
57
|
+
self.channel: SurrealDBConnection | None = None
|
|
58
|
+
|
|
59
|
+
async def start(self) -> None:
|
|
60
|
+
"""Start manager."""
|
|
61
|
+
self.channel = await ConnectionFactory.create_surreal_connection("task_manager", datetime.timedelta(seconds=5))
|
|
58
62
|
|
|
59
63
|
async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT | ModuleCodeModel:
|
|
60
64
|
"""Generate a stream consumer for a module's output data.
|
|
@@ -86,7 +90,10 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
86
90
|
message=f"Module {job_id} did not respond within 30 seconds",
|
|
87
91
|
)
|
|
88
92
|
finally:
|
|
89
|
-
logger.
|
|
93
|
+
logger.debug(
|
|
94
|
+
"Config setup response retrieved",
|
|
95
|
+
extra={"job_id": job_id, "queue_empty": session.queue.empty()},
|
|
96
|
+
)
|
|
90
97
|
|
|
91
98
|
async def create_config_setup_instance_job(
|
|
92
99
|
self,
|
|
@@ -110,11 +117,14 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
110
117
|
str: The unique identifier (job ID) of the created job.
|
|
111
118
|
|
|
112
119
|
Raises:
|
|
120
|
+
RuntimeError: If start() was not called before creating jobs.
|
|
113
121
|
Exception: If the module fails to start.
|
|
114
122
|
"""
|
|
115
123
|
job_id = str(uuid.uuid4())
|
|
116
|
-
# TODO: Ensure the job_id is unique.
|
|
117
124
|
module = ModuleFactory.create_module_instance(self.module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
125
|
+
if self.channel is None:
|
|
126
|
+
msg = "JobManager.start() must be called before creating jobs"
|
|
127
|
+
raise RuntimeError(msg)
|
|
118
128
|
self.tasks_sessions[job_id] = TaskSession(job_id, mission_id, self.channel, module)
|
|
119
129
|
|
|
120
130
|
try:
|
|
@@ -126,7 +136,7 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
126
136
|
except Exception:
|
|
127
137
|
# Remove the module from the manager in case of an error.
|
|
128
138
|
del self.tasks_sessions[job_id]
|
|
129
|
-
logger.exception("Failed to start module
|
|
139
|
+
logger.exception("Failed to start module", extra={"job_id": job_id})
|
|
130
140
|
raise
|
|
131
141
|
else:
|
|
132
142
|
return job_id
|
|
@@ -134,13 +144,33 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
134
144
|
async def add_to_queue(self, job_id: str, output_data: OutputModelT | ModuleCodeModel) -> None:
|
|
135
145
|
"""Add output data to the queue for a specific job.
|
|
136
146
|
|
|
137
|
-
|
|
147
|
+
Uses timeout-based backpressure: if the queue is full after 5s,
|
|
148
|
+
drops the oldest message to make room for the new one.
|
|
149
|
+
Rejects writes after stream is closed to prevent message loss.
|
|
138
150
|
|
|
139
151
|
Args:
|
|
140
152
|
job_id: The unique identifier of the job.
|
|
141
153
|
output_data: The output data produced by the job.
|
|
142
154
|
"""
|
|
143
|
-
|
|
155
|
+
session = self.tasks_sessions.get(job_id)
|
|
156
|
+
if session is None:
|
|
157
|
+
logger.warning("Queue write rejected - session not found", extra={"job_id": job_id})
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
if session.stream_closed:
|
|
161
|
+
logger.debug("Queue write rejected - stream closed", extra={"job_id": job_id})
|
|
162
|
+
return
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
await asyncio.wait_for(session.queue.put(output_data.model_dump()), timeout=5.0)
|
|
166
|
+
except asyncio.TimeoutError:
|
|
167
|
+
logger.warning("Queue full, dropping oldest message", extra={"job_id": job_id})
|
|
168
|
+
try:
|
|
169
|
+
session.queue.get_nowait()
|
|
170
|
+
session.queue.task_done()
|
|
171
|
+
except asyncio.QueueEmpty:
|
|
172
|
+
pass
|
|
173
|
+
session.queue.put_nowait(output_data.model_dump())
|
|
144
174
|
|
|
145
175
|
@asynccontextmanager # type: ignore
|
|
146
176
|
async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
|
|
@@ -177,42 +207,39 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
177
207
|
logger.debug("Session: %s with Module %s", job_id, session.module)
|
|
178
208
|
|
|
179
209
|
async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
|
|
180
|
-
"""Stream output data from the module with
|
|
210
|
+
"""Stream output data from the module with bounded blocking.
|
|
181
211
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
2. Check termination conditions after each item
|
|
186
|
-
3. Clean shutdown when task completes
|
|
187
|
-
|
|
188
|
-
This pattern provides:
|
|
189
|
-
- Immediate termination when task completes
|
|
190
|
-
- Direct session status monitoring
|
|
191
|
-
- Simple, predictable behavior for local tasks
|
|
212
|
+
Uses a 1-second timeout on queue.get() to periodically re-check
|
|
213
|
+
termination flags, preventing indefinite hangs when the task crashes
|
|
214
|
+
without producing output.
|
|
192
215
|
|
|
193
216
|
Yields:
|
|
194
217
|
dict: Output data generated by the module.
|
|
195
218
|
"""
|
|
196
219
|
while True:
|
|
197
|
-
|
|
198
|
-
|
|
220
|
+
if session.stream_closed or session.is_cancelled.is_set():
|
|
221
|
+
logger.debug("Stream ending for job %s (pre-check)", job_id)
|
|
222
|
+
break
|
|
223
|
+
|
|
224
|
+
try:
|
|
225
|
+
msg = await asyncio.wait_for(session.queue.get(), timeout=1.0)
|
|
226
|
+
except asyncio.TimeoutError:
|
|
227
|
+
continue
|
|
228
|
+
|
|
199
229
|
try:
|
|
200
230
|
yield msg
|
|
201
231
|
finally:
|
|
202
|
-
# Always mark task as done, even if consumer raises exception
|
|
203
232
|
session.queue.task_done()
|
|
204
233
|
|
|
205
|
-
# Check termination conditions after each message
|
|
206
|
-
# This allows immediate shutdown when the task completes
|
|
207
234
|
if (
|
|
208
|
-
session.
|
|
235
|
+
session.stream_closed
|
|
236
|
+
or session.is_cancelled.is_set()
|
|
209
237
|
or (session.status is TaskStatus.COMPLETED and session.queue.empty())
|
|
210
238
|
or session.status is TaskStatus.FAILED
|
|
211
239
|
):
|
|
212
240
|
logger.debug(
|
|
213
|
-
"Stream ending for job %s:
|
|
241
|
+
"Stream ending for job %s: status=%s, queue_empty=%s",
|
|
214
242
|
job_id,
|
|
215
|
-
session.is_cancelled.is_set(),
|
|
216
243
|
session.status,
|
|
217
244
|
session.queue.empty(),
|
|
218
245
|
)
|
|
@@ -259,6 +286,18 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
259
286
|
logger.info("Managed task started: '%s'", job_id, extra={"task_id": job_id})
|
|
260
287
|
return job_id
|
|
261
288
|
|
|
289
|
+
async def clean_session(self, task_id: str, mission_id: str) -> bool:
|
|
290
|
+
"""Clean a task's session.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
task_id: Unique identifier for the task.
|
|
294
|
+
mission_id: Mission identifier.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
bool: True if the task was successfully cleaned, False otherwise.
|
|
298
|
+
"""
|
|
299
|
+
return await self._task_manager.clean_session(task_id, mission_id)
|
|
300
|
+
|
|
262
301
|
async def stop_module(self, job_id: str) -> bool:
|
|
263
302
|
"""Stop a running module job.
|
|
264
303
|
|
|
@@ -271,20 +310,23 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
271
310
|
Raises:
|
|
272
311
|
Exception: If an error occurs while stopping the module.
|
|
273
312
|
"""
|
|
274
|
-
logger.info(
|
|
313
|
+
logger.info("Stop module requested", extra={"job_id": job_id})
|
|
275
314
|
|
|
276
315
|
async with self._lock:
|
|
277
316
|
session = self.tasks_sessions.get(job_id)
|
|
278
317
|
|
|
279
318
|
if not session:
|
|
280
|
-
logger.warning(
|
|
319
|
+
logger.warning("Session not found", extra={"job_id": job_id})
|
|
281
320
|
return False
|
|
282
321
|
try:
|
|
283
322
|
await session.module.stop()
|
|
284
323
|
await self.cancel_task(job_id, session.mission_id)
|
|
285
|
-
logger.debug(
|
|
286
|
-
|
|
287
|
-
|
|
324
|
+
logger.debug(
|
|
325
|
+
"Module stopped successfully",
|
|
326
|
+
extra={"job_id": job_id, "mission_id": session.mission_id},
|
|
327
|
+
)
|
|
328
|
+
except Exception:
|
|
329
|
+
logger.exception("Error stopping module", extra={"job_id": job_id})
|
|
288
330
|
raise
|
|
289
331
|
else:
|
|
290
332
|
return True
|
|
@@ -331,7 +373,7 @@ class SingleJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
331
373
|
await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
332
374
|
|
|
333
375
|
# Close SurrealDB connection after stopping all modules
|
|
334
|
-
if
|
|
376
|
+
if self.channel is not None:
|
|
335
377
|
try:
|
|
336
378
|
await self.channel.close()
|
|
337
379
|
logger.info("SingleJobManager: SurrealDB connection closed")
|
|
@@ -21,8 +21,9 @@ from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
|
21
21
|
from digitalkin.core.task_manager.task_executor import TaskExecutor
|
|
22
22
|
from digitalkin.core.task_manager.task_session import TaskSession
|
|
23
23
|
from digitalkin.logger import logger
|
|
24
|
-
from digitalkin.models.
|
|
25
|
-
from digitalkin.models.module.module_types import OutputModelT
|
|
24
|
+
from digitalkin.models.module.module import ModuleCodeModel
|
|
25
|
+
from digitalkin.models.module.module_types import DataModel, OutputModelT
|
|
26
|
+
from digitalkin.models.module.utility import EndOfStreamOutput
|
|
26
27
|
from digitalkin.modules._base_module import BaseModule
|
|
27
28
|
from digitalkin.services.services_config import ServicesConfig
|
|
28
29
|
from digitalkin.services.services_models import ServicesMode
|
|
@@ -141,7 +142,7 @@ async def cleanup_global_resources() -> None:
|
|
|
141
142
|
logger.warning("Failed to shutdown Taskiq broker: %s", e)
|
|
142
143
|
|
|
143
144
|
|
|
144
|
-
async def send_message_to_stream(job_id: str, output_data: OutputModelT) -> None: # type: ignore
|
|
145
|
+
async def send_message_to_stream(job_id: str, output_data: OutputModelT | ModuleCodeModel) -> None: # type: ignore[type-var]
|
|
145
146
|
"""Callback define to add a message frame to the Rstream.
|
|
146
147
|
|
|
147
148
|
Args:
|
|
@@ -186,7 +187,7 @@ async def run_start_module(
|
|
|
186
187
|
module_class.discover()
|
|
187
188
|
|
|
188
189
|
job_id = context.message.task_id
|
|
189
|
-
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
|
|
190
|
+
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id) # type: ignore[type-var]
|
|
190
191
|
module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
191
192
|
|
|
192
193
|
channel = None
|
|
@@ -201,14 +202,14 @@ async def run_start_module(
|
|
|
201
202
|
# Create a proper done callback that handles errors
|
|
202
203
|
async def send_end_of_stream(_: Any) -> None: # noqa: ANN401
|
|
203
204
|
try:
|
|
204
|
-
await callback(
|
|
205
|
+
await callback(DataModel(root=EndOfStreamOutput()))
|
|
205
206
|
except Exception as e:
|
|
206
207
|
logger.error("Error sending end of stream: %s", e, exc_info=True)
|
|
207
208
|
|
|
208
209
|
# Reconstruct Pydantic models from dicts for type safety
|
|
209
210
|
try:
|
|
210
211
|
input_model = module_class.create_input_model(input_data)
|
|
211
|
-
setup_model = module_class.create_setup_model(setup_data)
|
|
212
|
+
setup_model = await module_class.create_setup_model(setup_data)
|
|
212
213
|
except Exception as e:
|
|
213
214
|
logger.error("Failed to reconstruct models for job %s: %s", job_id, e, exc_info=True)
|
|
214
215
|
raise
|
|
@@ -272,7 +273,7 @@ async def run_config_module(
|
|
|
272
273
|
logger.debug("Services config: %s | Module config: %s", services_config, module_class.services_config)
|
|
273
274
|
|
|
274
275
|
job_id = context.message.task_id
|
|
275
|
-
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id)
|
|
276
|
+
callback = await BaseJobManager.job_specific_callback(send_message_to_stream, job_id) # type: ignore[type-var]
|
|
276
277
|
module = ModuleFactory.create_module_instance(module_class, job_id, mission_id, setup_id, setup_version_id)
|
|
277
278
|
|
|
278
279
|
# Override environment variables temporarily to use manager's SurrealDB
|
|
@@ -22,9 +22,10 @@ from digitalkin.core.common import ConnectionFactory, QueueFactory
|
|
|
22
22
|
from digitalkin.core.job_manager.base_job_manager import BaseJobManager
|
|
23
23
|
from digitalkin.core.job_manager.taskiq_broker import STREAM, STREAM_RETENTION, TASKIQ_BROKER, cleanup_global_resources
|
|
24
24
|
from digitalkin.core.task_manager.remote_task_manager import RemoteTaskManager
|
|
25
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
25
26
|
from digitalkin.logger import logger
|
|
26
27
|
from digitalkin.models.core.task_monitor import TaskStatus
|
|
27
|
-
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
28
|
+
from digitalkin.models.module.module_types import InputModelT, OutputModelT, SetupModelT
|
|
28
29
|
from digitalkin.modules._base_module import BaseModule
|
|
29
30
|
from digitalkin.services.services_models import ServicesMode
|
|
30
31
|
|
|
@@ -36,6 +37,7 @@ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
36
37
|
"""Taskiq job manager for running modules in Taskiq tasks."""
|
|
37
38
|
|
|
38
39
|
services_mode: ServicesMode
|
|
40
|
+
channel: SurrealDBConnection | None
|
|
39
41
|
|
|
40
42
|
@staticmethod
|
|
41
43
|
def _define_consumer() -> Consumer:
|
|
@@ -113,7 +115,7 @@ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
113
115
|
async def _stop(self) -> None:
|
|
114
116
|
"""Stop the TaskiqJobManager and clean up all resources."""
|
|
115
117
|
# Close SurrealDB connection
|
|
116
|
-
if
|
|
118
|
+
if self.channel is not None:
|
|
117
119
|
try:
|
|
118
120
|
await self.channel.close()
|
|
119
121
|
logger.info("TaskiqJobManager: SurrealDB connection closed")
|
|
@@ -128,8 +130,9 @@ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
128
130
|
await self.stream_consumer_task
|
|
129
131
|
|
|
130
132
|
# Clean up job queues
|
|
133
|
+
queue_count = len(self.job_queues)
|
|
131
134
|
self.job_queues.clear()
|
|
132
|
-
logger.info("TaskiqJobManager: Cleared %d job queues",
|
|
135
|
+
logger.info("TaskiqJobManager: Cleared %d job queues", queue_count)
|
|
133
136
|
|
|
134
137
|
# Call global cleanup for producer and broker
|
|
135
138
|
await cleanup_global_resources()
|
|
@@ -161,6 +164,7 @@ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
161
164
|
self.job_queues: dict[str, asyncio.Queue] = {}
|
|
162
165
|
self.max_queue_size = 1000
|
|
163
166
|
self.stream_timeout = stream_timeout
|
|
167
|
+
self.channel = None
|
|
164
168
|
|
|
165
169
|
async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
|
|
166
170
|
"""Generate a stream consumer for a module's output data.
|
|
@@ -429,7 +433,7 @@ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
429
433
|
return TaskStatus.FAILED
|
|
430
434
|
|
|
431
435
|
# Safety check: if channel not initialized (start() wasn't called), return FAILED
|
|
432
|
-
if
|
|
436
|
+
if self.channel is None:
|
|
433
437
|
logger.warning("Job %s status check failed - channel not initialized", job_id)
|
|
434
438
|
return TaskStatus.FAILED
|
|
435
439
|
|
|
@@ -521,7 +525,7 @@ class TaskiqJobManager(BaseJobManager[InputModelT, OutputModelT, SetupModelT]):
|
|
|
521
525
|
for job_id in self.tasks_sessions:
|
|
522
526
|
try:
|
|
523
527
|
status = await self.get_module_status(job_id)
|
|
524
|
-
task_record = await self.channel.select_by_task_id("tasks", job_id)
|
|
528
|
+
task_record = await self.channel.select_by_task_id("tasks", job_id) # type: ignore
|
|
525
529
|
|
|
526
530
|
modules_info[job_id] = {
|
|
527
531
|
"name": self.module_class.__name__,
|
|
@@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
|
|
|
8
8
|
from collections.abc import Coroutine
|
|
9
9
|
from typing import Any
|
|
10
10
|
|
|
11
|
+
from typing_extensions import Self
|
|
12
|
+
|
|
11
13
|
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
12
14
|
from digitalkin.core.task_manager.task_session import TaskSession
|
|
13
15
|
from digitalkin.logger import logger
|
|
@@ -507,7 +509,7 @@ class BaseTaskManager(ABC):
|
|
|
507
509
|
},
|
|
508
510
|
)
|
|
509
511
|
|
|
510
|
-
async def __aenter__(self) ->
|
|
512
|
+
async def __aenter__(self) -> Self:
|
|
511
513
|
"""Enter async context manager.
|
|
512
514
|
|
|
513
515
|
Returns:
|
|
@@ -4,7 +4,7 @@ import asyncio
|
|
|
4
4
|
import datetime
|
|
5
5
|
import os
|
|
6
6
|
from collections.abc import AsyncGenerator
|
|
7
|
-
from typing import Any, Generic, TypeVar
|
|
7
|
+
from typing import Any, Generic, TypeVar, cast
|
|
8
8
|
from uuid import UUID
|
|
9
9
|
|
|
10
10
|
from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
|
|
@@ -40,6 +40,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
40
40
|
db: TSurreal
|
|
41
41
|
timeout: datetime.timedelta
|
|
42
42
|
_live_queries: set[UUID] # Track active live queries for cleanup
|
|
43
|
+
_closed: bool # Flag to prevent operations on closed connection
|
|
43
44
|
|
|
44
45
|
@staticmethod
|
|
45
46
|
def _valid_id(raw_id: str, table_name: str) -> RecordID:
|
|
@@ -85,13 +86,14 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
85
86
|
self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
|
|
86
87
|
self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
|
|
87
88
|
self._live_queries = set() # Initialize live queries tracker
|
|
89
|
+
self._closed = False
|
|
88
90
|
|
|
89
91
|
async def init_surreal_instance(self) -> None:
|
|
90
92
|
"""Init a SurrealDB connection instance."""
|
|
91
93
|
logger.debug("Connecting to SurrealDB at %s", self.url)
|
|
92
94
|
self.db = AsyncSurreal(self.url) # type: ignore
|
|
93
95
|
await self.db.signin({"username": self.username, "password": self.password})
|
|
94
|
-
await self.db.use(self.namespace, self.database)
|
|
96
|
+
await self.db.use(self.namespace, self.database) # type: ignore[arg-type]
|
|
95
97
|
logger.debug("Successfully connected to SurrealDB")
|
|
96
98
|
|
|
97
99
|
async def close(self) -> None:
|
|
@@ -99,6 +101,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
99
101
|
|
|
100
102
|
This will also kill all active live queries to prevent memory leaks.
|
|
101
103
|
"""
|
|
104
|
+
self._closed = True
|
|
102
105
|
# Kill all tracked live queries before closing connection
|
|
103
106
|
if self._live_queries:
|
|
104
107
|
logger.debug("Killing %d active live queries before closing", len(self._live_queries))
|
|
@@ -112,7 +115,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
112
115
|
# Process results and track failures
|
|
113
116
|
failed_queries = []
|
|
114
117
|
for live_id, result in zip(live_query_ids, results):
|
|
115
|
-
if isinstance(result,
|
|
118
|
+
if isinstance(result, ConnectionError | TimeoutError | Exception):
|
|
116
119
|
failed_queries.append((live_id, str(result)))
|
|
117
120
|
else:
|
|
118
121
|
self._live_queries.discard(live_id)
|
|
@@ -146,7 +149,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
146
149
|
logger.debug("Creating record in %s with data: %s", table_name, data)
|
|
147
150
|
result = await self.db.create(table_name, data)
|
|
148
151
|
logger.debug("create result: %s", result)
|
|
149
|
-
return result
|
|
152
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
150
153
|
|
|
151
154
|
async def merge(
|
|
152
155
|
self,
|
|
@@ -170,7 +173,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
170
173
|
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
171
174
|
result = await self.db.merge(record_id, data)
|
|
172
175
|
logger.debug("update result: %s", result)
|
|
173
|
-
return result
|
|
176
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
174
177
|
|
|
175
178
|
async def update(
|
|
176
179
|
self,
|
|
@@ -194,7 +197,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
194
197
|
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
195
198
|
result = await self.db.update(record_id, data)
|
|
196
199
|
logger.debug("update result: %s", result)
|
|
197
|
-
return result
|
|
200
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
198
201
|
|
|
199
202
|
async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
|
|
200
203
|
"""Execute a custom SurrealQL query.
|
|
@@ -209,7 +212,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
209
212
|
logger.debug("execute_query: %s with params: %s", query, params)
|
|
210
213
|
result = await self.db.query(query, params or {})
|
|
211
214
|
logger.debug("execute_query result: %s", result)
|
|
212
|
-
return [result] if isinstance(result, dict) else result
|
|
215
|
+
return cast("list[dict[str, Any]]", [result] if isinstance(result, dict) else result)
|
|
213
216
|
|
|
214
217
|
async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
|
|
215
218
|
"""Fetch a record from a table by a unique field.
|
|
@@ -260,6 +263,9 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
260
263
|
Args:
|
|
261
264
|
live_id: Live query ID to kill
|
|
262
265
|
"""
|
|
266
|
+
if self._closed:
|
|
267
|
+
self._live_queries.discard(live_id)
|
|
268
|
+
return
|
|
263
269
|
logger.debug("Killing live query: %s", live_id)
|
|
264
270
|
await self.db.kill(live_id)
|
|
265
271
|
self._live_queries.discard(live_id) # Remove from tracker
|