digitalkin 0.2.23__py3-none-any.whl → 0.3.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/core/__init__.py +1 -0
  3. digitalkin/core/common/__init__.py +9 -0
  4. digitalkin/core/common/factories.py +156 -0
  5. digitalkin/core/job_manager/__init__.py +1 -0
  6. digitalkin/{modules → core}/job_manager/base_job_manager.py +137 -31
  7. digitalkin/core/job_manager/single_job_manager.py +354 -0
  8. digitalkin/{modules → core}/job_manager/taskiq_broker.py +116 -22
  9. digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
  10. digitalkin/core/task_manager/__init__.py +1 -0
  11. digitalkin/core/task_manager/base_task_manager.py +539 -0
  12. digitalkin/core/task_manager/local_task_manager.py +108 -0
  13. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  14. digitalkin/core/task_manager/surrealdb_repository.py +266 -0
  15. digitalkin/core/task_manager/task_executor.py +249 -0
  16. digitalkin/core/task_manager/task_session.py +406 -0
  17. digitalkin/grpc_servers/__init__.py +1 -19
  18. digitalkin/grpc_servers/_base_server.py +3 -3
  19. digitalkin/grpc_servers/module_server.py +27 -43
  20. digitalkin/grpc_servers/module_servicer.py +51 -36
  21. digitalkin/grpc_servers/registry_server.py +2 -2
  22. digitalkin/grpc_servers/registry_servicer.py +4 -4
  23. digitalkin/grpc_servers/utils/__init__.py +1 -0
  24. digitalkin/grpc_servers/utils/exceptions.py +0 -8
  25. digitalkin/grpc_servers/utils/grpc_client_wrapper.py +4 -4
  26. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  27. digitalkin/logger.py +73 -24
  28. digitalkin/mixins/__init__.py +19 -0
  29. digitalkin/mixins/base_mixin.py +10 -0
  30. digitalkin/mixins/callback_mixin.py +24 -0
  31. digitalkin/mixins/chat_history_mixin.py +110 -0
  32. digitalkin/mixins/cost_mixin.py +76 -0
  33. digitalkin/mixins/file_history_mixin.py +93 -0
  34. digitalkin/mixins/filesystem_mixin.py +46 -0
  35. digitalkin/mixins/logger_mixin.py +51 -0
  36. digitalkin/mixins/storage_mixin.py +79 -0
  37. digitalkin/models/core/__init__.py +1 -0
  38. digitalkin/{modules/job_manager → models/core}/job_manager_models.py +3 -3
  39. digitalkin/models/core/task_monitor.py +70 -0
  40. digitalkin/models/grpc_servers/__init__.py +1 -0
  41. digitalkin/{grpc_servers/utils → models/grpc_servers}/models.py +5 -5
  42. digitalkin/models/module/__init__.py +2 -0
  43. digitalkin/models/module/module.py +9 -1
  44. digitalkin/models/module/module_context.py +122 -6
  45. digitalkin/models/module/module_types.py +307 -19
  46. digitalkin/models/services/__init__.py +9 -0
  47. digitalkin/models/services/cost.py +1 -0
  48. digitalkin/models/services/storage.py +39 -5
  49. digitalkin/modules/_base_module.py +123 -118
  50. digitalkin/modules/tool_module.py +10 -2
  51. digitalkin/modules/trigger_handler.py +7 -6
  52. digitalkin/services/cost/__init__.py +9 -2
  53. digitalkin/services/cost/grpc_cost.py +9 -42
  54. digitalkin/services/filesystem/default_filesystem.py +0 -2
  55. digitalkin/services/filesystem/grpc_filesystem.py +10 -39
  56. digitalkin/services/setup/default_setup.py +5 -6
  57. digitalkin/services/setup/grpc_setup.py +52 -15
  58. digitalkin/services/storage/grpc_storage.py +4 -4
  59. digitalkin/services/user_profile/__init__.py +1 -0
  60. digitalkin/services/user_profile/default_user_profile.py +55 -0
  61. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  62. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  63. digitalkin/utils/__init__.py +28 -0
  64. digitalkin/utils/arg_parser.py +1 -1
  65. digitalkin/utils/development_mode_action.py +2 -2
  66. digitalkin/utils/dynamic_schema.py +483 -0
  67. digitalkin/utils/package_discover.py +1 -2
  68. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/METADATA +11 -30
  69. digitalkin-0.3.1.dev2.dist-info/RECORD +119 -0
  70. modules/dynamic_setup_module.py +362 -0
  71. digitalkin/grpc_servers/utils/factory.py +0 -180
  72. digitalkin/modules/job_manager/single_job_manager.py +0 -294
  73. digitalkin/modules/job_manager/taskiq_job_manager.py +0 -290
  74. digitalkin-0.2.23.dist-info/RECORD +0 -89
  75. /digitalkin/{grpc_servers/utils → models/grpc_servers}/types.py +0 -0
  76. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/WHEEL +0 -0
  77. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/licenses/LICENSE +0 -0
  78. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,87 @@
