digitalkin 0.3.1.dev2__py3-none-any.whl → 0.3.2a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- base_server/server_async_insecure.py +6 -5
- base_server/server_async_secure.py +6 -5
- base_server/server_sync_insecure.py +5 -4
- base_server/server_sync_secure.py +5 -4
- digitalkin/__version__.py +1 -1
- digitalkin/core/job_manager/base_job_manager.py +1 -1
- digitalkin/core/job_manager/single_job_manager.py +78 -36
- digitalkin/core/job_manager/taskiq_broker.py +7 -6
- digitalkin/core/job_manager/taskiq_job_manager.py +9 -5
- digitalkin/core/task_manager/base_task_manager.py +3 -1
- digitalkin/core/task_manager/surrealdb_repository.py +29 -7
- digitalkin/core/task_manager/task_executor.py +46 -12
- digitalkin/core/task_manager/task_session.py +132 -102
- digitalkin/grpc_servers/module_server.py +95 -171
- digitalkin/grpc_servers/module_servicer.py +121 -19
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +36 -10
- digitalkin/grpc_servers/utils/utility_schema_extender.py +106 -0
- digitalkin/models/__init__.py +1 -1
- digitalkin/models/core/job_manager_models.py +0 -8
- digitalkin/models/core/task_monitor.py +23 -1
- digitalkin/models/grpc_servers/models.py +95 -8
- digitalkin/models/module/__init__.py +26 -13
- digitalkin/models/module/base_types.py +61 -0
- digitalkin/models/module/module_context.py +279 -13
- digitalkin/models/module/module_types.py +28 -392
- digitalkin/models/module/setup_types.py +547 -0
- digitalkin/models/module/tool_cache.py +230 -0
- digitalkin/models/module/tool_reference.py +160 -0
- digitalkin/models/module/utility.py +167 -0
- digitalkin/models/services/cost.py +22 -1
- digitalkin/models/services/registry.py +77 -0
- digitalkin/modules/__init__.py +5 -1
- digitalkin/modules/_base_module.py +188 -63
- digitalkin/modules/archetype_module.py +6 -1
- digitalkin/modules/tool_module.py +6 -1
- digitalkin/modules/triggers/__init__.py +8 -0
- digitalkin/modules/triggers/healthcheck_ping_trigger.py +45 -0
- digitalkin/modules/triggers/healthcheck_services_trigger.py +63 -0
- digitalkin/modules/triggers/healthcheck_status_trigger.py +52 -0
- digitalkin/services/__init__.py +4 -0
- digitalkin/services/communication/__init__.py +7 -0
- digitalkin/services/communication/communication_strategy.py +87 -0
- digitalkin/services/communication/default_communication.py +104 -0
- digitalkin/services/communication/grpc_communication.py +264 -0
- digitalkin/services/cost/cost_strategy.py +36 -14
- digitalkin/services/cost/default_cost.py +61 -1
- digitalkin/services/cost/grpc_cost.py +98 -2
- digitalkin/services/filesystem/grpc_filesystem.py +9 -2
- digitalkin/services/registry/__init__.py +22 -1
- digitalkin/services/registry/default_registry.py +156 -4
- digitalkin/services/registry/exceptions.py +47 -0
- digitalkin/services/registry/grpc_registry.py +382 -0
- digitalkin/services/registry/registry_models.py +15 -0
- digitalkin/services/registry/registry_strategy.py +106 -4
- digitalkin/services/services_config.py +25 -3
- digitalkin/services/services_models.py +5 -1
- digitalkin/services/setup/default_setup.py +1 -1
- digitalkin/services/setup/grpc_setup.py +1 -1
- digitalkin/services/storage/grpc_storage.py +1 -1
- digitalkin/services/user_profile/__init__.py +11 -0
- digitalkin/services/user_profile/grpc_user_profile.py +2 -2
- digitalkin/services/user_profile/user_profile_strategy.py +0 -15
- digitalkin/utils/__init__.py +15 -3
- digitalkin/utils/conditional_schema.py +260 -0
- digitalkin/utils/dynamic_schema.py +4 -0
- digitalkin/utils/schema_splitter.py +290 -0
- {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/METADATA +12 -12
- digitalkin-0.3.2a3.dist-info/RECORD +144 -0
- {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/WHEEL +1 -1
- {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/top_level.txt +1 -0
- modules/archetype_with_tools_module.py +232 -0
- modules/cpu_intensive_module.py +1 -1
- modules/dynamic_setup_module.py +5 -29
- modules/minimal_llm_module.py +1 -1
- modules/text_transform_module.py +1 -1
- monitoring/digitalkin_observability/__init__.py +46 -0
- monitoring/digitalkin_observability/http_server.py +150 -0
- monitoring/digitalkin_observability/interceptors.py +176 -0
- monitoring/digitalkin_observability/metrics.py +201 -0
- monitoring/digitalkin_observability/prometheus.py +137 -0
- monitoring/tests/test_metrics.py +172 -0
- services/filesystem_module.py +7 -5
- services/storage_module.py +4 -2
- digitalkin/grpc_servers/registry_server.py +0 -65
- digitalkin/grpc_servers/registry_servicer.py +0 -456
- digitalkin-0.3.1.dev2.dist-info/RECORD +0 -119
- {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/licenses/LICENSE +0 -0
|
@@ -4,7 +4,7 @@ import asyncio
|
|
|
4
4
|
import datetime
|
|
5
5
|
import os
|
|
6
6
|
from collections.abc import AsyncGenerator
|
|
7
|
-
from typing import Any, Generic, TypeVar
|
|
7
|
+
from typing import Any, Generic, TypeVar, cast
|
|
8
8
|
from uuid import UUID
|
|
9
9
|
|
|
10
10
|
from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
|
|
@@ -40,6 +40,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
40
40
|
db: TSurreal
|
|
41
41
|
timeout: datetime.timedelta
|
|
42
42
|
_live_queries: set[UUID] # Track active live queries for cleanup
|
|
43
|
+
_closed: bool # Flag to prevent operations on closed connection
|
|
43
44
|
|
|
44
45
|
@staticmethod
|
|
45
46
|
def _valid_id(raw_id: str, table_name: str) -> RecordID:
|
|
@@ -85,13 +86,14 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
85
86
|
self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
|
|
86
87
|
self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
|
|
87
88
|
self._live_queries = set() # Initialize live queries tracker
|
|
89
|
+
self._closed = False
|
|
88
90
|
|
|
89
91
|
async def init_surreal_instance(self) -> None:
|
|
90
92
|
"""Init a SurrealDB connection instance."""
|
|
91
93
|
logger.debug("Connecting to SurrealDB at %s", self.url)
|
|
92
94
|
self.db = AsyncSurreal(self.url) # type: ignore
|
|
93
95
|
await self.db.signin({"username": self.username, "password": self.password})
|
|
94
|
-
await self.db.use(self.namespace, self.database)
|
|
96
|
+
await self.db.use(self.namespace, self.database) # type: ignore[arg-type]
|
|
95
97
|
logger.debug("Successfully connected to SurrealDB")
|
|
96
98
|
|
|
97
99
|
async def close(self) -> None:
|
|
@@ -99,6 +101,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
99
101
|
|
|
100
102
|
This will also kill all active live queries to prevent memory leaks.
|
|
101
103
|
"""
|
|
104
|
+
self._closed = True
|
|
102
105
|
# Kill all tracked live queries before closing connection
|
|
103
106
|
if self._live_queries:
|
|
104
107
|
logger.debug("Killing %d active live queries before closing", len(self._live_queries))
|
|
@@ -112,7 +115,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
112
115
|
# Process results and track failures
|
|
113
116
|
failed_queries = []
|
|
114
117
|
for live_id, result in zip(live_query_ids, results):
|
|
115
|
-
if isinstance(result,
|
|
118
|
+
if isinstance(result, ConnectionError | TimeoutError | Exception):
|
|
116
119
|
failed_queries.append((live_id, str(result)))
|
|
117
120
|
else:
|
|
118
121
|
self._live_queries.discard(live_id)
|
|
@@ -142,11 +145,27 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
142
145
|
|
|
143
146
|
Returns:
|
|
144
147
|
Dict[str, Any]: The created record as returned by the database
|
|
148
|
+
|
|
149
|
+
Raises:
|
|
150
|
+
RuntimeError: If the database returns an error response
|
|
145
151
|
"""
|
|
146
152
|
logger.debug("Creating record in %s with data: %s", table_name, data)
|
|
147
153
|
result = await self.db.create(table_name, data)
|
|
148
154
|
logger.debug("create result: %s", result)
|
|
149
|
-
|
|
155
|
+
|
|
156
|
+
# Check for error response from SurrealDB
|
|
157
|
+
if isinstance(result, dict) and "code" in result:
|
|
158
|
+
error_msg = result.get("message", result.get("information", "Unknown error"))
|
|
159
|
+
logger.error(
|
|
160
|
+
"SurrealDB create failed: %s (code: %s)",
|
|
161
|
+
error_msg,
|
|
162
|
+
result.get("code"),
|
|
163
|
+
extra={"table": table_name, "error": result},
|
|
164
|
+
)
|
|
165
|
+
msg = f"SurrealDB create failed in '{table_name}': {error_msg}"
|
|
166
|
+
raise RuntimeError(msg)
|
|
167
|
+
|
|
168
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
150
169
|
|
|
151
170
|
async def merge(
|
|
152
171
|
self,
|
|
@@ -170,7 +189,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
170
189
|
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
171
190
|
result = await self.db.merge(record_id, data)
|
|
172
191
|
logger.debug("update result: %s", result)
|
|
173
|
-
return result
|
|
192
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
174
193
|
|
|
175
194
|
async def update(
|
|
176
195
|
self,
|
|
@@ -194,7 +213,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
194
213
|
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
195
214
|
result = await self.db.update(record_id, data)
|
|
196
215
|
logger.debug("update result: %s", result)
|
|
197
|
-
return result
|
|
216
|
+
return cast("list[dict[str, Any]] | dict[str, Any]", result)
|
|
198
217
|
|
|
199
218
|
async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
|
|
200
219
|
"""Execute a custom SurrealQL query.
|
|
@@ -209,7 +228,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
209
228
|
logger.debug("execute_query: %s with params: %s", query, params)
|
|
210
229
|
result = await self.db.query(query, params or {})
|
|
211
230
|
logger.debug("execute_query result: %s", result)
|
|
212
|
-
return [result] if isinstance(result, dict) else result
|
|
231
|
+
return cast("list[dict[str, Any]]", [result] if isinstance(result, dict) else result)
|
|
213
232
|
|
|
214
233
|
async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
|
|
215
234
|
"""Fetch a record from a table by a unique field.
|
|
@@ -260,6 +279,9 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
260
279
|
Args:
|
|
261
280
|
live_id: Live query ID to kill
|
|
262
281
|
"""
|
|
282
|
+
if self._closed:
|
|
283
|
+
self._live_queries.discard(live_id)
|
|
284
|
+
return
|
|
263
285
|
logger.debug("Killing live query: %s", live_id)
|
|
264
286
|
await self.db.kill(live_id)
|
|
265
287
|
self._live_queries.discard(live_id) # Remove from tracker
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Task executor for running tasks with full lifecycle management."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import contextlib
|
|
4
5
|
import datetime
|
|
5
6
|
from collections.abc import Coroutine
|
|
6
7
|
from typing import Any
|
|
@@ -54,28 +55,53 @@ class TaskExecutor:
|
|
|
54
55
|
async def signal_wrapper() -> None:
|
|
55
56
|
"""Create initial signal record and listen for signals."""
|
|
56
57
|
try:
|
|
57
|
-
|
|
58
|
+
# Create task record and capture the record ID directly
|
|
59
|
+
# This avoids a race condition where SELECT might run before CREATE completes
|
|
60
|
+
result = await channel.create(
|
|
58
61
|
"tasks",
|
|
59
62
|
SignalMessage(
|
|
60
63
|
task_id=task_id,
|
|
61
64
|
mission_id=mission_id,
|
|
65
|
+
setup_id=session.setup_id,
|
|
66
|
+
setup_version_id=session.setup_version_id,
|
|
62
67
|
status=session.status,
|
|
63
68
|
action=SignalType.START,
|
|
64
69
|
).model_dump(),
|
|
65
70
|
)
|
|
66
|
-
|
|
71
|
+
# Store the record ID in session - required before starting live query
|
|
72
|
+
if isinstance(result, dict) and "id" in result:
|
|
73
|
+
session.signal_record_id = result["id"]
|
|
74
|
+
logger.debug(
|
|
75
|
+
"Task signal record created",
|
|
76
|
+
extra={"mission_id": mission_id, "task_id": task_id, "record_id": result["id"]},
|
|
77
|
+
)
|
|
78
|
+
# Only start listening if we have a valid record ID
|
|
79
|
+
await session.listen_signals()
|
|
80
|
+
else:
|
|
81
|
+
# Create failed - wait for cancellation instead of listening
|
|
82
|
+
logger.error(
|
|
83
|
+
"Failed to get record ID from task creation, waiting for cancellation",
|
|
84
|
+
extra={"mission_id": mission_id, "task_id": task_id, "result": result},
|
|
85
|
+
)
|
|
86
|
+
await session.is_cancelled.wait()
|
|
67
87
|
except asyncio.CancelledError:
|
|
68
88
|
logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
69
89
|
finally:
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
90
|
+
with contextlib.suppress(Exception): # Connection may already be closed
|
|
91
|
+
await channel.create(
|
|
92
|
+
"tasks",
|
|
93
|
+
SignalMessage(
|
|
94
|
+
task_id=task_id,
|
|
95
|
+
mission_id=mission_id,
|
|
96
|
+
setup_id=session.setup_id,
|
|
97
|
+
setup_version_id=session.setup_version_id,
|
|
98
|
+
status=session.status,
|
|
99
|
+
action=SignalType.STOP,
|
|
100
|
+
cancellation_reason=session.cancellation_reason,
|
|
101
|
+
error_message=session._last_exception, # noqa: SLF001
|
|
102
|
+
exception_traceback=session._last_traceback, # noqa: SLF001
|
|
103
|
+
).model_dump(),
|
|
104
|
+
)
|
|
79
105
|
logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
80
106
|
|
|
81
107
|
async def heartbeat_wrapper() -> None:
|
|
@@ -125,8 +151,14 @@ class TaskExecutor:
|
|
|
125
151
|
# Heartbeat stopped - failure cleanup
|
|
126
152
|
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
127
153
|
|
|
154
|
+
# Signal stream to close FIRST before any cleanup
|
|
155
|
+
session.close_stream()
|
|
156
|
+
|
|
128
157
|
# Cancel pending tasks with proper reason logging
|
|
129
158
|
if pending:
|
|
159
|
+
# Give stream time to see the signal and exit gracefully
|
|
160
|
+
await asyncio.sleep(0.01) # Allow one event loop cycle
|
|
161
|
+
|
|
130
162
|
pending_names = [t.get_name() for t in pending]
|
|
131
163
|
logger.debug(
|
|
132
164
|
"Cancelling pending tasks: %s, reason: %s",
|
|
@@ -148,6 +180,7 @@ class TaskExecutor:
|
|
|
148
180
|
# Determine final status based on which task completed
|
|
149
181
|
if completed is main_task:
|
|
150
182
|
session.status = TaskStatus.COMPLETED
|
|
183
|
+
session.cancellation_reason = CancellationReason.COMPLETED
|
|
151
184
|
logger.info(
|
|
152
185
|
"Main task completed successfully",
|
|
153
186
|
extra={"mission_id": mission_id, "task_id": task_id},
|
|
@@ -193,9 +226,10 @@ class TaskExecutor:
|
|
|
193
226
|
)
|
|
194
227
|
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
195
228
|
raise
|
|
196
|
-
except Exception:
|
|
229
|
+
except Exception as e:
|
|
197
230
|
session.status = TaskStatus.FAILED
|
|
198
231
|
cleanup_reason = CancellationReason.FAILURE_CLEANUP
|
|
232
|
+
session.record_exception(e)
|
|
199
233
|
logger.exception(
|
|
200
234
|
"Task failed with exception: '%s'",
|
|
201
235
|
task_id,
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Task session easing task lifecycle management."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import contextlib
|
|
4
5
|
import datetime
|
|
6
|
+
import traceback
|
|
5
7
|
from collections.abc import AsyncGenerator
|
|
6
8
|
|
|
7
9
|
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
@@ -39,9 +41,17 @@ class TaskSession:
|
|
|
39
41
|
is_cancelled: asyncio.Event
|
|
40
42
|
cancellation_reason: CancellationReason
|
|
41
43
|
_paused: asyncio.Event
|
|
44
|
+
_stream_closed: asyncio.Event
|
|
42
45
|
_heartbeat_interval: datetime.timedelta
|
|
43
46
|
_last_heartbeat: datetime.datetime
|
|
44
47
|
|
|
48
|
+
# Exception tracking for enhanced DB logging
|
|
49
|
+
_last_exception: str | None
|
|
50
|
+
_last_traceback: str | None
|
|
51
|
+
|
|
52
|
+
# Cleanup guard for idempotent cleanup
|
|
53
|
+
_cleanup_done: bool
|
|
54
|
+
|
|
45
55
|
def __init__(
|
|
46
56
|
self,
|
|
47
57
|
task_id: str,
|
|
@@ -81,12 +91,23 @@ class TaskSession:
|
|
|
81
91
|
self.is_cancelled = asyncio.Event()
|
|
82
92
|
self.cancellation_reason = CancellationReason.UNKNOWN
|
|
83
93
|
self._paused = asyncio.Event()
|
|
94
|
+
self._stream_closed = asyncio.Event()
|
|
84
95
|
self._heartbeat_interval = heartbeat_interval
|
|
85
96
|
|
|
97
|
+
# Exception tracking
|
|
98
|
+
self._last_exception = None
|
|
99
|
+
self._last_traceback = None
|
|
100
|
+
|
|
101
|
+
# Cleanup guard
|
|
102
|
+
self._cleanup_done = False
|
|
103
|
+
|
|
86
104
|
logger.info(
|
|
87
|
-
"
|
|
88
|
-
|
|
89
|
-
|
|
105
|
+
"TaskSession initialized",
|
|
106
|
+
extra={
|
|
107
|
+
"task_id": task_id,
|
|
108
|
+
"mission_id": mission_id,
|
|
109
|
+
"heartbeat_interval": str(heartbeat_interval),
|
|
110
|
+
},
|
|
90
111
|
)
|
|
91
112
|
|
|
92
113
|
@property
|
|
@@ -99,6 +120,39 @@ class TaskSession:
|
|
|
99
120
|
"""Task paused status."""
|
|
100
121
|
return self._paused.is_set()
|
|
101
122
|
|
|
123
|
+
@property
|
|
124
|
+
def stream_closed(self) -> bool:
|
|
125
|
+
"""Check if stream termination was signaled."""
|
|
126
|
+
return self._stream_closed.is_set()
|
|
127
|
+
|
|
128
|
+
def close_stream(self) -> None:
|
|
129
|
+
"""Signal that the stream should terminate."""
|
|
130
|
+
self._stream_closed.set()
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def setup_id(self) -> str:
|
|
134
|
+
"""Get setup_id from module context."""
|
|
135
|
+
return self.module.context.session.setup_id
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def setup_version_id(self) -> str:
|
|
139
|
+
"""Get setup_version_id from module context."""
|
|
140
|
+
return self.module.context.session.setup_version_id
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def session_ids(self) -> dict[str, str]:
|
|
144
|
+
"""Get all session IDs from module context for structured logging."""
|
|
145
|
+
return self.module.context.session.current_ids()
|
|
146
|
+
|
|
147
|
+
def record_exception(self, exc: Exception) -> None:
|
|
148
|
+
"""Record exception details for DB logging.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
exc: The exception that caused the task to fail.
|
|
152
|
+
"""
|
|
153
|
+
self._last_exception = str(exc)
|
|
154
|
+
self._last_traceback = traceback.format_exc()
|
|
155
|
+
|
|
102
156
|
async def send_heartbeat(self) -> bool:
|
|
103
157
|
"""Rate-limited heartbeat with connection resilience.
|
|
104
158
|
|
|
@@ -108,6 +162,8 @@ class TaskSession:
|
|
|
108
162
|
heartbeat = HeartbeatMessage(
|
|
109
163
|
task_id=self.task_id,
|
|
110
164
|
mission_id=self.mission_id,
|
|
165
|
+
setup_id=self.setup_id,
|
|
166
|
+
setup_version_id=self.setup_version_id,
|
|
111
167
|
timestamp=datetime.datetime.now(datetime.timezone.utc),
|
|
112
168
|
)
|
|
113
169
|
|
|
@@ -120,23 +176,17 @@ class TaskSession:
|
|
|
120
176
|
return True
|
|
121
177
|
except Exception as e:
|
|
122
178
|
logger.error(
|
|
123
|
-
"Heartbeat exception
|
|
124
|
-
self.
|
|
125
|
-
extra={"task_id": self.task_id, "error": str(e)},
|
|
179
|
+
"Heartbeat exception",
|
|
180
|
+
extra={**self.session_ids, "error": str(e)},
|
|
126
181
|
exc_info=True,
|
|
127
182
|
)
|
|
128
|
-
logger.error(
|
|
129
|
-
"Initial heartbeat failed for task: '%s'",
|
|
130
|
-
self.task_id,
|
|
131
|
-
extra={"task_id": self.task_id},
|
|
132
|
-
)
|
|
183
|
+
logger.error("Initial heartbeat failed", extra=self.session_ids)
|
|
133
184
|
return False
|
|
134
185
|
|
|
135
186
|
if (heartbeat.timestamp - self._last_heartbeat) < self._heartbeat_interval:
|
|
136
187
|
logger.debug(
|
|
137
|
-
"Heartbeat skipped due to rate limiting
|
|
138
|
-
self.
|
|
139
|
-
heartbeat.timestamp - self._last_heartbeat,
|
|
188
|
+
"Heartbeat skipped due to rate limiting",
|
|
189
|
+
extra={**self.session_ids, "delta": str(heartbeat.timestamp - self._last_heartbeat)},
|
|
140
190
|
)
|
|
141
191
|
return True
|
|
142
192
|
|
|
@@ -147,39 +197,24 @@ class TaskSession:
|
|
|
147
197
|
return True
|
|
148
198
|
except Exception as e:
|
|
149
199
|
logger.error(
|
|
150
|
-
"Heartbeat exception
|
|
151
|
-
self.
|
|
152
|
-
extra={"task_id": self.task_id, "error": str(e)},
|
|
200
|
+
"Heartbeat exception",
|
|
201
|
+
extra={**self.session_ids, "error": str(e)},
|
|
153
202
|
exc_info=True,
|
|
154
203
|
)
|
|
155
|
-
logger.warning(
|
|
156
|
-
"Heartbeat failed for task: '%s'",
|
|
157
|
-
self.task_id,
|
|
158
|
-
extra={"task_id": self.task_id},
|
|
159
|
-
)
|
|
204
|
+
logger.warning("Heartbeat failed", extra=self.session_ids)
|
|
160
205
|
return False
|
|
161
206
|
|
|
162
207
|
async def generate_heartbeats(self) -> None:
|
|
163
208
|
"""Periodic heartbeat generator with cancellation support."""
|
|
164
|
-
logger.debug(
|
|
165
|
-
"Heartbeat generator started for task: '%s'",
|
|
166
|
-
self.task_id,
|
|
167
|
-
extra={"task_id": self.task_id, "mission_id": self.mission_id},
|
|
168
|
-
)
|
|
209
|
+
logger.debug("Heartbeat generator started", extra=self.session_ids)
|
|
169
210
|
while not self.cancelled:
|
|
170
211
|
logger.debug(
|
|
171
|
-
"Heartbeat tick
|
|
172
|
-
self.
|
|
173
|
-
self.cancelled,
|
|
174
|
-
extra={"task_id": self.task_id, "mission_id": self.mission_id},
|
|
212
|
+
"Heartbeat tick",
|
|
213
|
+
extra={**self.session_ids, "cancelled": self.cancelled},
|
|
175
214
|
)
|
|
176
215
|
success = await self.send_heartbeat()
|
|
177
216
|
if not success:
|
|
178
|
-
logger.error(
|
|
179
|
-
"Heartbeat failed, cancelling task: '%s'",
|
|
180
|
-
self.task_id,
|
|
181
|
-
extra={"task_id": self.task_id, "mission_id": self.mission_id},
|
|
182
|
-
)
|
|
217
|
+
logger.error("Heartbeat failed, cancelling task", extra=self.session_ids)
|
|
183
218
|
await self._handle_cancel(CancellationReason.HEARTBEAT_FAILURE)
|
|
184
219
|
break
|
|
185
220
|
await asyncio.sleep(self._heartbeat_interval.total_seconds())
|
|
@@ -187,32 +222,32 @@ class TaskSession:
|
|
|
187
222
|
async def wait_if_paused(self) -> None:
|
|
188
223
|
"""Block execution if task is paused."""
|
|
189
224
|
if self._paused.is_set():
|
|
190
|
-
logger.info(
|
|
191
|
-
"Task paused, waiting for resume: '%s'",
|
|
192
|
-
self.task_id,
|
|
193
|
-
extra={"task_id": self.task_id},
|
|
194
|
-
)
|
|
225
|
+
logger.info("Task paused, waiting for resume", extra=self.session_ids)
|
|
195
226
|
await self._paused.wait()
|
|
196
227
|
|
|
197
228
|
async def listen_signals(self) -> None: # noqa: C901
|
|
198
229
|
"""Enhanced signal listener with comprehensive handling.
|
|
199
230
|
|
|
200
231
|
Raises:
|
|
201
|
-
CancelledError:
|
|
232
|
+
CancelledError: If task is cancelled during signal listening.
|
|
202
233
|
"""
|
|
203
|
-
logger.info(
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
)
|
|
234
|
+
logger.info("Signal listener started", extra=self.session_ids)
|
|
235
|
+
|
|
236
|
+
# signal_record_id must be set by TaskExecutor before calling this method.
|
|
237
|
+
# If not set, we cannot filter signals correctly - abort early.
|
|
208
238
|
if self.signal_record_id is None:
|
|
209
|
-
|
|
239
|
+
logger.error(
|
|
240
|
+
"signal_record_id not set - cannot start signal listener without valid record ID",
|
|
241
|
+
extra=self.session_ids,
|
|
242
|
+
)
|
|
243
|
+
return
|
|
210
244
|
|
|
211
245
|
live_id, live_signals = await self.db.start_live("tasks")
|
|
212
246
|
try:
|
|
213
247
|
async for signal in live_signals:
|
|
214
|
-
logger.debug("Signal received
|
|
215
|
-
|
|
248
|
+
logger.debug("Signal received", extra={**self.session_ids, "signal": signal})
|
|
249
|
+
# Check both cancelled and stream_closed to ensure clean shutdown
|
|
250
|
+
if self.cancelled or self.stream_closed:
|
|
216
251
|
break
|
|
217
252
|
|
|
218
253
|
if signal is None or signal["id"] == self.signal_record_id or "payload" not in signal:
|
|
@@ -228,26 +263,18 @@ class TaskSession:
|
|
|
228
263
|
await self._handle_status_request()
|
|
229
264
|
|
|
230
265
|
except asyncio.CancelledError:
|
|
231
|
-
logger.debug(
|
|
232
|
-
"Signal listener cancelled for task: '%s'",
|
|
233
|
-
self.task_id,
|
|
234
|
-
extra={"task_id": self.task_id},
|
|
235
|
-
)
|
|
266
|
+
logger.debug("Signal listener cancelled", extra=self.session_ids)
|
|
236
267
|
raise
|
|
237
268
|
except Exception as e:
|
|
238
269
|
logger.error(
|
|
239
|
-
"Signal listener fatal error
|
|
240
|
-
self.
|
|
241
|
-
extra={"task_id": self.task_id, "error": str(e)},
|
|
270
|
+
"Signal listener fatal error",
|
|
271
|
+
extra={**self.session_ids, "error": str(e)},
|
|
242
272
|
exc_info=True,
|
|
243
273
|
)
|
|
244
274
|
finally:
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
self.task_id,
|
|
249
|
-
extra={"task_id": self.task_id},
|
|
250
|
-
)
|
|
275
|
+
with contextlib.suppress(Exception): # Connection may already be closed
|
|
276
|
+
await self.db.stop_live(live_id)
|
|
277
|
+
logger.info("Signal listener stopped", extra=self.session_ids)
|
|
251
278
|
|
|
252
279
|
async def _handle_cancel(self, reason: CancellationReason = CancellationReason.UNKNOWN) -> None:
|
|
253
280
|
"""Idempotent cancellation with acknowledgment and reason tracking.
|
|
@@ -257,13 +284,9 @@ class TaskSession:
|
|
|
257
284
|
"""
|
|
258
285
|
if self.is_cancelled.is_set():
|
|
259
286
|
logger.debug(
|
|
260
|
-
"Cancel ignored -
|
|
261
|
-
self.task_id,
|
|
262
|
-
self.cancellation_reason.value,
|
|
263
|
-
reason.value,
|
|
287
|
+
"Cancel ignored - already cancelled",
|
|
264
288
|
extra={
|
|
265
|
-
|
|
266
|
-
"mission_id": self.mission_id,
|
|
289
|
+
**self.session_ids,
|
|
267
290
|
"existing_reason": self.cancellation_reason.value,
|
|
268
291
|
"new_reason": reason.value,
|
|
269
292
|
},
|
|
@@ -277,25 +300,13 @@ class TaskSession:
|
|
|
277
300
|
# Log with appropriate level based on reason
|
|
278
301
|
if reason in {CancellationReason.SUCCESS_CLEANUP, CancellationReason.FAILURE_CLEANUP}:
|
|
279
302
|
logger.debug(
|
|
280
|
-
"Task cancelled (cleanup)
|
|
281
|
-
self.
|
|
282
|
-
reason.value,
|
|
283
|
-
extra={
|
|
284
|
-
"task_id": self.task_id,
|
|
285
|
-
"mission_id": self.mission_id,
|
|
286
|
-
"cancellation_reason": reason.value,
|
|
287
|
-
},
|
|
303
|
+
"Task cancelled (cleanup)",
|
|
304
|
+
extra={**self.session_ids, "cancellation_reason": reason.value},
|
|
288
305
|
)
|
|
289
306
|
else:
|
|
290
307
|
logger.info(
|
|
291
|
-
"Task cancelled
|
|
292
|
-
self.
|
|
293
|
-
reason.value,
|
|
294
|
-
extra={
|
|
295
|
-
"task_id": self.task_id,
|
|
296
|
-
"mission_id": self.mission_id,
|
|
297
|
-
"cancellation_reason": reason.value,
|
|
298
|
-
},
|
|
308
|
+
"Task cancelled",
|
|
309
|
+
extra={**self.session_ids, "cancellation_reason": reason.value},
|
|
299
310
|
)
|
|
300
311
|
|
|
301
312
|
# Resume if paused so cancellation can proceed
|
|
@@ -308,19 +319,18 @@ class TaskSession:
|
|
|
308
319
|
SignalMessage(
|
|
309
320
|
task_id=self.task_id,
|
|
310
321
|
mission_id=self.mission_id,
|
|
322
|
+
setup_id=self.setup_id,
|
|
323
|
+
setup_version_id=self.setup_version_id,
|
|
311
324
|
action=SignalType.ACK_CANCEL,
|
|
312
325
|
status=self.status,
|
|
326
|
+
cancellation_reason=reason,
|
|
313
327
|
).model_dump(),
|
|
314
328
|
)
|
|
315
329
|
|
|
316
330
|
async def _handle_pause(self) -> None:
|
|
317
331
|
"""Pause task execution."""
|
|
318
332
|
if not self._paused.is_set():
|
|
319
|
-
logger.info(
|
|
320
|
-
"Pausing task: '%s'",
|
|
321
|
-
self.task_id,
|
|
322
|
-
extra={"task_id": self.task_id},
|
|
323
|
-
)
|
|
333
|
+
logger.info("Task paused", extra=self.session_ids)
|
|
324
334
|
self._paused.set()
|
|
325
335
|
|
|
326
336
|
await self.db.update(
|
|
@@ -329,6 +339,8 @@ class TaskSession:
|
|
|
329
339
|
SignalMessage(
|
|
330
340
|
task_id=self.task_id,
|
|
331
341
|
mission_id=self.mission_id,
|
|
342
|
+
setup_id=self.setup_id,
|
|
343
|
+
setup_version_id=self.setup_version_id,
|
|
332
344
|
action=SignalType.ACK_PAUSE,
|
|
333
345
|
status=self.status,
|
|
334
346
|
).model_dump(),
|
|
@@ -337,11 +349,7 @@ class TaskSession:
|
|
|
337
349
|
async def _handle_resume(self) -> None:
|
|
338
350
|
"""Resume paused task."""
|
|
339
351
|
if self._paused.is_set():
|
|
340
|
-
logger.info(
|
|
341
|
-
"Resuming task: '%s'",
|
|
342
|
-
self.task_id,
|
|
343
|
-
extra={"task_id": self.task_id},
|
|
344
|
-
)
|
|
352
|
+
logger.info("Task resumed", extra=self.session_ids)
|
|
345
353
|
self._paused.clear()
|
|
346
354
|
|
|
347
355
|
await self.db.update(
|
|
@@ -350,6 +358,8 @@ class TaskSession:
|
|
|
350
358
|
SignalMessage(
|
|
351
359
|
task_id=self.task_id,
|
|
352
360
|
mission_id=self.mission_id,
|
|
361
|
+
setup_id=self.setup_id,
|
|
362
|
+
setup_version_id=self.setup_version_id,
|
|
353
363
|
action=SignalType.ACK_RESUME,
|
|
354
364
|
status=self.status,
|
|
355
365
|
).model_dump(),
|
|
@@ -361,28 +371,38 @@ class TaskSession:
|
|
|
361
371
|
"tasks",
|
|
362
372
|
self.signal_record_id, # type: ignore
|
|
363
373
|
SignalMessage(
|
|
364
|
-
mission_id=self.mission_id,
|
|
365
374
|
task_id=self.task_id,
|
|
375
|
+
mission_id=self.mission_id,
|
|
376
|
+
setup_id=self.setup_id,
|
|
377
|
+
setup_version_id=self.setup_version_id,
|
|
366
378
|
status=self.status,
|
|
367
379
|
action=SignalType.ACK_STATUS,
|
|
368
380
|
).model_dump(),
|
|
369
381
|
)
|
|
370
382
|
|
|
371
|
-
logger.debug(
|
|
372
|
-
"Status report sent for task: '%s'",
|
|
373
|
-
self.task_id,
|
|
374
|
-
extra={"task_id": self.task_id},
|
|
375
|
-
)
|
|
383
|
+
logger.debug("Status report sent", extra=self.session_ids)
|
|
376
384
|
|
|
377
385
|
async def cleanup(self) -> None:
|
|
378
386
|
"""Clean up task session resources.
|
|
379
387
|
|
|
388
|
+
This method is idempotent - safe to call multiple times.
|
|
389
|
+
Second and subsequent calls are no-ops.
|
|
390
|
+
|
|
380
391
|
This includes:
|
|
381
392
|
- Clearing queue to free memory
|
|
393
|
+
- Cleaning up module context services
|
|
382
394
|
- Stopping module
|
|
383
395
|
- Closing database connection
|
|
384
396
|
- Clearing module reference
|
|
385
397
|
"""
|
|
398
|
+
if self._cleanup_done:
|
|
399
|
+
logger.debug(
|
|
400
|
+
"Cleanup already done, skipping",
|
|
401
|
+
extra={"task_id": self.task_id, "mission_id": self.mission_id},
|
|
402
|
+
)
|
|
403
|
+
return
|
|
404
|
+
self._cleanup_done = True
|
|
405
|
+
|
|
386
406
|
# Clear queue to free memory
|
|
387
407
|
try:
|
|
388
408
|
while not self.queue.empty():
|
|
@@ -390,6 +410,16 @@ class TaskSession:
|
|
|
390
410
|
except asyncio.QueueEmpty:
|
|
391
411
|
pass
|
|
392
412
|
|
|
413
|
+
# Clean up module context services (e.g., gRPC channel pool)
|
|
414
|
+
if self.module is not None and self.module.context is not None:
|
|
415
|
+
try:
|
|
416
|
+
await self.module.context.cleanup()
|
|
417
|
+
except Exception:
|
|
418
|
+
logger.exception(
|
|
419
|
+
"Error cleaning up module context",
|
|
420
|
+
extra={"mission_id": self.mission_id, "task_id": self.task_id},
|
|
421
|
+
)
|
|
422
|
+
|
|
393
423
|
# Stop module
|
|
394
424
|
try:
|
|
395
425
|
await self.module.stop()
|