digitalkin 0.3.2.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- base_server/__init__.py +1 -0
- base_server/mock/__init__.py +5 -0
- base_server/mock/mock_pb2.py +39 -0
- base_server/mock/mock_pb2_grpc.py +102 -0
- base_server/server_async_insecure.py +125 -0
- base_server/server_async_secure.py +143 -0
- base_server/server_sync_insecure.py +103 -0
- base_server/server_sync_secure.py +122 -0
- digitalkin/__init__.py +8 -0
- digitalkin/__version__.py +8 -0
- digitalkin/core/__init__.py +1 -0
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/__init__.py +1 -0
- digitalkin/core/job_manager/base_job_manager.py +288 -0
- digitalkin/core/job_manager/single_job_manager.py +354 -0
- digitalkin/core/job_manager/taskiq_broker.py +311 -0
- digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
- digitalkin/core/task_manager/__init__.py +1 -0
- digitalkin/core/task_manager/base_task_manager.py +539 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +266 -0
- digitalkin/core/task_manager/task_executor.py +249 -0
- digitalkin/core/task_manager/task_session.py +406 -0
- digitalkin/grpc_servers/__init__.py +1 -0
- digitalkin/grpc_servers/_base_server.py +486 -0
- digitalkin/grpc_servers/module_server.py +208 -0
- digitalkin/grpc_servers/module_servicer.py +516 -0
- digitalkin/grpc_servers/utils/__init__.py +1 -0
- digitalkin/grpc_servers/utils/exceptions.py +29 -0
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +88 -0
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/grpc_servers/utils/utility_schema_extender.py +97 -0
- digitalkin/logger.py +157 -0
- digitalkin/mixins/__init__.py +19 -0
- digitalkin/mixins/base_mixin.py +10 -0
- digitalkin/mixins/callback_mixin.py +24 -0
- digitalkin/mixins/chat_history_mixin.py +110 -0
- digitalkin/mixins/cost_mixin.py +76 -0
- digitalkin/mixins/file_history_mixin.py +93 -0
- digitalkin/mixins/filesystem_mixin.py +46 -0
- digitalkin/mixins/logger_mixin.py +51 -0
- digitalkin/mixins/storage_mixin.py +79 -0
- digitalkin/models/__init__.py +8 -0
- digitalkin/models/core/__init__.py +1 -0
- digitalkin/models/core/job_manager_models.py +36 -0
- digitalkin/models/core/task_monitor.py +70 -0
- digitalkin/models/grpc_servers/__init__.py +1 -0
- digitalkin/models/grpc_servers/models.py +275 -0
- digitalkin/models/grpc_servers/types.py +24 -0
- digitalkin/models/module/__init__.py +25 -0
- digitalkin/models/module/module.py +40 -0
- digitalkin/models/module/module_context.py +149 -0
- digitalkin/models/module/module_types.py +393 -0
- digitalkin/models/module/utility.py +146 -0
- digitalkin/models/services/__init__.py +10 -0
- digitalkin/models/services/cost.py +54 -0
- digitalkin/models/services/registry.py +42 -0
- digitalkin/models/services/storage.py +44 -0
- digitalkin/modules/__init__.py +11 -0
- digitalkin/modules/_base_module.py +517 -0
- digitalkin/modules/archetype_module.py +23 -0
- digitalkin/modules/tool_module.py +23 -0
- digitalkin/modules/trigger_handler.py +48 -0
- digitalkin/modules/triggers/__init__.py +12 -0
- digitalkin/modules/triggers/healthcheck_ping_trigger.py +45 -0
- digitalkin/modules/triggers/healthcheck_services_trigger.py +63 -0
- digitalkin/modules/triggers/healthcheck_status_trigger.py +52 -0
- digitalkin/py.typed +0 -0
- digitalkin/services/__init__.py +30 -0
- digitalkin/services/agent/__init__.py +6 -0
- digitalkin/services/agent/agent_strategy.py +19 -0
- digitalkin/services/agent/default_agent.py +13 -0
- digitalkin/services/base_strategy.py +22 -0
- digitalkin/services/communication/__init__.py +7 -0
- digitalkin/services/communication/communication_strategy.py +76 -0
- digitalkin/services/communication/default_communication.py +101 -0
- digitalkin/services/communication/grpc_communication.py +223 -0
- digitalkin/services/cost/__init__.py +14 -0
- digitalkin/services/cost/cost_strategy.py +100 -0
- digitalkin/services/cost/default_cost.py +114 -0
- digitalkin/services/cost/grpc_cost.py +138 -0
- digitalkin/services/filesystem/__init__.py +7 -0
- digitalkin/services/filesystem/default_filesystem.py +417 -0
- digitalkin/services/filesystem/filesystem_strategy.py +252 -0
- digitalkin/services/filesystem/grpc_filesystem.py +317 -0
- digitalkin/services/identity/__init__.py +6 -0
- digitalkin/services/identity/default_identity.py +15 -0
- digitalkin/services/identity/identity_strategy.py +14 -0
- digitalkin/services/registry/__init__.py +27 -0
- digitalkin/services/registry/default_registry.py +141 -0
- digitalkin/services/registry/exceptions.py +47 -0
- digitalkin/services/registry/grpc_registry.py +306 -0
- digitalkin/services/registry/registry_models.py +43 -0
- digitalkin/services/registry/registry_strategy.py +98 -0
- digitalkin/services/services_config.py +200 -0
- digitalkin/services/services_models.py +65 -0
- digitalkin/services/setup/__init__.py +1 -0
- digitalkin/services/setup/default_setup.py +219 -0
- digitalkin/services/setup/grpc_setup.py +343 -0
- digitalkin/services/setup/setup_strategy.py +145 -0
- digitalkin/services/snapshot/__init__.py +6 -0
- digitalkin/services/snapshot/default_snapshot.py +39 -0
- digitalkin/services/snapshot/snapshot_strategy.py +30 -0
- digitalkin/services/storage/__init__.py +7 -0
- digitalkin/services/storage/default_storage.py +228 -0
- digitalkin/services/storage/grpc_storage.py +214 -0
- digitalkin/services/storage/storage_strategy.py +273 -0
- digitalkin/services/user_profile/__init__.py +12 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- digitalkin/utils/__init__.py +29 -0
- digitalkin/utils/arg_parser.py +92 -0
- digitalkin/utils/development_mode_action.py +51 -0
- digitalkin/utils/dynamic_schema.py +483 -0
- digitalkin/utils/llm_ready_schema.py +75 -0
- digitalkin/utils/package_discover.py +357 -0
- digitalkin-0.3.2.dev2.dist-info/METADATA +602 -0
- digitalkin-0.3.2.dev2.dist-info/RECORD +131 -0
- digitalkin-0.3.2.dev2.dist-info/WHEEL +5 -0
- digitalkin-0.3.2.dev2.dist-info/licenses/LICENSE +430 -0
- digitalkin-0.3.2.dev2.dist-info/top_level.txt +4 -0
- modules/__init__.py +0 -0
- modules/cpu_intensive_module.py +280 -0
- modules/dynamic_setup_module.py +338 -0
- modules/minimal_llm_module.py +347 -0
- modules/text_transform_module.py +203 -0
- services/filesystem_module.py +200 -0
- services/storage_module.py +206 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Remote task manager for distributed execution."""
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
from collections.abc import Coroutine
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from digitalkin.core.task_manager.base_task_manager import BaseTaskManager
|
|
8
|
+
from digitalkin.logger import logger
|
|
9
|
+
from digitalkin.modules._base_module import BaseModule
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RemoteTaskManager(BaseTaskManager):
|
|
13
|
+
"""Task manager for distributed/remote execution.
|
|
14
|
+
|
|
15
|
+
Only manages task metadata and signals - actual execution happens in remote workers.
|
|
16
|
+
Suitable for horizontally scaled deployments with Taskiq/Celery workers.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
async def create_task(
|
|
20
|
+
self,
|
|
21
|
+
task_id: str,
|
|
22
|
+
mission_id: str,
|
|
23
|
+
module: BaseModule,
|
|
24
|
+
coro: Coroutine[Any, Any, None],
|
|
25
|
+
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
26
|
+
connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Register task for remote execution (metadata only).
|
|
29
|
+
|
|
30
|
+
Creates TaskSession for signal handling and monitoring, but doesn't execute the coroutine.
|
|
31
|
+
The coroutine will be recreated and executed by a remote worker.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
task_id: Unique identifier for the task
|
|
35
|
+
mission_id: Mission identifier
|
|
36
|
+
module: Module instance for metadata (not executed here)
|
|
37
|
+
coro: Coroutine (will be closed - execution happens in worker)
|
|
38
|
+
heartbeat_interval: Interval between heartbeats
|
|
39
|
+
connection_timeout: Connection timeout for SurrealDB
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If task_id duplicated
|
|
43
|
+
RuntimeError: If task overload
|
|
44
|
+
"""
|
|
45
|
+
# Validation
|
|
46
|
+
await self._validate_task_creation(task_id, mission_id, coro)
|
|
47
|
+
|
|
48
|
+
logger.info(
|
|
49
|
+
"Registering remote task: '%s'",
|
|
50
|
+
task_id,
|
|
51
|
+
extra={
|
|
52
|
+
"mission_id": mission_id,
|
|
53
|
+
"task_id": task_id,
|
|
54
|
+
"heartbeat_interval": heartbeat_interval,
|
|
55
|
+
"connection_timeout": connection_timeout,
|
|
56
|
+
},
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
# Create session for metadata and signal handling
|
|
61
|
+
_channel, _session = await self._create_session(
|
|
62
|
+
task_id, mission_id, module, heartbeat_interval, connection_timeout
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Close coroutine - worker will recreate and execute it
|
|
66
|
+
coro.close()
|
|
67
|
+
|
|
68
|
+
logger.info(
|
|
69
|
+
"Remote task registered: '%s'",
|
|
70
|
+
task_id,
|
|
71
|
+
extra={
|
|
72
|
+
"mission_id": mission_id,
|
|
73
|
+
"task_id": task_id,
|
|
74
|
+
"total_sessions": len(self.tasks_sessions),
|
|
75
|
+
},
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(
|
|
80
|
+
"Failed to register remote task: '%s'",
|
|
81
|
+
task_id,
|
|
82
|
+
extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
|
|
83
|
+
exc_info=True,
|
|
84
|
+
)
|
|
85
|
+
# Cleanup on failure
|
|
86
|
+
await self._cleanup_task(task_id, mission_id=mission_id)
|
|
87
|
+
raise
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""SurrealDB connection management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import datetime
|
|
5
|
+
import os
|
|
6
|
+
from collections.abc import AsyncGenerator
|
|
7
|
+
from typing import Any, Generic, TypeVar, cast
|
|
8
|
+
from uuid import UUID
|
|
9
|
+
|
|
10
|
+
from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
|
|
11
|
+
|
|
12
|
+
from digitalkin.logger import logger
|
|
13
|
+
|
|
14
|
+
TSurreal = TypeVar("TSurreal", bound=AsyncHttpSurrealConnection | AsyncWsSurrealConnection)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SurrealDBSetupBadIDError(Exception):
|
|
18
|
+
"""Exception raised when an invalid ID is encountered during the setup process in the SurrealDB repository.
|
|
19
|
+
|
|
20
|
+
This error is used to indicate that the provided ID does not meet the
|
|
21
|
+
expected format or criteria.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SurrealDBSetupVersionBadIDError(Exception):
|
|
26
|
+
"""Exception raised when an invalid ID is encountered during the setup of a SurrealDB version.
|
|
27
|
+
|
|
28
|
+
This error is intended to signal that the provided ID does not meet
|
|
29
|
+
the expected format or criteria for a valid SurrealDB setup version ID.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SurrealDBConnection(Generic[TSurreal]):
|
|
34
|
+
"""Base repository for database operations.
|
|
35
|
+
|
|
36
|
+
This class provides common database operations that can be used by
|
|
37
|
+
specific table repositories.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
db: TSurreal
|
|
41
|
+
timeout: datetime.timedelta
|
|
42
|
+
_live_queries: set[UUID] # Track active live queries for cleanup
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _valid_id(raw_id: str, table_name: str) -> RecordID:
|
|
46
|
+
"""Validate and parse a raw ID string into a RecordID.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
raw_id: The raw ID string to validate
|
|
50
|
+
table_name: table name to enforce
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
SurrealDBSetupBadIDError: If the raw ID string is not valid
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
RecordID: Parsed RecordID object if valid, None otherwise
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
split_id = raw_id.split(":")
|
|
60
|
+
if split_id[0] != table_name:
|
|
61
|
+
msg = f"Invalid table name for ID: {raw_id}"
|
|
62
|
+
raise SurrealDBSetupBadIDError(msg)
|
|
63
|
+
return RecordID(split_id[0], split_id[1])
|
|
64
|
+
except IndexError:
|
|
65
|
+
raise SurrealDBSetupBadIDError
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
database: str | None = None,
|
|
70
|
+
timeout: datetime.timedelta = datetime.timedelta(seconds=5),
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Initialize the repository.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
database: AsyncSurrealDB connection to a specific database
|
|
76
|
+
timeout: Timeout for database operations
|
|
77
|
+
"""
|
|
78
|
+
self.timeout = timeout
|
|
79
|
+
base_url = os.getenv("SURREALDB_URL", "ws://localhost").strip()
|
|
80
|
+
port = (os.getenv("SURREALDB_PORT") or "").strip()
|
|
81
|
+
self.url = f"{base_url}{f':{port}' if port else ''}/rpc"
|
|
82
|
+
|
|
83
|
+
self.username = os.getenv("SURREALDB_USERNAME", "root")
|
|
84
|
+
self.password = os.getenv("SURREALDB_PASSWORD", "root")
|
|
85
|
+
self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
|
|
86
|
+
self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
|
|
87
|
+
self._live_queries = set() # Initialize live queries tracker
|
|
88
|
+
|
|
89
|
+
async def init_surreal_instance(self) -> None:
|
|
90
|
+
"""Init a SurrealDB connection instance."""
|
|
91
|
+
logger.debug("Connecting to SurrealDB at %s", self.url)
|
|
92
|
+
self.db = AsyncSurreal(self.url) # type: ignore
|
|
93
|
+
await self.db.signin({"username": self.username, "password": self.password})
|
|
94
|
+
await self.db.use(self.namespace, self.database) # type: ignore[arg-type]
|
|
95
|
+
logger.debug("Successfully connected to SurrealDB")
|
|
96
|
+
|
|
97
|
+
async def close(self) -> None:
|
|
98
|
+
"""Close the SurrealDB connection if it exists.
|
|
99
|
+
|
|
100
|
+
This will also kill all active live queries to prevent memory leaks.
|
|
101
|
+
"""
|
|
102
|
+
# Kill all tracked live queries before closing connection
|
|
103
|
+
if self._live_queries:
|
|
104
|
+
logger.debug("Killing %d active live queries before closing", len(self._live_queries))
|
|
105
|
+
live_query_ids = list(self._live_queries)
|
|
106
|
+
|
|
107
|
+
# Kill all queries concurrently, capturing any exceptions
|
|
108
|
+
results = await asyncio.gather(
|
|
109
|
+
*[self.db.kill(live_id) for live_id in live_query_ids], return_exceptions=True
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Process results and track failures
|
|
113
|
+
failed_queries = []
|
|
114
|
+
for live_id, result in zip(live_query_ids, results):
|
|
115
|
+
if isinstance(result, ConnectionError | TimeoutError | Exception):
|
|
116
|
+
failed_queries.append((live_id, str(result)))
|
|
117
|
+
else:
|
|
118
|
+
self._live_queries.discard(live_id)
|
|
119
|
+
|
|
120
|
+
# Log aggregated failures once instead of per-query
|
|
121
|
+
if failed_queries:
|
|
122
|
+
logger.warning(
|
|
123
|
+
"Failed to kill %d live queries: %s",
|
|
124
|
+
len(failed_queries),
|
|
125
|
+
failed_queries[:5], # Only log first 5 to avoid log spam
|
|
126
|
+
extra={"total_failed": len(failed_queries)},
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
logger.debug("Closing SurrealDB connection")
|
|
130
|
+
await self.db.close()
|
|
131
|
+
|
|
132
|
+
async def create(
|
|
133
|
+
self,
|
|
134
|
+
table_name: str,
|
|
135
|
+
data: dict[str, Any],
|
|
136
|
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
|
137
|
+
"""Create a new record.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
table_name: Name of the table to insert into
|
|
141
|
+
data: Data to insert
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Dict[str, Any]: The created record as returned by the database
|
|
145
|
+
"""
|
|
146
|
+
logger.debug("Creating record in %s with data: %s", table_name, data)
|
|
147
|
+
result = await self.db.create(table_name, data)
|
|
148
|
+
logger.debug("create result: %s", result)
|
|
149
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
150
|
+
|
|
151
|
+
async def merge(
|
|
152
|
+
self,
|
|
153
|
+
table_name: str,
|
|
154
|
+
record_id: str | RecordID,
|
|
155
|
+
data: dict[str, Any],
|
|
156
|
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
|
157
|
+
"""Update an existing record.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
table_name: Name of the table to insert into
|
|
161
|
+
record_id: record ID to update
|
|
162
|
+
data: Data to insert
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Dict[str, Any]: The created record as returned by the database
|
|
166
|
+
"""
|
|
167
|
+
if isinstance(record_id, str):
|
|
168
|
+
# validate surrealDB id if raw str
|
|
169
|
+
record_id = self._valid_id(record_id, table_name)
|
|
170
|
+
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
171
|
+
result = await self.db.merge(record_id, data)
|
|
172
|
+
logger.debug("update result: %s", result)
|
|
173
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
174
|
+
|
|
175
|
+
async def update(
|
|
176
|
+
self,
|
|
177
|
+
table_name: str,
|
|
178
|
+
record_id: str | RecordID,
|
|
179
|
+
data: dict[str, Any],
|
|
180
|
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
|
181
|
+
"""Update an existing record.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
table_name: Name of the table to insert into
|
|
185
|
+
record_id: record ID to update
|
|
186
|
+
data: Data to insert
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Dict[str, Any]: The created record as returned by the database
|
|
190
|
+
"""
|
|
191
|
+
if isinstance(record_id, str):
|
|
192
|
+
# validate surrealDB id if raw str
|
|
193
|
+
record_id = self._valid_id(record_id, table_name)
|
|
194
|
+
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
195
|
+
result = await self.db.update(record_id, data)
|
|
196
|
+
logger.debug("update result: %s", result)
|
|
197
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
198
|
+
|
|
199
|
+
async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
|
|
200
|
+
"""Execute a custom SurrealQL query.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
query: SurrealQL query
|
|
204
|
+
params: Query parameters
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
List[Dict[str, Any]]: Query results
|
|
208
|
+
"""
|
|
209
|
+
logger.debug("execute_query: %s with params: %s", query, params)
|
|
210
|
+
result = await self.db.query(query, params or {})
|
|
211
|
+
logger.debug("execute_query result: %s", result)
|
|
212
|
+
return cast("list[dict[str, Any]]", [result] if isinstance(result, dict) else result)
|
|
213
|
+
|
|
214
|
+
async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
|
|
215
|
+
"""Fetch a record from a table by a unique field.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
table: Table name
|
|
219
|
+
value: Field value to match
|
|
220
|
+
|
|
221
|
+
Raises:
|
|
222
|
+
ValueError: If no records are found
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Dict with record data if found, else None
|
|
226
|
+
"""
|
|
227
|
+
query = "SELECT * FROM type::table($table) WHERE task_id = $value;"
|
|
228
|
+
params = {"table": table, "value": value}
|
|
229
|
+
|
|
230
|
+
result = await self.execute_query(query, params)
|
|
231
|
+
if not result:
|
|
232
|
+
msg = f"No records found in table '{table}' with task_id '{value}'"
|
|
233
|
+
logger.error(msg)
|
|
234
|
+
raise ValueError(msg)
|
|
235
|
+
|
|
236
|
+
return result[0]
|
|
237
|
+
|
|
238
|
+
async def start_live(
|
|
239
|
+
self,
|
|
240
|
+
table_name: str,
|
|
241
|
+
) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
|
|
242
|
+
"""Create and subscribe to a live SurrealQL query.
|
|
243
|
+
|
|
244
|
+
The live query ID is tracked to ensure proper cleanup on connection close.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
table_name: Name of the table to insert into
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
tuple[UUID, AsyncGenerator]: Live query ID and subscription generator
|
|
251
|
+
"""
|
|
252
|
+
live_id = await self.db.live(table_name, diff=False)
|
|
253
|
+
self._live_queries.add(live_id) # Track for cleanup
|
|
254
|
+
logger.debug("Started live query %s for table %s (total: %d)", live_id, table_name, len(self._live_queries))
|
|
255
|
+
return live_id, await self.db.subscribe_live(live_id)
|
|
256
|
+
|
|
257
|
+
async def stop_live(self, live_id: UUID) -> None:
|
|
258
|
+
"""Kill a live SurrealQL query.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
live_id: Live query ID to kill
|
|
262
|
+
"""
|
|
263
|
+
logger.debug("Killing live query: %s", live_id)
|
|
264
|
+
await self.db.kill(live_id)
|
|
265
|
+
self._live_queries.discard(live_id) # Remove from tracker
|
|
266
|
+
logger.debug("Stopped live query %s (remaining: %d)", live_id, len(self._live_queries))
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
"""Task executor for running tasks with full lifecycle management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import datetime
|
|
5
|
+
from collections.abc import Coroutine
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
9
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
10
|
+
from digitalkin.logger import logger
|
|
11
|
+
from digitalkin.models.core.task_monitor import (
|
|
12
|
+
CancellationReason,
|
|
13
|
+
SignalMessage,
|
|
14
|
+
SignalType,
|
|
15
|
+
TaskStatus,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TaskExecutor:
|
|
20
|
+
"""Executes tasks with the supervisor pattern (main + heartbeat + signal listener).
|
|
21
|
+
|
|
22
|
+
Pure execution logic - no task registry or orchestration.
|
|
23
|
+
Used by workers to run distributed tasks or by TaskManager for local execution.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
async def execute_task( # noqa: C901, PLR0915
|
|
28
|
+
task_id: str,
|
|
29
|
+
mission_id: str,
|
|
30
|
+
coro: Coroutine[Any, Any, None],
|
|
31
|
+
session: TaskSession,
|
|
32
|
+
channel: SurrealDBConnection,
|
|
33
|
+
) -> asyncio.Task[None]:
|
|
34
|
+
"""Execute a task using the supervisor pattern.
|
|
35
|
+
|
|
36
|
+
Runs three concurrent sub-tasks:
|
|
37
|
+
- Main coroutine (the actual work)
|
|
38
|
+
- Heartbeat generator (sends heartbeats to SurrealDB)
|
|
39
|
+
- Signal listener (watches for stop/pause/resume signals)
|
|
40
|
+
|
|
41
|
+
The first task to complete determines the outcome.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
task_id: Unique identifier for the task
|
|
45
|
+
mission_id: Mission identifier for the task
|
|
46
|
+
coro: The coroutine to execute (module.start(...))
|
|
47
|
+
session: TaskSession for state management
|
|
48
|
+
channel: SurrealDB connection for signals
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
asyncio.Task: The supervisor task managing the lifecycle
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
async def signal_wrapper() -> None:
|
|
55
|
+
"""Create initial signal record and listen for signals."""
|
|
56
|
+
try:
|
|
57
|
+
await channel.create(
|
|
58
|
+
"tasks",
|
|
59
|
+
SignalMessage(
|
|
60
|
+
task_id=task_id,
|
|
61
|
+
mission_id=mission_id,
|
|
62
|
+
status=session.status,
|
|
63
|
+
action=SignalType.START,
|
|
64
|
+
).model_dump(),
|
|
65
|
+
)
|
|
66
|
+
await session.listen_signals()
|
|
67
|
+
except asyncio.CancelledError:
|
|
68
|
+
logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
69
|
+
finally:
|
|
70
|
+
await channel.create(
|
|
71
|
+
"tasks",
|
|
72
|
+
SignalMessage(
|
|
73
|
+
task_id=task_id,
|
|
74
|
+
mission_id=mission_id,
|
|
75
|
+
status=session.status,
|
|
76
|
+
action=SignalType.STOP,
|
|
77
|
+
).model_dump(),
|
|
78
|
+
)
|
|
79
|
+
logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
80
|
+
|
|
81
|
+
async def heartbeat_wrapper() -> None:
|
|
82
|
+
"""Generate heartbeats for task health monitoring."""
|
|
83
|
+
try:
|
|
84
|
+
await session.generate_heartbeats()
|
|
85
|
+
except asyncio.CancelledError:
|
|
86
|
+
logger.debug("Heartbeat cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
87
|
+
finally:
|
|
88
|
+
logger.info("Heartbeat task ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
89
|
+
|
|
90
|
+
async def supervisor() -> None: # noqa: C901, PLR0912, PLR0915
|
|
91
|
+
"""Supervise the three concurrent tasks and handle outcomes.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
RuntimeError: If the heartbeat task stops unexpectedly.
|
|
95
|
+
asyncio.CancelledError: If the supervisor task is cancelled.
|
|
96
|
+
"""
|
|
97
|
+
session.started_at = datetime.datetime.now(datetime.timezone.utc)
|
|
98
|
+
session.status = TaskStatus.RUNNING
|
|
99
|
+
|
|
100
|
+
# Create tasks with proper exception handling
|
|
101
|
+
main_task = None
|
|
102
|
+
hb_task = None
|
|
103
|
+
sig_task = None
|
|
104
|
+
cleanup_reason = CancellationReason.UNKNOWN
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
main_task = asyncio.create_task(coro, name=f"{task_id}_main")
|
|
108
|
+
hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
|
|
109
|
+
sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
|
|
110
|
+
done, pending = await asyncio.wait(
|
|
111
|
+
[main_task, sig_task, hb_task],
|
|
112
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Determine cleanup reason based on which task completed first
|
|
116
|
+
completed = next(iter(done))
|
|
117
|
+
|
|
118
|
+
if completed is main_task:
|
|
119
|
+
# Main task finished - cleanup is due to success
|
|
120
|
+
cleanup_reason = CancellationReason.SUCCESS_CLEANUP
|
|
121
|
+
elif completed is sig_task or (completed is hb_task and sig_task.done()):
|
|
122
|
+
# Signal task finished - external cancellation
|
|
123
|
+
cleanup_reason = CancellationReason.SIGNAL
|
|
124
|
+
elif completed is hb_task:
|
|
125
|
+
# Heartbeat stopped - failure cleanup
|
|
126
|
+
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
127
|
+
|
|
128
|
+
# Cancel pending tasks with proper reason logging
|
|
129
|
+
if pending:
|
|
130
|
+
pending_names = [t.get_name() for t in pending]
|
|
131
|
+
logger.debug(
|
|
132
|
+
"Cancelling pending tasks: %s, reason: %s",
|
|
133
|
+
pending_names,
|
|
134
|
+
cleanup_reason.value,
|
|
135
|
+
extra={
|
|
136
|
+
"mission_id": mission_id,
|
|
137
|
+
"task_id": task_id,
|
|
138
|
+
"pending_tasks": pending_names,
|
|
139
|
+
"cancellation_reason": cleanup_reason.value,
|
|
140
|
+
},
|
|
141
|
+
)
|
|
142
|
+
for t in pending:
|
|
143
|
+
t.cancel()
|
|
144
|
+
|
|
145
|
+
# Propagate exception/result from the finished task
|
|
146
|
+
await completed
|
|
147
|
+
|
|
148
|
+
# Determine final status based on which task completed
|
|
149
|
+
if completed is main_task:
|
|
150
|
+
session.status = TaskStatus.COMPLETED
|
|
151
|
+
logger.info(
|
|
152
|
+
"Main task completed successfully",
|
|
153
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
154
|
+
)
|
|
155
|
+
elif completed is sig_task or (completed is hb_task and sig_task.done()):
|
|
156
|
+
session.status = TaskStatus.CANCELLED
|
|
157
|
+
session.cancellation_reason = CancellationReason.SIGNAL
|
|
158
|
+
logger.info(
|
|
159
|
+
"Task cancelled via external signal",
|
|
160
|
+
extra={
|
|
161
|
+
"mission_id": mission_id,
|
|
162
|
+
"task_id": task_id,
|
|
163
|
+
"cancellation_reason": CancellationReason.SIGNAL.value,
|
|
164
|
+
},
|
|
165
|
+
)
|
|
166
|
+
elif completed is hb_task:
|
|
167
|
+
session.status = TaskStatus.FAILED
|
|
168
|
+
session.cancellation_reason = CancellationReason.HEARTBEAT_FAILURE
|
|
169
|
+
logger.error(
|
|
170
|
+
"Heartbeat stopped unexpectedly for task: '%s'",
|
|
171
|
+
task_id,
|
|
172
|
+
extra={
|
|
173
|
+
"mission_id": mission_id,
|
|
174
|
+
"task_id": task_id,
|
|
175
|
+
"cancellation_reason": CancellationReason.HEARTBEAT_FAILURE.value,
|
|
176
|
+
},
|
|
177
|
+
)
|
|
178
|
+
msg = f"Heartbeat stopped for {task_id}"
|
|
179
|
+
raise RuntimeError(msg) # noqa: TRY301
|
|
180
|
+
|
|
181
|
+
except asyncio.CancelledError:
|
|
182
|
+
session.status = TaskStatus.CANCELLED
|
|
183
|
+
# Only set reason if not already set (preserve original reason)
|
|
184
|
+
logger.info(
|
|
185
|
+
"Task cancelled externally: '%s', reason: %s",
|
|
186
|
+
task_id,
|
|
187
|
+
session.cancellation_reason.value,
|
|
188
|
+
extra={
|
|
189
|
+
"mission_id": mission_id,
|
|
190
|
+
"task_id": task_id,
|
|
191
|
+
"cancellation_reason": session.cancellation_reason.value,
|
|
192
|
+
},
|
|
193
|
+
)
|
|
194
|
+
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
195
|
+
raise
|
|
196
|
+
except Exception:
|
|
197
|
+
session.status = TaskStatus.FAILED
|
|
198
|
+
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
199
|
+
logger.exception(
|
|
200
|
+
"Task failed with exception: '%s'",
|
|
201
|
+
task_id,
|
|
202
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
203
|
+
)
|
|
204
|
+
raise
|
|
205
|
+
finally:
|
|
206
|
+
session.completed_at = datetime.datetime.now(datetime.timezone.utc)
|
|
207
|
+
# Ensure all tasks are cleaned up with proper reason
|
|
208
|
+
tasks_to_cleanup = [t for t in [main_task, hb_task, sig_task] if t is not None and not t.done()]
|
|
209
|
+
if tasks_to_cleanup:
|
|
210
|
+
cleanup_names = [t.get_name() for t in tasks_to_cleanup]
|
|
211
|
+
logger.debug(
|
|
212
|
+
"Final cleanup of %d remaining tasks: %s, reason: %s",
|
|
213
|
+
len(tasks_to_cleanup),
|
|
214
|
+
cleanup_names,
|
|
215
|
+
cleanup_reason.value,
|
|
216
|
+
extra={
|
|
217
|
+
"mission_id": mission_id,
|
|
218
|
+
"task_id": task_id,
|
|
219
|
+
"cleanup_count": len(tasks_to_cleanup),
|
|
220
|
+
"cleanup_tasks": cleanup_names,
|
|
221
|
+
"cancellation_reason": cleanup_reason.value,
|
|
222
|
+
},
|
|
223
|
+
)
|
|
224
|
+
for t in tasks_to_cleanup:
|
|
225
|
+
t.cancel()
|
|
226
|
+
await asyncio.gather(*tasks_to_cleanup, return_exceptions=True)
|
|
227
|
+
|
|
228
|
+
duration = (
|
|
229
|
+
(session.completed_at - session.started_at).total_seconds()
|
|
230
|
+
if session.started_at and session.completed_at
|
|
231
|
+
else None
|
|
232
|
+
)
|
|
233
|
+
logger.info(
|
|
234
|
+
"Task execution completed: '%s', status: %s, reason: %s, duration: %.2fs",
|
|
235
|
+
task_id,
|
|
236
|
+
session.status.value,
|
|
237
|
+
session.cancellation_reason.value if session.status == TaskStatus.CANCELLED else "n/a",
|
|
238
|
+
duration or 0,
|
|
239
|
+
extra={
|
|
240
|
+
"mission_id": mission_id,
|
|
241
|
+
"task_id": task_id,
|
|
242
|
+
"status": session.status.value,
|
|
243
|
+
"cancellation_reason": session.cancellation_reason.value,
|
|
244
|
+
"duration": duration,
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
# Return the supervisor task to be awaited by caller
|
|
249
|
+
return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
|