digitalkin 0.2.26__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/grpc_servers/module_server.py +27 -44
  3. digitalkin/grpc_servers/module_servicer.py +27 -22
  4. digitalkin/grpc_servers/utils/models.py +1 -1
  5. digitalkin/logger.py +1 -9
  6. digitalkin/mixins/__init__.py +19 -0
  7. digitalkin/mixins/base_mixin.py +10 -0
  8. digitalkin/mixins/callback_mixin.py +24 -0
  9. digitalkin/mixins/chat_history_mixin.py +108 -0
  10. digitalkin/mixins/cost_mixin.py +76 -0
  11. digitalkin/mixins/file_history_mixin.py +99 -0
  12. digitalkin/mixins/filesystem_mixin.py +47 -0
  13. digitalkin/mixins/logger_mixin.py +59 -0
  14. digitalkin/mixins/storage_mixin.py +79 -0
  15. digitalkin/models/module/__init__.py +2 -0
  16. digitalkin/models/module/module.py +9 -1
  17. digitalkin/models/module/module_context.py +90 -6
  18. digitalkin/models/module/module_types.py +6 -6
  19. digitalkin/models/module/task_monitor.py +51 -0
  20. digitalkin/models/services/__init__.py +9 -0
  21. digitalkin/models/services/storage.py +39 -5
  22. digitalkin/modules/_base_module.py +47 -68
  23. digitalkin/modules/job_manager/base_job_manager.py +12 -8
  24. digitalkin/modules/job_manager/single_job_manager.py +84 -78
  25. digitalkin/modules/job_manager/surrealdb_repository.py +225 -0
  26. digitalkin/modules/job_manager/task_manager.py +391 -0
  27. digitalkin/modules/job_manager/task_session.py +276 -0
  28. digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
  29. digitalkin/modules/tool_module.py +10 -2
  30. digitalkin/modules/trigger_handler.py +7 -6
  31. digitalkin/services/cost/__init__.py +9 -2
  32. digitalkin/services/storage/grpc_storage.py +1 -1
  33. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dist-info}/METADATA +18 -18
  34. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dist-info}/RECORD +37 -24
  35. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dist-info}/WHEEL +0 -0
  36. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dist-info}/licenses/LICENSE +0 -0
  37. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,225 @@
