digitalkin 0.3.1.dev2__py3-none-any.whl → 0.3.2a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. base_server/server_async_insecure.py +6 -5
  2. base_server/server_async_secure.py +6 -5
  3. base_server/server_sync_insecure.py +5 -4
  4. base_server/server_sync_secure.py +5 -4
  5. digitalkin/__version__.py +1 -1
  6. digitalkin/core/job_manager/base_job_manager.py +1 -1
  7. digitalkin/core/job_manager/single_job_manager.py +78 -36
  8. digitalkin/core/job_manager/taskiq_broker.py +7 -6
  9. digitalkin/core/job_manager/taskiq_job_manager.py +9 -5
  10. digitalkin/core/task_manager/base_task_manager.py +3 -1
  11. digitalkin/core/task_manager/surrealdb_repository.py +29 -7
  12. digitalkin/core/task_manager/task_executor.py +46 -12
  13. digitalkin/core/task_manager/task_session.py +132 -102
  14. digitalkin/grpc_servers/module_server.py +95 -171
  15. digitalkin/grpc_servers/module_servicer.py +121 -19
  16. digitalkin/grpc_servers/utils/grpc_client_wrapper.py +36 -10
  17. digitalkin/grpc_servers/utils/utility_schema_extender.py +106 -0
  18. digitalkin/models/__init__.py +1 -1
  19. digitalkin/models/core/job_manager_models.py +0 -8
  20. digitalkin/models/core/task_monitor.py +23 -1
  21. digitalkin/models/grpc_servers/models.py +95 -8
  22. digitalkin/models/module/__init__.py +26 -13
  23. digitalkin/models/module/base_types.py +61 -0
  24. digitalkin/models/module/module_context.py +279 -13
  25. digitalkin/models/module/module_types.py +28 -392
  26. digitalkin/models/module/setup_types.py +547 -0
  27. digitalkin/models/module/tool_cache.py +230 -0
  28. digitalkin/models/module/tool_reference.py +160 -0
  29. digitalkin/models/module/utility.py +167 -0
  30. digitalkin/models/services/cost.py +22 -1
  31. digitalkin/models/services/registry.py +77 -0
  32. digitalkin/modules/__init__.py +5 -1
  33. digitalkin/modules/_base_module.py +188 -63
  34. digitalkin/modules/archetype_module.py +6 -1
  35. digitalkin/modules/tool_module.py +6 -1
  36. digitalkin/modules/triggers/__init__.py +8 -0
  37. digitalkin/modules/triggers/healthcheck_ping_trigger.py +45 -0
  38. digitalkin/modules/triggers/healthcheck_services_trigger.py +63 -0
  39. digitalkin/modules/triggers/healthcheck_status_trigger.py +52 -0
  40. digitalkin/services/__init__.py +4 -0
  41. digitalkin/services/communication/__init__.py +7 -0
  42. digitalkin/services/communication/communication_strategy.py +87 -0
  43. digitalkin/services/communication/default_communication.py +104 -0
  44. digitalkin/services/communication/grpc_communication.py +264 -0
  45. digitalkin/services/cost/cost_strategy.py +36 -14
  46. digitalkin/services/cost/default_cost.py +61 -1
  47. digitalkin/services/cost/grpc_cost.py +98 -2
  48. digitalkin/services/filesystem/grpc_filesystem.py +9 -2
  49. digitalkin/services/registry/__init__.py +22 -1
  50. digitalkin/services/registry/default_registry.py +156 -4
  51. digitalkin/services/registry/exceptions.py +47 -0
  52. digitalkin/services/registry/grpc_registry.py +382 -0
  53. digitalkin/services/registry/registry_models.py +15 -0
  54. digitalkin/services/registry/registry_strategy.py +106 -4
  55. digitalkin/services/services_config.py +25 -3
  56. digitalkin/services/services_models.py +5 -1
  57. digitalkin/services/setup/default_setup.py +1 -1
  58. digitalkin/services/setup/grpc_setup.py +1 -1
  59. digitalkin/services/storage/grpc_storage.py +1 -1
  60. digitalkin/services/user_profile/__init__.py +11 -0
  61. digitalkin/services/user_profile/grpc_user_profile.py +2 -2
  62. digitalkin/services/user_profile/user_profile_strategy.py +0 -15
  63. digitalkin/utils/__init__.py +15 -3
  64. digitalkin/utils/conditional_schema.py +260 -0
  65. digitalkin/utils/dynamic_schema.py +4 -0
  66. digitalkin/utils/schema_splitter.py +290 -0
  67. {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/METADATA +12 -12
  68. digitalkin-0.3.2a3.dist-info/RECORD +144 -0
  69. {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/WHEEL +1 -1
  70. {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/top_level.txt +1 -0
  71. modules/archetype_with_tools_module.py +232 -0
  72. modules/cpu_intensive_module.py +1 -1
  73. modules/dynamic_setup_module.py +5 -29
  74. modules/minimal_llm_module.py +1 -1
  75. modules/text_transform_module.py +1 -1
  76. monitoring/digitalkin_observability/__init__.py +46 -0
  77. monitoring/digitalkin_observability/http_server.py +150 -0
  78. monitoring/digitalkin_observability/interceptors.py +176 -0
  79. monitoring/digitalkin_observability/metrics.py +201 -0
  80. monitoring/digitalkin_observability/prometheus.py +137 -0
  81. monitoring/tests/test_metrics.py +172 -0
  82. services/filesystem_module.py +7 -5
  83. services/storage_module.py +4 -2
  84. digitalkin/grpc_servers/registry_server.py +0 -65
  85. digitalkin/grpc_servers/registry_servicer.py +0 -456
  86. digitalkin-0.3.1.dev2.dist-info/RECORD +0 -119
  87. {digitalkin-0.3.1.dev2.dist-info → digitalkin-0.3.2a3.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,7 @@ import asyncio
4
4
  import datetime
5
5
  import os
6
6
  from collections.abc import AsyncGenerator
7
- from typing import Any, Generic, TypeVar
7
+ from typing import Any, Generic, TypeVar, cast
8
8
  from uuid import UUID
9
9
 
10
10
  from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
@@ -40,6 +40,7 @@ class SurrealDBConnection(Generic[TSurreal]):
40
40
  db: TSurreal
41
41
  timeout: datetime.timedelta
42
42
  _live_queries: set[UUID] # Track active live queries for cleanup
43
+ _closed: bool # Flag to prevent operations on closed connection
43
44
 
44
45
  @staticmethod
45
46
  def _valid_id(raw_id: str, table_name: str) -> RecordID:
@@ -85,13 +86,14 @@ class SurrealDBConnection(Generic[TSurreal]):
85
86
  self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
86
87
  self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
87
88
  self._live_queries = set() # Initialize live queries tracker
89
+ self._closed = False
88
90
 
89
91
  async def init_surreal_instance(self) -> None:
90
92
  """Init a SurrealDB connection instance."""
91
93
  logger.debug("Connecting to SurrealDB at %s", self.url)
92
94
  self.db = AsyncSurreal(self.url) # type: ignore
93
95
  await self.db.signin({"username": self.username, "password": self.password})
94
- await self.db.use(self.namespace, self.database)
96
+ await self.db.use(self.namespace, self.database) # type: ignore[arg-type]
95
97
  logger.debug("Successfully connected to SurrealDB")
96
98
 
97
99
  async def close(self) -> None:
@@ -99,6 +101,7 @@ class SurrealDBConnection(Generic[TSurreal]):
99
101
 
100
102
  This will also kill all active live queries to prevent memory leaks.
101
103
  """
104
+ self._closed = True
102
105
  # Kill all tracked live queries before closing connection
103
106
  if self._live_queries:
104
107
  logger.debug("Killing %d active live queries before closing", len(self._live_queries))
@@ -112,7 +115,7 @@ class SurrealDBConnection(Generic[TSurreal]):
112
115
  # Process results and track failures
113
116
  failed_queries = []
114
117
  for live_id, result in zip(live_query_ids, results):
115
- if isinstance(result, (ConnectionError, TimeoutError, Exception)):
118
+ if isinstance(result, ConnectionError | TimeoutError | Exception):
116
119
  failed_queries.append((live_id, str(result)))
117
120
  else:
118
121
  self._live_queries.discard(live_id)
@@ -142,11 +145,27 @@ class SurrealDBConnection(Generic[TSurreal]):
142
145
 
143
146
  Returns:
144
147
  Dict[str, Any]: The created record as returned by the database
148
+
149
+ Raises:
150
+ RuntimeError: If the database returns an error response
145
151
  """
146
152
  logger.debug("Creating record in %s with data: %s", table_name, data)
147
153
  result = await self.db.create(table_name, data)
148
154
  logger.debug("create result: %s", result)
149
- return result
155
+
156
+ # Check for error response from SurrealDB
157
+ if isinstance(result, dict) and "code" in result:
158
+ error_msg = result.get("message", result.get("information", "Unknown error"))
159
+ logger.error(
160
+ "SurrealDB create failed: %s (code: %s)",
161
+ error_msg,
162
+ result.get("code"),
163
+ extra={"table": table_name, "error": result},
164
+ )
165
+ msg = f"SurrealDB create failed in '{table_name}': {error_msg}"
166
+ raise RuntimeError(msg)
167
+
168
+ return cast("list[dict[str, Any]] | dict[str, Any]", result)
150
169
 
151
170
  async def merge(
152
171
  self,
@@ -170,7 +189,7 @@ class SurrealDBConnection(Generic[TSurreal]):
170
189
  logger.debug("Updating record in %s with data: %s", record_id, data)
171
190
  result = await self.db.merge(record_id, data)
172
191
  logger.debug("update result: %s", result)
173
- return result
192
+ return cast("list[dict[str, Any]] | dict[str, Any]", result)
174
193
 
175
194
  async def update(
176
195
  self,
@@ -194,7 +213,7 @@ class SurrealDBConnection(Generic[TSurreal]):
194
213
  logger.debug("Updating record in %s with data: %s", record_id, data)
195
214
  result = await self.db.update(record_id, data)
196
215
  logger.debug("update result: %s", result)
197
- return result
216
+ return cast("list[dict[str, Any]] | dict[str, Any]", result)
198
217
 
199
218
  async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
200
219
  """Execute a custom SurrealQL query.
@@ -209,7 +228,7 @@ class SurrealDBConnection(Generic[TSurreal]):
209
228
  logger.debug("execute_query: %s with params: %s", query, params)
210
229
  result = await self.db.query(query, params or {})
211
230
  logger.debug("execute_query result: %s", result)
212
- return [result] if isinstance(result, dict) else result
231
+ return cast("list[dict[str, Any]]", [result] if isinstance(result, dict) else result)
213
232
 
214
233
  async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
215
234
  """Fetch a record from a table by a unique field.
@@ -260,6 +279,9 @@ class SurrealDBConnection(Generic[TSurreal]):
260
279
  Args:
261
280
  live_id: Live query ID to kill
262
281
  """
282
+ if self._closed:
283
+ self._live_queries.discard(live_id)
284
+ return
263
285
  logger.debug("Killing live query: %s", live_id)
264
286
  await self.db.kill(live_id)
265
287
  self._live_queries.discard(live_id) # Remove from tracker
@@ -1,6 +1,7 @@
1
1
  """Task executor for running tasks with full lifecycle management."""
2
2
 
3
3
  import asyncio
4
+ import contextlib
4
5
  import datetime
5
6
  from collections.abc import Coroutine
6
7
  from typing import Any
@@ -54,28 +55,53 @@ class TaskExecutor:
54
55
  async def signal_wrapper() -> None:
55
56
  """Create initial signal record and listen for signals."""
56
57
  try:
57
- await channel.create(
58
+ # Create task record and capture the record ID directly
59
+ # This avoids a race condition where SELECT might run before CREATE completes
60
+ result = await channel.create(
58
61
  "tasks",
59
62
  SignalMessage(
60
63
  task_id=task_id,
61
64
  mission_id=mission_id,
65
+ setup_id=session.setup_id,
66
+ setup_version_id=session.setup_version_id,
62
67
  status=session.status,
63
68
  action=SignalType.START,
64
69
  ).model_dump(),
65
70
  )
66
- await session.listen_signals()
71
+ # Store the record ID in session - required before starting live query
72
+ if isinstance(result, dict) and "id" in result:
73
+ session.signal_record_id = result["id"]
74
+ logger.debug(
75
+ "Task signal record created",
76
+ extra={"mission_id": mission_id, "task_id": task_id, "record_id": result["id"]},
77
+ )
78
+ # Only start listening if we have a valid record ID
79
+ await session.listen_signals()
80
+ else:
81
+ # Create failed - wait for cancellation instead of listening
82
+ logger.error(
83
+ "Failed to get record ID from task creation, waiting for cancellation",
84
+ extra={"mission_id": mission_id, "task_id": task_id, "result": result},
85
+ )
86
+ await session.is_cancelled.wait()
67
87
  except asyncio.CancelledError:
68
88
  logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
69
89
  finally:
70
- await channel.create(
71
- "tasks",
72
- SignalMessage(
73
- task_id=task_id,
74
- mission_id=mission_id,
75
- status=session.status,
76
- action=SignalType.STOP,
77
- ).model_dump(),
78
- )
90
+ with contextlib.suppress(Exception): # Connection may already be closed
91
+ await channel.create(
92
+ "tasks",
93
+ SignalMessage(
94
+ task_id=task_id,
95
+ mission_id=mission_id,
96
+ setup_id=session.setup_id,
97
+ setup_version_id=session.setup_version_id,
98
+ status=session.status,
99
+ action=SignalType.STOP,
100
+ cancellation_reason=session.cancellation_reason,
101
+ error_message=session._last_exception, # noqa: SLF001
102
+ exception_traceback=session._last_traceback, # noqa: SLF001
103
+ ).model_dump(),
104
+ )
79
105
  logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
80
106
 
81
107
  async def heartbeat_wrapper() -> None:
@@ -125,8 +151,14 @@ class TaskExecutor:
125
151
  # Heartbeat stopped - failure cleanup
126
152
  cleanup_reason = CancellationReason.FAILURE_CLEANUP
127
153
 
154
+ # Signal stream to close FIRST before any cleanup
155
+ session.close_stream()
156
+
128
157
  # Cancel pending tasks with proper reason logging
129
158
  if pending:
159
+ # Give stream time to see the signal and exit gracefully
160
+ await asyncio.sleep(0.01) # Allow one event loop cycle
161
+
130
162
  pending_names = [t.get_name() for t in pending]
131
163
  logger.debug(
132
164
  "Cancelling pending tasks: %s, reason: %s",
@@ -148,6 +180,7 @@ class TaskExecutor:
148
180
  # Determine final status based on which task completed
149
181
  if completed is main_task:
150
182
  session.status = TaskStatus.COMPLETED
183
+ session.cancellation_reason = CancellationReason.COMPLETED
151
184
  logger.info(
152
185
  "Main task completed successfully",
153
186
  extra={"mission_id": mission_id, "task_id": task_id},
@@ -193,9 +226,10 @@ class TaskExecutor:
193
226
  )
194
227
  cleanup_reason = CancellationReason.FAILURE_CLEANUP
195
228
  raise
196
- except Exception:
229
+ except Exception as e:
197
230
  session.status = TaskStatus.FAILED
198
231
  cleanup_reason = CancellationReason.FAILURE_CLEANUP
232
+ session.record_exception(e)
199
233
  logger.exception(
200
234
  "Task failed with exception: '%s'",
201
235
  task_id,
@@ -1,7 +1,9 @@
1
1
  """Task session easing task lifecycle management."""
2
2
 
3
3
  import asyncio
4
+ import contextlib
4
5
  import datetime
6
+ import traceback
5
7
  from collections.abc import AsyncGenerator
6
8
 
7
9
  from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
@@ -39,9 +41,17 @@ class TaskSession:
39
41
  is_cancelled: asyncio.Event
40
42
  cancellation_reason: CancellationReason
41
43
  _paused: asyncio.Event
44
+ _stream_closed: asyncio.Event
42
45
  _heartbeat_interval: datetime.timedelta
43
46
  _last_heartbeat: datetime.datetime
44
47
 
48
+ # Exception tracking for enhanced DB logging
49
+ _last_exception: str | None
50
+ _last_traceback: str | None
51
+
52
+ # Cleanup guard for idempotent cleanup
53
+ _cleanup_done: bool
54
+
45
55
  def __init__(
46
56
  self,
47
57
  task_id: str,
@@ -81,12 +91,23 @@ class TaskSession:
81
91
  self.is_cancelled = asyncio.Event()
82
92
  self.cancellation_reason = CancellationReason.UNKNOWN
83
93
  self._paused = asyncio.Event()
94
+ self._stream_closed = asyncio.Event()
84
95
  self._heartbeat_interval = heartbeat_interval
85
96
 
97
+ # Exception tracking
98
+ self._last_exception = None
99
+ self._last_traceback = None
100
+
101
+ # Cleanup guard
102
+ self._cleanup_done = False
103
+
86
104
  logger.info(
87
- "TaskContext initialized for task: '%s'",
88
- task_id,
89
- extra={"task_id": task_id, "mission_id": mission_id, "heartbeat_interval": heartbeat_interval},
105
+ "TaskSession initialized",
106
+ extra={
107
+ "task_id": task_id,
108
+ "mission_id": mission_id,
109
+ "heartbeat_interval": str(heartbeat_interval),
110
+ },
90
111
  )
91
112
 
92
113
  @property
@@ -99,6 +120,39 @@ class TaskSession:
99
120
  """Task paused status."""
100
121
  return self._paused.is_set()
101
122
 
123
+ @property
124
+ def stream_closed(self) -> bool:
125
+ """Check if stream termination was signaled."""
126
+ return self._stream_closed.is_set()
127
+
128
+ def close_stream(self) -> None:
129
+ """Signal that the stream should terminate."""
130
+ self._stream_closed.set()
131
+
132
+ @property
133
+ def setup_id(self) -> str:
134
+ """Get setup_id from module context."""
135
+ return self.module.context.session.setup_id
136
+
137
+ @property
138
+ def setup_version_id(self) -> str:
139
+ """Get setup_version_id from module context."""
140
+ return self.module.context.session.setup_version_id
141
+
142
+ @property
143
+ def session_ids(self) -> dict[str, str]:
144
+ """Get all session IDs from module context for structured logging."""
145
+ return self.module.context.session.current_ids()
146
+
147
+ def record_exception(self, exc: Exception) -> None:
148
+ """Record exception details for DB logging.
149
+
150
+ Args:
151
+ exc: The exception that caused the task to fail.
152
+ """
153
+ self._last_exception = str(exc)
154
+ self._last_traceback = traceback.format_exc()
155
+
102
156
  async def send_heartbeat(self) -> bool:
103
157
  """Rate-limited heartbeat with connection resilience.
104
158
 
@@ -108,6 +162,8 @@ class TaskSession:
108
162
  heartbeat = HeartbeatMessage(
109
163
  task_id=self.task_id,
110
164
  mission_id=self.mission_id,
165
+ setup_id=self.setup_id,
166
+ setup_version_id=self.setup_version_id,
111
167
  timestamp=datetime.datetime.now(datetime.timezone.utc),
112
168
  )
113
169
 
@@ -120,23 +176,17 @@ class TaskSession:
120
176
  return True
121
177
  except Exception as e:
122
178
  logger.error(
123
- "Heartbeat exception for task: '%s'",
124
- self.task_id,
125
- extra={"task_id": self.task_id, "error": str(e)},
179
+ "Heartbeat exception",
180
+ extra={**self.session_ids, "error": str(e)},
126
181
  exc_info=True,
127
182
  )
128
- logger.error(
129
- "Initial heartbeat failed for task: '%s'",
130
- self.task_id,
131
- extra={"task_id": self.task_id},
132
- )
183
+ logger.error("Initial heartbeat failed", extra=self.session_ids)
133
184
  return False
134
185
 
135
186
  if (heartbeat.timestamp - self._last_heartbeat) < self._heartbeat_interval:
136
187
  logger.debug(
137
- "Heartbeat skipped due to rate limiting for task: '%s' | delta=%s",
138
- self.task_id,
139
- heartbeat.timestamp - self._last_heartbeat,
188
+ "Heartbeat skipped due to rate limiting",
189
+ extra={**self.session_ids, "delta": str(heartbeat.timestamp - self._last_heartbeat)},
140
190
  )
141
191
  return True
142
192
 
@@ -147,39 +197,24 @@ class TaskSession:
147
197
  return True
148
198
  except Exception as e:
149
199
  logger.error(
150
- "Heartbeat exception for task: '%s'",
151
- self.task_id,
152
- extra={"task_id": self.task_id, "error": str(e)},
200
+ "Heartbeat exception",
201
+ extra={**self.session_ids, "error": str(e)},
153
202
  exc_info=True,
154
203
  )
155
- logger.warning(
156
- "Heartbeat failed for task: '%s'",
157
- self.task_id,
158
- extra={"task_id": self.task_id},
159
- )
204
+ logger.warning("Heartbeat failed", extra=self.session_ids)
160
205
  return False
161
206
 
162
207
  async def generate_heartbeats(self) -> None:
163
208
  """Periodic heartbeat generator with cancellation support."""
164
- logger.debug(
165
- "Heartbeat generator started for task: '%s'",
166
- self.task_id,
167
- extra={"task_id": self.task_id, "mission_id": self.mission_id},
168
- )
209
+ logger.debug("Heartbeat generator started", extra=self.session_ids)
169
210
  while not self.cancelled:
170
211
  logger.debug(
171
- "Heartbeat tick for task: '%s', cancelled=%s",
172
- self.task_id,
173
- self.cancelled,
174
- extra={"task_id": self.task_id, "mission_id": self.mission_id},
212
+ "Heartbeat tick",
213
+ extra={**self.session_ids, "cancelled": self.cancelled},
175
214
  )
176
215
  success = await self.send_heartbeat()
177
216
  if not success:
178
- logger.error(
179
- "Heartbeat failed, cancelling task: '%s'",
180
- self.task_id,
181
- extra={"task_id": self.task_id, "mission_id": self.mission_id},
182
- )
217
+ logger.error("Heartbeat failed, cancelling task", extra=self.session_ids)
183
218
  await self._handle_cancel(CancellationReason.HEARTBEAT_FAILURE)
184
219
  break
185
220
  await asyncio.sleep(self._heartbeat_interval.total_seconds())
@@ -187,32 +222,32 @@ class TaskSession:
187
222
  async def wait_if_paused(self) -> None:
188
223
  """Block execution if task is paused."""
189
224
  if self._paused.is_set():
190
- logger.info(
191
- "Task paused, waiting for resume: '%s'",
192
- self.task_id,
193
- extra={"task_id": self.task_id},
194
- )
225
+ logger.info("Task paused, waiting for resume", extra=self.session_ids)
195
226
  await self._paused.wait()
196
227
 
197
228
  async def listen_signals(self) -> None: # noqa: C901
198
229
  """Enhanced signal listener with comprehensive handling.
199
230
 
200
231
  Raises:
201
- CancelledError: Asyncio when task cancelling
232
+ CancelledError: If task is cancelled during signal listening.
202
233
  """
203
- logger.info(
204
- "Signal listener started for task: '%s'",
205
- self.task_id,
206
- extra={"task_id": self.task_id},
207
- )
234
+ logger.info("Signal listener started", extra=self.session_ids)
235
+
236
+ # signal_record_id must be set by TaskExecutor before calling this method.
237
+ # If not set, we cannot filter signals correctly - abort early.
208
238
  if self.signal_record_id is None:
209
- self.signal_record_id = (await self.db.select_by_task_id("tasks", self.task_id)).get("id")
239
+ logger.error(
240
+ "signal_record_id not set - cannot start signal listener without valid record ID",
241
+ extra=self.session_ids,
242
+ )
243
+ return
210
244
 
211
245
  live_id, live_signals = await self.db.start_live("tasks")
212
246
  try:
213
247
  async for signal in live_signals:
214
- logger.debug("Signal received for task '%s': %s", self.task_id, signal)
215
- if self.cancelled:
248
+ logger.debug("Signal received", extra={**self.session_ids, "signal": signal})
249
+ # Check both cancelled and stream_closed to ensure clean shutdown
250
+ if self.cancelled or self.stream_closed:
216
251
  break
217
252
 
218
253
  if signal is None or signal["id"] == self.signal_record_id or "payload" not in signal:
@@ -228,26 +263,18 @@ class TaskSession:
228
263
  await self._handle_status_request()
229
264
 
230
265
  except asyncio.CancelledError:
231
- logger.debug(
232
- "Signal listener cancelled for task: '%s'",
233
- self.task_id,
234
- extra={"task_id": self.task_id},
235
- )
266
+ logger.debug("Signal listener cancelled", extra=self.session_ids)
236
267
  raise
237
268
  except Exception as e:
238
269
  logger.error(
239
- "Signal listener fatal error for task: '%s'",
240
- self.task_id,
241
- extra={"task_id": self.task_id, "error": str(e)},
270
+ "Signal listener fatal error",
271
+ extra={**self.session_ids, "error": str(e)},
242
272
  exc_info=True,
243
273
  )
244
274
  finally:
245
- await self.db.stop_live(live_id)
246
- logger.info(
247
- "Signal listener stopped for task: '%s'",
248
- self.task_id,
249
- extra={"task_id": self.task_id},
250
- )
275
+ with contextlib.suppress(Exception): # Connection may already be closed
276
+ await self.db.stop_live(live_id)
277
+ logger.info("Signal listener stopped", extra=self.session_ids)
251
278
 
252
279
  async def _handle_cancel(self, reason: CancellationReason = CancellationReason.UNKNOWN) -> None:
253
280
  """Idempotent cancellation with acknowledgment and reason tracking.
@@ -257,13 +284,9 @@ class TaskSession:
257
284
  """
258
285
  if self.is_cancelled.is_set():
259
286
  logger.debug(
260
- "Cancel ignored - task already cancelled: '%s' (existing reason: %s, new reason: %s)",
261
- self.task_id,
262
- self.cancellation_reason.value,
263
- reason.value,
287
+ "Cancel ignored - already cancelled",
264
288
  extra={
265
- "task_id": self.task_id,
266
- "mission_id": self.mission_id,
289
+ **self.session_ids,
267
290
  "existing_reason": self.cancellation_reason.value,
268
291
  "new_reason": reason.value,
269
292
  },
@@ -277,25 +300,13 @@ class TaskSession:
277
300
  # Log with appropriate level based on reason
278
301
  if reason in {CancellationReason.SUCCESS_CLEANUP, CancellationReason.FAILURE_CLEANUP}:
279
302
  logger.debug(
280
- "Task cancelled (cleanup): '%s', reason: %s",
281
- self.task_id,
282
- reason.value,
283
- extra={
284
- "task_id": self.task_id,
285
- "mission_id": self.mission_id,
286
- "cancellation_reason": reason.value,
287
- },
303
+ "Task cancelled (cleanup)",
304
+ extra={**self.session_ids, "cancellation_reason": reason.value},
288
305
  )
289
306
  else:
290
307
  logger.info(
291
- "Task cancelled: '%s', reason: %s",
292
- self.task_id,
293
- reason.value,
294
- extra={
295
- "task_id": self.task_id,
296
- "mission_id": self.mission_id,
297
- "cancellation_reason": reason.value,
298
- },
308
+ "Task cancelled",
309
+ extra={**self.session_ids, "cancellation_reason": reason.value},
299
310
  )
300
311
 
301
312
  # Resume if paused so cancellation can proceed
@@ -308,19 +319,18 @@ class TaskSession:
308
319
  SignalMessage(
309
320
  task_id=self.task_id,
310
321
  mission_id=self.mission_id,
322
+ setup_id=self.setup_id,
323
+ setup_version_id=self.setup_version_id,
311
324
  action=SignalType.ACK_CANCEL,
312
325
  status=self.status,
326
+ cancellation_reason=reason,
313
327
  ).model_dump(),
314
328
  )
315
329
 
316
330
  async def _handle_pause(self) -> None:
317
331
  """Pause task execution."""
318
332
  if not self._paused.is_set():
319
- logger.info(
320
- "Pausing task: '%s'",
321
- self.task_id,
322
- extra={"task_id": self.task_id},
323
- )
333
+ logger.info("Task paused", extra=self.session_ids)
324
334
  self._paused.set()
325
335
 
326
336
  await self.db.update(
@@ -329,6 +339,8 @@ class TaskSession:
329
339
  SignalMessage(
330
340
  task_id=self.task_id,
331
341
  mission_id=self.mission_id,
342
+ setup_id=self.setup_id,
343
+ setup_version_id=self.setup_version_id,
332
344
  action=SignalType.ACK_PAUSE,
333
345
  status=self.status,
334
346
  ).model_dump(),
@@ -337,11 +349,7 @@ class TaskSession:
337
349
  async def _handle_resume(self) -> None:
338
350
  """Resume paused task."""
339
351
  if self._paused.is_set():
340
- logger.info(
341
- "Resuming task: '%s'",
342
- self.task_id,
343
- extra={"task_id": self.task_id},
344
- )
352
+ logger.info("Task resumed", extra=self.session_ids)
345
353
  self._paused.clear()
346
354
 
347
355
  await self.db.update(
@@ -350,6 +358,8 @@ class TaskSession:
350
358
  SignalMessage(
351
359
  task_id=self.task_id,
352
360
  mission_id=self.mission_id,
361
+ setup_id=self.setup_id,
362
+ setup_version_id=self.setup_version_id,
353
363
  action=SignalType.ACK_RESUME,
354
364
  status=self.status,
355
365
  ).model_dump(),
@@ -361,28 +371,38 @@ class TaskSession:
361
371
  "tasks",
362
372
  self.signal_record_id, # type: ignore
363
373
  SignalMessage(
364
- mission_id=self.mission_id,
365
374
  task_id=self.task_id,
375
+ mission_id=self.mission_id,
376
+ setup_id=self.setup_id,
377
+ setup_version_id=self.setup_version_id,
366
378
  status=self.status,
367
379
  action=SignalType.ACK_STATUS,
368
380
  ).model_dump(),
369
381
  )
370
382
 
371
- logger.debug(
372
- "Status report sent for task: '%s'",
373
- self.task_id,
374
- extra={"task_id": self.task_id},
375
- )
383
+ logger.debug("Status report sent", extra=self.session_ids)
376
384
 
377
385
  async def cleanup(self) -> None:
378
386
  """Clean up task session resources.
379
387
 
388
+ This method is idempotent - safe to call multiple times.
389
+ Second and subsequent calls are no-ops.
390
+
380
391
  This includes:
381
392
  - Clearing queue to free memory
393
+ - Cleaning up module context services
382
394
  - Stopping module
383
395
  - Closing database connection
384
396
  - Clearing module reference
385
397
  """
398
+ if self._cleanup_done:
399
+ logger.debug(
400
+ "Cleanup already done, skipping",
401
+ extra={"task_id": self.task_id, "mission_id": self.mission_id},
402
+ )
403
+ return
404
+ self._cleanup_done = True
405
+
386
406
  # Clear queue to free memory
387
407
  try:
388
408
  while not self.queue.empty():
@@ -390,6 +410,16 @@ class TaskSession:
390
410
  except asyncio.QueueEmpty:
391
411
  pass
392
412
 
413
+ # Clean up module context services (e.g., gRPC channel pool)
414
+ if self.module is not None and self.module.context is not None:
415
+ try:
416
+ await self.module.context.cleanup()
417
+ except Exception:
418
+ logger.exception(
419
+ "Error cleaning up module context",
420
+ extra={"mission_id": self.mission_id, "task_id": self.task_id},
421
+ )
422
+
393
423
  # Stop module
394
424
  try:
395
425
  await self.module.stop()