1
+ """Remote task manager for distributed execution."""
2
+
3
+ import datetime
4
+ from collections.abc import Coroutine
5
+ from typing import Any
6
+
7
+ from digitalkin.core.task_manager.base_task_manager import BaseTaskManager
8
+ from digitalkin.logger import logger
9
+ from digitalkin.modules._base_module import BaseModule
10
+
11
+
12
+ class RemoteTaskManager(BaseTaskManager):
13
+ """Task manager for distributed/remote execution.
14
+
15
+ Only manages task metadata and signals - actual execution happens in remote workers.
16
+ Suitable for horizontally scaled deployments with Taskiq/Celery workers.
17
+ """
18
+
19
+ async def create_task(
20
+ self,
21
+ task_id: str,
22
+ mission_id: str,
23
+ module: BaseModule,
24
+ coro: Coroutine[Any, Any, None],
25
+ heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
26
+ connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
27
+ ) -> None:
28
+ """Register task for remote execution (metadata only).
29
+
30
+ Creates TaskSession for signal handling and monitoring, but doesn't execute the coroutine.
31
+ The coroutine will be recreated and executed by a remote worker.
32
+
33
+ Args:
34
+ task_id: Unique identifier for the task
35
+ mission_id: Mission identifier
36
+ module: Module instance for metadata (not executed here)
37
+ coro: Coroutine (will be closed - execution happens in worker)
38
+ heartbeat_interval: Interval between heartbeats
39
+ connection_timeout: Connection timeout for SurrealDB
40
+
41
+ Raises:
42
+ ValueError: If task_id duplicated
43
+ RuntimeError: If task overload
44
+ """
45
+ # Validation
46
+ await self._validate_task_creation(task_id, mission_id, coro)
47
+
48
+ logger.info(
49
+ "Registering remote task: '%s'",
50
+ task_id,
51
+ extra={
52
+ "mission_id": mission_id,
53
+ "task_id": task_id,
54
+ "heartbeat_interval": heartbeat_interval,
55
+ "connection_timeout": connection_timeout,
56
+ },
57
+ )
58
+
59
+ try:
60
+ # Create session for metadata and signal handling
61
+ _channel, _session = await self._create_session(
62
+ task_id, mission_id, module, heartbeat_interval, connection_timeout
63
+ )
64
+
65
+ # Close coroutine - worker will recreate and execute it
66
+ coro.close()
67
+
68
+ logger.info(
69
+ "Remote task registered: '%s'",
70
+ task_id,
71
+ extra={
72
+ "mission_id": mission_id,
73
+ "task_id": task_id,
74
+ "total_sessions": len(self.tasks_sessions),
75
+ },
76
+ )
77
+
78
+ except Exception as e:
79
+ logger.error(
80
+ "Failed to register remote task: '%s'",
81
+ task_id,
82
+ extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
83
+ exc_info=True,
84
+ )
85
+ # Cleanup on failure
86
+ await self._cleanup_task(task_id, mission_id=mission_id)
87
+ raise
@@ -0,0 +1,266 @@
1
+ """SurrealDB connection management."""
2
+
3
+ import asyncio
4
+ import datetime
5
+ import os
6
+ from collections.abc import AsyncGenerator
7
+ from typing import Any, Generic, TypeVar
8
+ from uuid import UUID
9
+
10
+ from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
11
+
12
+ from digitalkin.logger import logger
13
+
14
+ TSurreal = TypeVar("TSurreal", bound=AsyncHttpSurrealConnection | AsyncWsSurrealConnection)
15
+
16
+
17
+ class SurrealDBSetupBadIDError(Exception):
18
+ """Exception raised when an invalid ID is encountered during the setup process in the SurrealDB repository.
19
+
20
+ This error is used to indicate that the provided ID does not meet the
21
+ expected format or criteria.
22
+ """
23
+
24
+
25
+ class SurrealDBSetupVersionBadIDError(Exception):
26
+ """Exception raised when an invalid ID is encountered during the setup of a SurrealDB version.
27
+
28
+ This error is intended to signal that the provided ID does not meet
29
+ the expected format or criteria for a valid SurrealDB setup version ID.
30
+ """
31
+
32
+
33
+ class SurrealDBConnection(Generic[TSurreal]):
34
+ """Base repository for database operations.
35
+
36
+ This class provides common database operations that can be used by
37
+ specific table repositories.
38
+ """
39
+
40
+ db: TSurreal
41
+ timeout: datetime.timedelta
42
+ _live_queries: set[UUID] # Track active live queries for cleanup
43
+
44
+ @staticmethod
45
+ def _valid_id(raw_id: str, table_name: str) -> RecordID:
46
+ """Validate and parse a raw ID string into a RecordID.
47
+
48
+ Args:
49
+ raw_id: The raw ID string to validate
50
+ table_name: table name to enforce
51
+
52
+ Raises:
53
+ SurrealDBSetupBadIDError: If the raw ID string is not valid
54
+
55
+ Returns:
56
+ RecordID: Parsed RecordID object if valid, None otherwise
57
+ """
58
+ try:
59
+ split_id = raw_id.split(":")
60
+ if split_id[0] != table_name:
61
+ msg = f"Invalid table name for ID: {raw_id}"
62
+ raise SurrealDBSetupBadIDError(msg)
63
+ return RecordID(split_id[0], split_id[1])
64
+ except IndexError:
65
+ raise SurrealDBSetupBadIDError
66
+
67
+ def __init__(
68
+ self,
69
+ database: str | None = None,
70
+ timeout: datetime.timedelta = datetime.timedelta(seconds=5),
71
+ ) -> None:
72
+ """Initialize the repository.
73
+
74
+ Args:
75
+ database: AsyncSurrealDB connection to a specific database
76
+ timeout: Timeout for database operations
77
+ """
78
+ self.timeout = timeout
79
+ base_url = os.getenv("SURREALDB_URL", "ws://localhost").strip()
80
+ port = (os.getenv("SURREALDB_PORT") or "").strip()
81
+ self.url = f"{base_url}{f':{port}' if port else ''}/rpc"
82
+
83
+ self.username = os.getenv("SURREALDB_USERNAME", "root")
84
+ self.password = os.getenv("SURREALDB_PASSWORD", "root")
85
+ self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
86
+ self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
87
+ self._live_queries = set() # Initialize live queries tracker
88
+
89
+ async def init_surreal_instance(self) -> None:
90
+ """Init a SurrealDB connection instance."""
91
+ logger.debug("Connecting to SurrealDB at %s", self.url)
92
+ self.db = AsyncSurreal(self.url) # type: ignore
93
+ await self.db.signin({"username": self.username, "password": self.password})
94
+ await self.db.use(self.namespace, self.database)
95
+ logger.debug("Successfully connected to SurrealDB")
96
+
97
+ async def close(self) -> None:
98
+ """Close the SurrealDB connection if it exists.
99
+
100
+ This will also kill all active live queries to prevent memory leaks.
101
+ """
102
+ # Kill all tracked live queries before closing connection
103
+ if self._live_queries:
104
+ logger.debug("Killing %d active live queries before closing", len(self._live_queries))
105
+ live_query_ids = list(self._live_queries)
106
+
107
+ # Kill all queries concurrently, capturing any exceptions
108
+ results = await asyncio.gather(
109
+ *[self.db.kill(live_id) for live_id in live_query_ids], return_exceptions=True
110
+ )
111
+
112
+ # Process results and track failures
113
+ failed_queries = []
114
+ for live_id, result in zip(live_query_ids, results):
115
+ if isinstance(result, (ConnectionError, TimeoutError, Exception)):
116
+ failed_queries.append((live_id, str(result)))
117
+ else:
118
+ self._live_queries.discard(live_id)
119
+
120
+ # Log aggregated failures once instead of per-query
121
+ if failed_queries:
122
+ logger.warning(
123
+ "Failed to kill %d live queries: %s",
124
+ len(failed_queries),
125
+ failed_queries[:5], # Only log first 5 to avoid log spam
126
+ extra={"total_failed": len(failed_queries)},
127
+ )
128
+
129
+ logger.debug("Closing SurrealDB connection")
130
+ await self.db.close()
131
+
132
+ async def create(
133
+ self,
134
+ table_name: str,
135
+ data: dict[str, Any],
136
+ ) -> list[dict[str, Any]] | dict[str, Any]:
137
+ """Create a new record.
138
+
139
+ Args:
140
+ table_name: Name of the table to insert into
141
+ data: Data to insert
142
+
143
+ Returns:
144
+ Dict[str, Any]: The created record as returned by the database
145
+ """
146
+ logger.debug("Creating record in %s with data: %s", table_name, data)
147
+ result = await self.db.create(table_name, data)
148
+ logger.debug("create result: %s", result)
149
+ return result
150
+
151
+ async def merge(
152
+ self,
153
+ table_name: str,
154
+ record_id: str | RecordID,
155
+ data: dict[str, Any],
156
+ ) -> list[dict[str, Any]] | dict[str, Any]:
157
+ """Update an existing record.
158
+
159
+ Args:
160
+ table_name: Name of the table to insert into
161
+ record_id: record ID to update
162
+ data: Data to insert
163
+
164
+ Returns:
165
+ Dict[str, Any]: The created record as returned by the database
166
+ """
167
+ if isinstance(record_id, str):
168
+ # validate surrealDB id if raw str
169
+ record_id = self._valid_id(record_id, table_name)
170
+ logger.debug("Updating record in %s with data: %s", record_id, data)
171
+ result = await self.db.merge(record_id, data)
172
+ logger.debug("update result: %s", result)
173
+ return result
174
+
175
+ async def update(
176
+ self,
177
+ table_name: str,
178
+ record_id: str | RecordID,
179
+ data: dict[str, Any],
180
+ ) -> list[dict[str, Any]] | dict[str, Any]:
181
+ """Update an existing record.
182
+
183
+ Args:
184
+ table_name: Name of the table to insert into
185
+ record_id: record ID to update
186
+ data: Data to insert
187
+
188
+ Returns:
189
+ Dict[str, Any]: The created record as returned by the database
190
+ """
191
+ if isinstance(record_id, str):
192
+ # validate surrealDB id if raw str
193
+ record_id = self._valid_id(record_id, table_name)
194
+ logger.debug("Updating record in %s with data: %s", record_id, data)
195
+ result = await self.db.update(record_id, data)
196
+ logger.debug("update result: %s", result)
197
+ return result
198
+
199
+ async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
200
+ """Execute a custom SurrealQL query.
201
+
202
+ Args:
203
+ query: SurrealQL query
204
+ params: Query parameters
205
+
206
+ Returns:
207
+ List[Dict[str, Any]]: Query results
208
+ """
209
+ logger.debug("execute_query: %s with params: %s", query, params)
210
+ result = await self.db.query(query, params or {})
211
+ logger.debug("execute_query result: %s", result)
212
+ return [result] if isinstance(result, dict) else result
213
+
214
+ async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
215
+ """Fetch a record from a table by a unique field.
216
+
217
+ Args:
218
+ table: Table name
219
+ value: Field value to match
220
+
221
+ Raises:
222
+ ValueError: If no records are found
223
+
224
+ Returns:
225
+ Dict with record data if found, else None
226
+ """
227
+ query = "SELECT * FROM type::table($table) WHERE task_id = $value;"
228
+ params = {"table": table, "value": value}
229
+
230
+ result = await self.execute_query(query, params)
231
+ if not result:
232
+ msg = f"No records found in table '{table}' with task_id '{value}'"
233
+ logger.error(msg)
234
+ raise ValueError(msg)
235
+
236
+ return result[0]
237
+
238
+ async def start_live(
239
+ self,
240
+ table_name: str,
241
+ ) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
242
+ """Create and subscribe to a live SurrealQL query.
243
+
244
+ The live query ID is tracked to ensure proper cleanup on connection close.
245
+
246
+ Args:
247
+ table_name: Name of the table to insert into
248
+
249
+ Returns:
250
+ tuple[UUID, AsyncGenerator]: Live query ID and subscription generator
251
+ """
252
+ live_id = await self.db.live(table_name, diff=False)
253
+ self._live_queries.add(live_id) # Track for cleanup
254
+ logger.debug("Started live query %s for table %s (total: %d)", live_id, table_name, len(self._live_queries))
255
+ return live_id, await self.db.subscribe_live(live_id)
256
+
257
+ async def stop_live(self, live_id: UUID) -> None:
258
+ """Kill a live SurrealQL query.
259
+
260
+ Args:
261
+ live_id: Live query ID to kill
262
+ """
263
+ logger.debug("Killing live query: %s", live_id)
264
+ await self.db.kill(live_id)
265
+ self._live_queries.discard(live_id) # Remove from tracker
266
+ logger.debug("Stopped live query %s (remaining: %d)", live_id, len(self._live_queries))
@@ -0,0 +1,249 @@
1
+ """Task executor for running tasks with full lifecycle management."""
2
+
3
+ import asyncio
4
+ import datetime
5
+ from collections.abc import Coroutine
6
+ from typing import Any
7
+
8
+ from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
9
+ from digitalkin.core.task_manager.task_session import TaskSession
10
+ from digitalkin.logger import logger
11
+ from digitalkin.models.core.task_monitor import (
12
+ CancellationReason,
13
+ SignalMessage,
14
+ SignalType,
15
+ TaskStatus,
16
+ )
17
+
18
+
19
+ class TaskExecutor:
20
+ """Executes tasks with the supervisor pattern (main + heartbeat + signal listener).
21
+
22
+ Pure execution logic - no task registry or orchestration.
23
+ Used by workers to run distributed tasks or by TaskManager for local execution.
24
+ """
25
+
26
+ @staticmethod
27
+ async def execute_task( # noqa: C901, PLR0915
28
+ task_id: str,
29
+ mission_id: str,
30
+ coro: Coroutine[Any, Any, None],
31
+ session: TaskSession,
32
+ channel: SurrealDBConnection,
33
+ ) -> asyncio.Task[None]:
34
+ """Execute a task using the supervisor pattern.
35
+
36
+ Runs three concurrent sub-tasks:
37
+ - Main coroutine (the actual work)
38
+ - Heartbeat generator (sends heartbeats to SurrealDB)
39
+ - Signal listener (watches for stop/pause/resume signals)
40
+
41
+ The first task to complete determines the outcome.
42
+
43
+ Args:
44
+ task_id: Unique identifier for the task
45
+ mission_id: Mission identifier for the task
46
+ coro: The coroutine to execute (module.start(...))
47
+ session: TaskSession for state management
48
+ channel: SurrealDB connection for signals
49
+
50
+ Returns:
51
+ asyncio.Task: The supervisor task managing the lifecycle
52
+ """
53
+
54
+ async def signal_wrapper() -> None:
55
+ """Create initial signal record and listen for signals."""
56
+ try:
57
+ await channel.create(
58
+ "tasks",
59
+ SignalMessage(
60
+ task_id=task_id,
61
+ mission_id=mission_id,
62
+ status=session.status,
63
+ action=SignalType.START,
64
+ ).model_dump(),
65
+ )
66
+ await session.listen_signals()
67
+ except asyncio.CancelledError:
68
+ logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
69
+ finally:
70
+ await channel.create(
71
+ "tasks",
72
+ SignalMessage(
73
+ task_id=task_id,
74
+ mission_id=mission_id,
75
+ status=session.status,
76
+ action=SignalType.STOP,
77
+ ).model_dump(),
78
+ )
79
+ logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
80
+
81
+ async def heartbeat_wrapper() -> None:
82
+ """Generate heartbeats for task health monitoring."""
83
+ try:
84
+ await session.generate_heartbeats()
85
+ except asyncio.CancelledError:
86
+ logger.debug("Heartbeat cancelled", extra={"mission_id": mission_id, "task_id": task_id})
87
+ finally:
88
+ logger.info("Heartbeat task ended", extra={"mission_id": mission_id, "task_id": task_id})
89
+
90
+ async def supervisor() -> None: # noqa: C901, PLR0912, PLR0915
91
+ """Supervise the three concurrent tasks and handle outcomes.
92
+
93
+ Raises:
94
+ RuntimeError: If the heartbeat task stops unexpectedly.
95
+ asyncio.CancelledError: If the supervisor task is cancelled.
96
+ """
97
+ session.started_at = datetime.datetime.now(datetime.timezone.utc)
98
+ session.status = TaskStatus.RUNNING
99
+
100
+ # Create tasks with proper exception handling
101
+ main_task = None
102
+ hb_task = None
103
+ sig_task = None
104
+ cleanup_reason = CancellationReason.UNKNOWN
105
+
106
+ try:
107
+ main_task = asyncio.create_task(coro, name=f"{task_id}_main")
108
+ hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
109
+ sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
110
+ done, pending = await asyncio.wait(
111
+ [main_task, sig_task, hb_task],
112
+ return_when=asyncio.FIRST_COMPLETED,
113
+ )
114
+
115
+ # Determine cleanup reason based on which task completed first
116
+ completed = next(iter(done))
117
+
118
+ if completed is main_task:
119
+ # Main task finished - cleanup is due to success
120
+ cleanup_reason = CancellationReason.SUCCESS_CLEANUP
121
+ elif completed is sig_task or (completed is hb_task and sig_task.done()):
122
+ # Signal task finished - external cancellation
123
+ cleanup_reason = CancellationReason.SIGNAL
124
+ elif completed is hb_task:
125
+ # Heartbeat stopped - failure cleanup
126
+ cleanup_reason = CancellationReason.FAILURE_CLEANUP
127
+
128
+ # Cancel pending tasks with proper reason logging
129
+ if pending:
130
+ pending_names = [t.get_name() for t in pending]
131
+ logger.debug(
132
+ "Cancelling pending tasks: %s, reason: %s",
133
+ pending_names,
134
+ cleanup_reason.value,
135
+ extra={
136
+ "mission_id": mission_id,
137
+ "task_id": task_id,
138
+ "pending_tasks": pending_names,
139
+ "cancellation_reason": cleanup_reason.value,
140
+ },
141
+ )
142
+ for t in pending:
143
+ t.cancel()
144
+
145
+ # Propagate exception/result from the finished task
146
+ await completed
147
+
148
+ # Determine final status based on which task completed
149
+ if completed is main_task:
150
+ session.status = TaskStatus.COMPLETED
151
+ logger.info(
152
+ "Main task completed successfully",
153
+ extra={"mission_id": mission_id, "task_id": task_id},
154
+ )
155
+ elif completed is sig_task or (completed is hb_task and sig_task.done()):
156
+ session.status = TaskStatus.CANCELLED
157
+ session.cancellation_reason = CancellationReason.SIGNAL
158
+ logger.info(
159
+ "Task cancelled via external signal",
160
+ extra={
161
+ "mission_id": mission_id,
162
+ "task_id": task_id,
163
+ "cancellation_reason": CancellationReason.SIGNAL.value,
164
+ },
165
+ )
166
+ elif completed is hb_task:
167
+ session.status = TaskStatus.FAILED
168
+ session.cancellation_reason = CancellationReason.HEARTBEAT_FAILURE
169
+ logger.error(
170
+ "Heartbeat stopped unexpectedly for task: '%s'",
171
+ task_id,
172
+ extra={
173
+ "mission_id": mission_id,
174
+ "task_id": task_id,
175
+ "cancellation_reason": CancellationReason.HEARTBEAT_FAILURE.value,
176
+ },
177
+ )
178
+ msg = f"Heartbeat stopped for {task_id}"
179
+ raise RuntimeError(msg) # noqa: TRY301
180
+
181
+ except asyncio.CancelledError:
182
+ session.status = TaskStatus.CANCELLED
183
+ # Only set reason if not already set (preserve original reason)
184
+ logger.info(
185
+ "Task cancelled externally: '%s', reason: %s",
186
+ task_id,
187
+ session.cancellation_reason.value,
188
+ extra={
189
+ "mission_id": mission_id,
190
+ "task_id": task_id,
191
+ "cancellation_reason": session.cancellation_reason.value,
192
+ },
193
+ )
194
+ cleanup_reason = CancellationReason.FAILURE_CLEANUP
195
+ raise
196
+ except Exception:
197
+ session.status = TaskStatus.FAILED
198
+ cleanup_reason = CancellationReason.FAILURE_CLEANUP
199
+ logger.exception(
200
+ "Task failed with exception: '%s'",
201
+ task_id,
202
+ extra={"mission_id": mission_id, "task_id": task_id},
203
+ )
204
+ raise
205
+ finally:
206
+ session.completed_at = datetime.datetime.now(datetime.timezone.utc)
207
+ # Ensure all tasks are cleaned up with proper reason
208
+ tasks_to_cleanup = [t for t in [main_task, hb_task, sig_task] if t is not None and not t.done()]
209
+ if tasks_to_cleanup:
210
+ cleanup_names = [t.get_name() for t in tasks_to_cleanup]
211
+ logger.debug(
212
+ "Final cleanup of %d remaining tasks: %s, reason: %s",
213
+ len(tasks_to_cleanup),
214
+ cleanup_names,
215
+ cleanup_reason.value,
216
+ extra={
217
+ "mission_id": mission_id,
218
+ "task_id": task_id,
219
+ "cleanup_count": len(tasks_to_cleanup),
220
+ "cleanup_tasks": cleanup_names,
221
+ "cancellation_reason": cleanup_reason.value,
222
+ },
223
+ )
224
+ for t in tasks_to_cleanup:
225
+ t.cancel()
226
+ await asyncio.gather(*tasks_to_cleanup, return_exceptions=True)
227
+
228
+ duration = (
229
+ (session.completed_at - session.started_at).total_seconds()
230
+ if session.started_at and session.completed_at
231
+ else None
232
+ )
233
+ logger.info(
234
+ "Task execution completed: '%s', status: %s, reason: %s, duration: %.2fs",
235
+ task_id,
236
+ session.status.value,
237
+ session.cancellation_reason.value if session.status == TaskStatus.CANCELLED else "n/a",
238
+ duration or 0,
239
+ extra={
240
+ "mission_id": mission_id,
241
+ "task_id": task_id,
242
+ "status": session.status.value,
243
+ "cancellation_reason": session.cancellation_reason.value,
244
+ "duration": duration,
245
+ },
246
+ )
247
+
248
+ # Return the supervisor task to be awaited by caller
249
+ return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")