1
+ """SurrealDB connection management."""
2
+
3
+ import datetime
4
+ import os
5
+ from collections.abc import AsyncGenerator
6
+ from typing import Any, Generic, TypeVar
7
+ from uuid import UUID
8
+
9
+ from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
10
+
11
+ from digitalkin.logger import logger
12
+
13
+ TSurreal = TypeVar("TSurreal", bound=AsyncHttpSurrealConnection | AsyncWsSurrealConnection)
14
+
15
+
16
+ class SurrealDBSetupBadIDError(Exception):
17
+ """Exception raised when an invalid ID is encountered during the setup process in the SurrealDB repository.
18
+
19
+ This error is used to indicate that the provided ID does not meet the
20
+ expected format or criteria.
21
+ """
22
+
23
+
24
+ class SurrealDBSetupVersionBadIDError(Exception):
25
+ """Exception raised when an invalid ID is encountered during the setup of a SurrealDB version.
26
+
27
+ This error is intended to signal that the provided ID does not meet
28
+ the expected format or criteria for a valid SurrealDB setup version ID.
29
+ """
30
+
31
+
32
+ class SurrealDBConnection(Generic[TSurreal]):
33
+ """Base repository for database operations.
34
+
35
+ This class provides common database operations that can be used by
36
+ specific table repositories.
37
+ """
38
+
39
+ db: TSurreal
40
+ timeout: datetime.timedelta
41
+
42
+ @staticmethod
43
+ def _valid_id(raw_id: str, table_name: str) -> RecordID:
44
+ """Validate and parse a raw ID string into a RecordID.
45
+
46
+ Args:
47
+ raw_id: The raw ID string to validate
48
+ table_name: table name to enforce
49
+
50
+ Raises:
51
+ SurrealDBSetupBadIDError: If the raw ID string is not valid
52
+
53
+ Returns:
54
+ RecordID: Parsed RecordID object if valid, None otherwise
55
+ """
56
+ try:
57
+ split_id = raw_id.split(":")
58
+ if split_id[0] != table_name:
59
+ msg = f"Invalid table name for ID: {raw_id}"
60
+ raise SurrealDBSetupBadIDError(msg)
61
+ return RecordID(split_id[0], split_id[1])
62
+ except IndexError:
63
+ raise SurrealDBSetupBadIDError
64
+
65
+ def __init__(
66
+ self,
67
+ database: str | None = None,
68
+ timeout: datetime.timedelta = datetime.timedelta(seconds=5),
69
+ ) -> None:
70
+ """Initialize the repository.
71
+
72
+ Args:
73
+ database: AsyncSurrealDB connection to a specific database
74
+ table_name: Name of the table to interact with
75
+ timeout: Timeout for database operations
76
+ """
77
+ self.timeout = timeout
78
+ self.url = f"{os.getenv('SURREALDB_URL', 'ws://localhost')}:{os.getenv('SURREALDB_PORT', '8000')}/rpc"
79
+ self.username = os.getenv("SURREALDB_USERNAME", "root")
80
+ self.password = os.getenv("SURREALDB_PASSWORD", "root")
81
+ self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
82
+ self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
83
+
84
+ async def init_surreal_instance(self) -> None:
85
+ """Init a SurrealDB connection instance."""
86
+ logger.debug("Connecting to SurrealDB at %s", self.url)
87
+ self.db = AsyncSurreal(self.url) # type: ignore
88
+ await self.db.signin({"username": self.username, "password": self.password})
89
+ await self.db.use(self.namespace, self.database)
90
+ logger.debug("Successfully connected to SurrealDB")
91
+
92
+ async def close(self) -> None:
93
+ """Close the SurrealDB connection if it exists."""
94
+ logger.debug("Closing SurrealDB connection")
95
+ await self.db.close()
96
+
97
+ async def create(
98
+ self,
99
+ table_name: str,
100
+ data: dict[str, Any],
101
+ ) -> list[dict[str, Any]] | dict[str, Any]:
102
+ """Create a new record.
103
+
104
+ Args:
105
+ table_name: Name of the table to insert into
106
+ data: Data to insert
107
+
108
+ Returns:
109
+ Dict[str, Any]: The created record as returned by the database
110
+ """
111
+ logger.debug("Creating record in %s with data: %s", table_name, data)
112
+ result = await self.db.create(table_name, data)
113
+ logger.debug("create result: %s", result)
114
+ return result
115
+
116
+ async def merge(
117
+ self,
118
+ table_name: str,
119
+ record_id: str | RecordID,
120
+ data: dict[str, Any],
121
+ ) -> list[dict[str, Any]] | dict[str, Any]:
122
+ """Update an existing record.
123
+
124
+ Args:
125
+ table_name: Name of the table to insert into
126
+ record_id: record ID to update
127
+ data: Data to insert
128
+
129
+ Returns:
130
+ Dict[str, Any]: The created record as returned by the database
131
+ """
132
+ if isinstance(record_id, str):
133
+ # validate surrealDB id if raw str
134
+ record_id = self._valid_id(record_id, table_name)
135
+ logger.debug("Updating record in %s with data: %s", record_id, data)
136
+ result = await self.db.merge(record_id, data)
137
+ logger.debug("update result: %s", result)
138
+ return result
139
+
140
+ async def update(
141
+ self,
142
+ table_name: str,
143
+ record_id: str | RecordID,
144
+ data: dict[str, Any],
145
+ ) -> list[dict[str, Any]] | dict[str, Any]:
146
+ """Update an existing record.
147
+
148
+ Args:
149
+ table_name: Name of the table to insert into
150
+ record_id: record ID to update
151
+ data: Data to insert
152
+
153
+ Returns:
154
+ Dict[str, Any]: The created record as returned by the database
155
+ """
156
+ if isinstance(record_id, str):
157
+ # validate surrealDB id if raw str
158
+ record_id = self._valid_id(record_id, table_name)
159
+ logger.debug("Updating record in %s with data: %s", record_id, data)
160
+ result = await self.db.update(record_id, data)
161
+ logger.debug("update result: %s", result)
162
+ return result
163
+
164
+ async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
165
+ """Execute a custom SurrealQL query.
166
+
167
+ Args:
168
+ query: SurrealQL query
169
+ params: Query parameters
170
+
171
+ Returns:
172
+ List[Dict[str, Any]]: Query results
173
+ """
174
+ logger.debug("execute_query: %s with params: %s", query, params)
175
+ result = await self.db.query(query, params or {})
176
+ logger.debug("execute_query result: %s", result)
177
+ return [result] if isinstance(result, dict) else result
178
+
179
+ async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
180
+ """Fetch a record from a table by a unique field.
181
+
182
+ Args:
183
+ table: Table name
184
+ value: Field value to match
185
+
186
+ Raises:
187
+ ValueError: If no records are found
188
+
189
+ Returns:
190
+ Dict with record data if found, else None
191
+ """
192
+ query = "SELECT * FROM type::table($table) WHERE task_id = $value;"
193
+ params = {"table": table, "value": value}
194
+
195
+ result = await self.execute_query(query, params)
196
+ if not result:
197
+ msg = f"No records found in table '{table}' with task_id '{value}'"
198
+ logger.error(msg)
199
+ raise ValueError(msg)
200
+
201
+ return result[0]
202
+
203
+ async def start_live(
204
+ self,
205
+ table_name: str,
206
+ ) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
207
+ """Create and subscribe to a live SurrealQL query.
208
+
209
+ Args:
210
+ table_name: Name of the table to insert into
211
+
212
+ Returns:
213
+ List[Dict[str, Any]]: Query results
214
+ """
215
+ live_id = await self.db.live(table_name, diff=False)
216
+ return live_id, await self.db.subscribe_live(live_id)
217
+
218
+ async def stop_live(self, live_id: UUID) -> None:
219
+ """Kill a live SurrealQL query.
220
+
221
+ Args:
222
+ live_id: record ID to watch for
223
+ """
224
+ logger.debug("KILL Subscribe live for: %s", live_id)
225
+ await self.db.kill(live_id)
@@ -0,0 +1,391 @@
1
+ """Task manager with comprehensive lifecycle management."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import datetime
6
+ from collections.abc import Coroutine
7
+ from typing import Any
8
+
9
+ from digitalkin.logger import logger
10
+ from digitalkin.models.module.task_monitor import SignalMessage, SignalType, TaskStatus
11
+ from digitalkin.modules._base_module import BaseModule
12
+ from digitalkin.modules.job_manager.task_session import SurrealDBConnection, TaskSession
13
+
14
+
15
+ class TaskManager:
16
+ """Task manager with comprehensive lifecycle management."""
17
+
18
+ tasks: dict[str, asyncio.Task]
19
+ tasks_sessions: dict[str, TaskSession]
20
+ channel: SurrealDBConnection
21
+ default_timeout: float
22
+ max_concurrent_tasks: int
23
+ _shutdown_event: asyncio.Event
24
+
25
+ def __init__(self, default_timeout: float = 10.0, max_concurrent_tasks: int = 100) -> None:
26
+ """."""
27
+ self.tasks = {}
28
+ self.tasks_sessions = {}
29
+ self.default_timeout = default_timeout
30
+ self.max_concurrent_tasks = max_concurrent_tasks
31
+ self._shutdown_event = asyncio.Event()
32
+
33
+ logger.info(
34
+ "TaskManager initialized with max_concurrent_tasks: %d, default_timeout: %.1f",
35
+ max_concurrent_tasks,
36
+ default_timeout,
37
+ extra={"max_concurrent_tasks": max_concurrent_tasks, "default_timeout": default_timeout},
38
+ )
39
+
40
+ @property
41
+ def task_count(self) -> int:
42
+ """."""
43
+ return len(self.tasks_sessions)
44
+
45
+ @property
46
+ def running_tasks(self) -> set[str]:
47
+ """."""
48
+ return {task_id for task_id, task in self.tasks.items() if not task.done()}
49
+
50
+ async def _cleanup_task(self, task_id: str) -> None:
51
+ """Clean up task resources."""
52
+ logger.debug("Cleaning up resources for task: '%s'", task_id, extra={"task_id": task_id})
53
+ if task_id in self.tasks_sessions:
54
+ await self.tasks_sessions[task_id].db.close()
55
+ # Remove from collections
56
+
57
+ async def _task_wrapper( # noqa: C901, PLR0915
58
+ self,
59
+ task_id: str,
60
+ coro: Coroutine[Any, Any, None],
61
+ session: TaskSession,
62
+ ) -> asyncio.Task[None]:
63
+ """Task wrapper that runs main, heartbeat, and listener concurrently.
64
+
65
+ The first to finish determines the outcome. Returns a Task that the
66
+ caller can await externally.
67
+
68
+ Returns:
69
+ asyncio.Task[None]: The supervisor task managing the lifecycle.
70
+ """
71
+
72
+ async def signal_wrapper() -> None:
73
+ try:
74
+ await self.channel.create(
75
+ "tasks",
76
+ SignalMessage(
77
+ task_id=task_id,
78
+ status=session.status,
79
+ action=SignalType.START,
80
+ ).model_dump(),
81
+ )
82
+ await session.listen_signals()
83
+ except asyncio.CancelledError:
84
+ logger.debug("Signal listener cancelled", extra={"task_id": task_id})
85
+ finally:
86
+ await self.channel.create(
87
+ "tasks",
88
+ SignalMessage(
89
+ task_id=task_id,
90
+ status=session.status,
91
+ action=SignalType.STOP,
92
+ ).model_dump(),
93
+ )
94
+ logger.info("Signal listener ended", extra={"task_id": task_id})
95
+
96
+ async def heartbeat_wrapper() -> None:
97
+ try:
98
+ await session.generate_heartbeats()
99
+ except asyncio.CancelledError:
100
+ logger.debug("Signal listener cancelled", extra={"task_id": task_id})
101
+ finally:
102
+ logger.info("Heartbeat task ended", extra={"task_id": task_id})
103
+
104
+ async def supervisor() -> None:
105
+ session.started_at = datetime.datetime.now(datetime.timezone.utc)
106
+ session.status = TaskStatus.RUNNING
107
+
108
+ main_task = asyncio.create_task(coro, name=f"{task_id}_main")
109
+ hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
110
+ sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
111
+
112
+ try:
113
+ done, pending = await asyncio.wait(
114
+ [main_task, sig_task, hb_task],
115
+ return_when=asyncio.FIRST_COMPLETED,
116
+ )
117
+
118
+ # One task completed -> cancel the others
119
+ for t in pending:
120
+ t.cancel()
121
+
122
+ # Propagate exception/result from the finished task
123
+ completed = next(iter(done))
124
+ await completed
125
+
126
+ logger.critical(f"{completed=} | {main_task=} | {hb_task=} | {sig_task=}")
127
+
128
+ if completed is main_task:
129
+ session.status = TaskStatus.COMPLETED
130
+ elif completed is sig_task or (completed is hb_task and sig_task.done()):
131
+ logger.critical(f"{sig_task=}")
132
+ session.status = TaskStatus.CANCELLED
133
+ elif completed is hb_task:
134
+ session.status = TaskStatus.FAILED
135
+ msg = f"Heartbeat stopped for {task_id}"
136
+ raise RuntimeError(msg) # noqa: TRY301
137
+
138
+ except asyncio.CancelledError:
139
+ session.status = TaskStatus.CANCELLED
140
+ raise
141
+ except Exception:
142
+ session.status = TaskStatus.FAILED
143
+ raise
144
+ finally:
145
+ session.completed_at = datetime.datetime.now(datetime.timezone.utc)
146
+ # Ensure all tasks are cleaned up
147
+ for t in [main_task, hb_task, sig_task]:
148
+ if not t.done():
149
+ t.cancel()
150
+ await asyncio.gather(main_task, hb_task, sig_task, return_exceptions=True)
151
+
152
+ # Return the supervisor task to be awaited outside
153
+ return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
154
+
155
+ async def create_task(
156
+ self,
157
+ task_id: str,
158
+ module: BaseModule,
159
+ coro: Coroutine[Any, Any, None],
160
+ heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
161
+ connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
162
+ ) -> None:
163
+ """Create and start a new managed task.
164
+
165
+ Raises:
166
+ ValueError: task_id duplicated
167
+ RuntimeError: task overload
168
+ """
169
+ if task_id in self.tasks:
170
+ # close Coroutine during runtime
171
+ coro.close()
172
+ logger.warning("Task creation failed - task already exists: '%s'", task_id, extra={"task_id": task_id})
173
+ msg = f"Task {task_id} already exists"
174
+ raise ValueError(msg)
175
+
176
+ if len(self.tasks) >= self.max_concurrent_tasks:
177
+ coro.close()
178
+ logger.error(
179
+ "Task creation failed - max concurrent tasks reached: %d",
180
+ self.max_concurrent_tasks,
181
+ extra={
182
+ "task_id": task_id,
183
+ "current_count": len(self.tasks),
184
+ "max_concurrent": self.max_concurrent_tasks,
185
+ },
186
+ )
187
+ msg = f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached"
188
+ raise RuntimeError(msg)
189
+
190
+ logger.info(
191
+ "Creating new task: '%s'",
192
+ task_id,
193
+ extra={
194
+ "task_id": task_id,
195
+ "heartbeat_interval": heartbeat_interval,
196
+ "connection_timeout": connection_timeout,
197
+ },
198
+ )
199
+
200
+ try:
201
+ # Initialize components
202
+ channel: SurrealDBConnection = SurrealDBConnection("task_manager", connection_timeout)
203
+ await channel.init_surreal_instance()
204
+ session = TaskSession(task_id, channel, module, heartbeat_interval)
205
+
206
+ self.tasks_sessions[task_id] = session
207
+
208
+ # Create wrapper task
209
+ self.tasks[task_id] = asyncio.create_task(self._task_wrapper(task_id, coro, session), name=task_id)
210
+
211
+ logger.info(
212
+ "Task created successfully: '%s'",
213
+ task_id,
214
+ extra={
215
+ "task_id": task_id,
216
+ "total_tasks": len(self.tasks),
217
+ },
218
+ )
219
+
220
+ except Exception as e:
221
+ logger.error(
222
+ "Failed to create task: '%s'", task_id, extra={"task_id": task_id, "error": str(e)}, exc_info=True
223
+ )
224
+ # Cleanup on failure
225
+ await self._cleanup_task(task_id)
226
+ raise
227
+
228
+ async def send_signal(self, task_id: str, signal_type: str, payload: dict) -> bool:
229
+ """Send signal to a specific task.
230
+
231
+ Returns:
232
+ bool: True if the task sent successfully the given signal, False otherwise.
233
+ """
234
+ if task_id not in self.tasks_sessions:
235
+ logger.warning(
236
+ "Cannot send signal - task not found: '%s'",
237
+ task_id,
238
+ extra={"task_id": task_id, "signal_type": signal_type},
239
+ )
240
+ return False
241
+
242
+ logger.info(
243
+ "Sending signal '%s' to task: '%s'",
244
+ signal_type,
245
+ task_id,
246
+ extra={"task_id": task_id, "signal_type": signal_type, "payload": payload},
247
+ )
248
+
249
+ await self.channel.update("tasks", signal_type, payload)
250
+ return True
251
+
252
+ async def cancel_task(self, task_id: str, timeout: float | None = None) -> bool:
253
+ """Cancel a task with graceful shutdown and fallback.
254
+
255
+ Returns:
256
+ bool: True if the task was cancelled successfully, False otherwise.
257
+ """
258
+ if task_id not in self.tasks:
259
+ logger.warning("Cannot cancel - task not found: '%s'", task_id, extra={"task_id": task_id})
260
+ return True
261
+
262
+ timeout = timeout or self.default_timeout
263
+ task = self.tasks[task_id]
264
+
265
+ logger.info(
266
+ "Initiating task cancellation: '%s', timeout: %.1fs",
267
+ task_id,
268
+ timeout,
269
+ extra={"task_id": task_id, "timeout": timeout},
270
+ )
271
+
272
+ try:
273
+ # Phase 1: Cooperative cancellation
274
+ # await self.send_signal(task_id, "cancel") # noqa: ERA001
275
+
276
+ # Wait for graceful shutdown
277
+ await asyncio.wait_for(task, timeout=timeout)
278
+
279
+ logger.info("Task cancelled gracefully: '%s'", task_id, extra={"task_id": task_id})
280
+
281
+ except asyncio.TimeoutError:
282
+ logger.warning(
283
+ "Graceful cancellation timed out for task: '%s', forcing cancellation",
284
+ task_id,
285
+ extra={"task_id": task_id, "timeout": timeout},
286
+ )
287
+
288
+ # Phase 2: Force cancellation
289
+ task.cancel()
290
+ with contextlib.suppress(asyncio.CancelledError):
291
+ await task
292
+
293
+ logger.warning("Task force-cancelled: '%s'", task_id, extra={"task_id": task_id})
294
+ return True
295
+
296
+ except Exception as e:
297
+ logger.error(
298
+ "Error during task cancellation: '%s'",
299
+ task_id,
300
+ extra={"task_id": task_id, "error": str(e)},
301
+ exc_info=True,
302
+ )
303
+ return False
304
+ return True
305
+
306
+ async def clean_session(self, task_id: str) -> bool:
307
+ """Clean up task session without cancelling the task.
308
+
309
+ Returns:
310
+ bool: True if the task was cleaned successfully, False otherwise.
311
+ """
312
+ if task_id not in self.tasks_sessions:
313
+ logger.warning("Cannot clean session - task not found: '%s'", task_id, extra={"task_id": task_id})
314
+ return False
315
+
316
+ await self.tasks_sessions[task_id].module.stop()
317
+ await self.cancel_task(task_id)
318
+
319
+ logger.info("Cleaning up session for task: '%s'", task_id, extra={"task_id": task_id})
320
+ self.tasks_sessions.pop(task_id, None)
321
+ return True
322
+
323
+ async def pause_task(self, task_id: str) -> bool:
324
+ """Pause a running task.
325
+
326
+ Returns:
327
+ bool: True if the task was paused successfully, False otherwise.
328
+ """
329
+ return await self.send_signal(task_id, "pause", {})
330
+
331
+ async def resume_task(self, task_id: str) -> bool:
332
+ """Resume a paused task.
333
+
334
+ Returns:
335
+ bool: True if the task was paused successfully, False otherwise.
336
+ """
337
+ return await self.send_signal(task_id, "resume", {})
338
+
339
+ async def get_task_status(self, task_id: str) -> bool:
340
+ """Request status from a task.
341
+
342
+ Returns:
343
+ bool: True if the task was paused successfully, False otherwise.
344
+ """
345
+ return await self.send_signal(task_id, "status", {})
346
+
347
+ async def cancel_all_tasks(self, timeout: float | None = None) -> dict[str, bool]:
348
+ """Cancel all running tasks.
349
+
350
+ Returns:
351
+ dict[str: bool]: True if the tasks were paused successfully, False otherwise.
352
+ """
353
+ timeout = timeout or self.default_timeout
354
+ task_ids = list(self.running_tasks)
355
+
356
+ logger.info(
357
+ "Cancelling all tasks: %d tasks", len(task_ids), extra={"task_count": len(task_ids), "timeout": timeout}
358
+ )
359
+
360
+ results = {}
361
+ for task_id in task_ids:
362
+ results[task_id] = await self.cancel_task(task_id, timeout)
363
+
364
+ return results
365
+
366
+ async def shutdown(self, timeout: float = 30.0) -> None:
367
+ """Graceful shutdown of all tasks."""
368
+ logger.info(
369
+ "TaskManager shutdown initiated, timeout: %.1fs",
370
+ timeout,
371
+ extra={"timeout": timeout, "active_tasks": len(self.running_tasks)},
372
+ )
373
+
374
+ self._shutdown_event.set()
375
+ results = await self.cancel_all_tasks(timeout)
376
+
377
+ failed_tasks = [task_id for task_id, success in results.items() if not success]
378
+ if failed_tasks:
379
+ logger.error(
380
+ "Failed to cancel %d tasks during shutdown: %s",
381
+ len(failed_tasks),
382
+ failed_tasks,
383
+ extra={"failed_tasks": failed_tasks, "failed_count": len(failed_tasks)},
384
+ )
385
+
386
+ logger.info(
387
+ "TaskManager shutdown completed, cancelled: %d, failed: %d",
388
+ len(results) - len(failed_tasks),
389
+ len(failed_tasks),
390
+ extra={"cancelled_count": len(results) - len(failed_tasks), "failed_count": len(failed_tasks)},
391
+ )