digitalkin 0.2.26__py3-none-any.whl → 0.3.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/grpc_servers/module_server.py +27 -44
  3. digitalkin/grpc_servers/module_servicer.py +27 -22
  4. digitalkin/grpc_servers/utils/models.py +1 -1
  5. digitalkin/logger.py +1 -9
  6. digitalkin/mixins/__init__.py +19 -0
  7. digitalkin/mixins/base_mixin.py +10 -0
  8. digitalkin/mixins/callback_mixin.py +24 -0
  9. digitalkin/mixins/chat_history_mixin.py +108 -0
  10. digitalkin/mixins/cost_mixin.py +76 -0
  11. digitalkin/mixins/file_history_mixin.py +99 -0
  12. digitalkin/mixins/filesystem_mixin.py +47 -0
  13. digitalkin/mixins/logger_mixin.py +59 -0
  14. digitalkin/mixins/storage_mixin.py +79 -0
  15. digitalkin/models/module/__init__.py +2 -0
  16. digitalkin/models/module/module.py +9 -1
  17. digitalkin/models/module/module_context.py +90 -6
  18. digitalkin/models/module/module_types.py +6 -6
  19. digitalkin/models/module/task_monitor.py +51 -0
  20. digitalkin/models/services/__init__.py +9 -0
  21. digitalkin/models/services/storage.py +39 -5
  22. digitalkin/modules/_base_module.py +47 -68
  23. digitalkin/modules/job_manager/base_job_manager.py +12 -8
  24. digitalkin/modules/job_manager/single_job_manager.py +84 -78
  25. digitalkin/modules/job_manager/surrealdb_repository.py +228 -0
  26. digitalkin/modules/job_manager/task_manager.py +389 -0
  27. digitalkin/modules/job_manager/task_session.py +275 -0
  28. digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
  29. digitalkin/modules/tool_module.py +10 -2
  30. digitalkin/modules/trigger_handler.py +7 -6
  31. digitalkin/services/cost/__init__.py +9 -2
  32. digitalkin/services/filesystem/default_filesystem.py +0 -2
  33. digitalkin/services/storage/grpc_storage.py +1 -1
  34. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/METADATA +20 -19
  35. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/RECORD +38 -25
  36. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/WHEEL +0 -0
  37. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/licenses/LICENSE +0 -0
  38. {digitalkin-0.2.26.dist-info → digitalkin-0.3.0.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,228 @@
1
+ """SurrealDB connection management."""
2
+
3
+ import datetime
4
+ import os
5
+ from collections.abc import AsyncGenerator
6
+ from typing import Any, Generic, TypeVar
7
+ from uuid import UUID
8
+
9
+ from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
10
+
11
+ from digitalkin.logger import logger
12
+
13
+ TSurreal = TypeVar("TSurreal", bound=AsyncHttpSurrealConnection | AsyncWsSurrealConnection)
14
+
15
+
16
+ class SurrealDBSetupBadIDError(Exception):
17
+ """Exception raised when an invalid ID is encountered during the setup process in the SurrealDB repository.
18
+
19
+ This error is used to indicate that the provided ID does not meet the
20
+ expected format or criteria.
21
+ """
22
+
23
+
24
+ class SurrealDBSetupVersionBadIDError(Exception):
25
+ """Exception raised when an invalid ID is encountered during the setup of a SurrealDB version.
26
+
27
+ This error is intended to signal that the provided ID does not meet
28
+ the expected format or criteria for a valid SurrealDB setup version ID.
29
+ """
30
+
31
+
32
+ class SurrealDBConnection(Generic[TSurreal]):
33
+ """Base repository for database operations.
34
+
35
+ This class provides common database operations that can be used by
36
+ specific table repositories.
37
+ """
38
+
39
+ db: TSurreal
40
+ timeout: datetime.timedelta
41
+
42
+ @staticmethod
43
+ def _valid_id(raw_id: str, table_name: str) -> RecordID:
44
+ """Validate and parse a raw ID string into a RecordID.
45
+
46
+ Args:
47
+ raw_id: The raw ID string to validate
48
+ table_name: table name to enforce
49
+
50
+ Raises:
51
+ SurrealDBSetupBadIDError: If the raw ID string is not valid
52
+
53
+ Returns:
54
+ RecordID: Parsed RecordID object if valid, None otherwise
55
+ """
56
+ try:
57
+ split_id = raw_id.split(":")
58
+ if split_id[0] != table_name:
59
+ msg = f"Invalid table name for ID: {raw_id}"
60
+ raise SurrealDBSetupBadIDError(msg)
61
+ return RecordID(split_id[0], split_id[1])
62
+ except IndexError:
63
+ raise SurrealDBSetupBadIDError
64
+
65
+ def __init__(
66
+ self,
67
+ database: str | None = None,
68
+ timeout: datetime.timedelta = datetime.timedelta(seconds=5),
69
+ ) -> None:
70
+ """Initialize the repository.
71
+
72
+ Args:
73
+ database: AsyncSurrealDB connection to a specific database
74
+ table_name: Name of the table to interact with
75
+ timeout: Timeout for database operations
76
+ """
77
+ self.timeout = timeout
78
+ base_url = os.getenv("SURREALDB_URL", "ws://localhost").strip()
79
+ port = (os.getenv("SURREALDB_PORT") or "").strip()
80
+ self.url = f"{base_url}{f':{port}' if port else ''}/rpc"
81
+
82
+ self.username = os.getenv("SURREALDB_USERNAME", "root")
83
+ self.password = os.getenv("SURREALDB_PASSWORD", "root")
84
+ self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
85
+ self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
86
+
87
+ async def init_surreal_instance(self) -> None:
88
+ """Init a SurrealDB connection instance."""
89
+ logger.debug("Connecting to SurrealDB at %s", self.url)
90
+ self.db = AsyncSurreal(self.url) # type: ignore
91
+ await self.db.signin({"username": self.username, "password": self.password})
92
+ await self.db.use(self.namespace, self.database)
93
+ logger.debug("Successfully connected to SurrealDB")
94
+
95
+ async def close(self) -> None:
96
+ """Close the SurrealDB connection if it exists."""
97
+ logger.debug("Closing SurrealDB connection")
98
+ await self.db.close()
99
+
100
+ async def create(
101
+ self,
102
+ table_name: str,
103
+ data: dict[str, Any],
104
+ ) -> list[dict[str, Any]] | dict[str, Any]:
105
+ """Create a new record.
106
+
107
+ Args:
108
+ table_name: Name of the table to insert into
109
+ data: Data to insert
110
+
111
+ Returns:
112
+ Dict[str, Any]: The created record as returned by the database
113
+ """
114
+ logger.debug("Creating record in %s with data: %s", table_name, data)
115
+ result = await self.db.create(table_name, data)
116
+ logger.debug("create result: %s", result)
117
+ return result
118
+
119
+ async def merge(
120
+ self,
121
+ table_name: str,
122
+ record_id: str | RecordID,
123
+ data: dict[str, Any],
124
+ ) -> list[dict[str, Any]] | dict[str, Any]:
125
+ """Update an existing record.
126
+
127
+ Args:
128
+ table_name: Name of the table to insert into
129
+ record_id: record ID to update
130
+ data: Data to insert
131
+
132
+ Returns:
133
+ Dict[str, Any]: The created record as returned by the database
134
+ """
135
+ if isinstance(record_id, str):
136
+ # validate surrealDB id if raw str
137
+ record_id = self._valid_id(record_id, table_name)
138
+ logger.debug("Updating record in %s with data: %s", record_id, data)
139
+ result = await self.db.merge(record_id, data)
140
+ logger.debug("update result: %s", result)
141
+ return result
142
+
143
+ async def update(
144
+ self,
145
+ table_name: str,
146
+ record_id: str | RecordID,
147
+ data: dict[str, Any],
148
+ ) -> list[dict[str, Any]] | dict[str, Any]:
149
+ """Update an existing record.
150
+
151
+ Args:
152
+ table_name: Name of the table to insert into
153
+ record_id: record ID to update
154
+ data: Data to insert
155
+
156
+ Returns:
157
+ Dict[str, Any]: The created record as returned by the database
158
+ """
159
+ if isinstance(record_id, str):
160
+ # validate surrealDB id if raw str
161
+ record_id = self._valid_id(record_id, table_name)
162
+ logger.debug("Updating record in %s with data: %s", record_id, data)
163
+ result = await self.db.update(record_id, data)
164
+ logger.debug("update result: %s", result)
165
+ return result
166
+
167
+ async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
168
+ """Execute a custom SurrealQL query.
169
+
170
+ Args:
171
+ query: SurrealQL query
172
+ params: Query parameters
173
+
174
+ Returns:
175
+ List[Dict[str, Any]]: Query results
176
+ """
177
+ logger.debug("execute_query: %s with params: %s", query, params)
178
+ result = await self.db.query(query, params or {})
179
+ logger.debug("execute_query result: %s", result)
180
+ return [result] if isinstance(result, dict) else result
181
+
182
+ async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
183
+ """Fetch a record from a table by a unique field.
184
+
185
+ Args:
186
+ table: Table name
187
+ value: Field value to match
188
+
189
+ Raises:
190
+ ValueError: If no records are found
191
+
192
+ Returns:
193
+ Dict with record data if found, else None
194
+ """
195
+ query = "SELECT * FROM type::table($table) WHERE task_id = $value;"
196
+ params = {"table": table, "value": value}
197
+
198
+ result = await self.execute_query(query, params)
199
+ if not result:
200
+ msg = f"No records found in table '{table}' with task_id '{value}'"
201
+ logger.error(msg)
202
+ raise ValueError(msg)
203
+
204
+ return result[0]
205
+
206
+ async def start_live(
207
+ self,
208
+ table_name: str,
209
+ ) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
210
+ """Create and subscribe to a live SurrealQL query.
211
+
212
+ Args:
213
+ table_name: Name of the table to insert into
214
+
215
+ Returns:
216
+ List[Dict[str, Any]]: Query results
217
+ """
218
+ live_id = await self.db.live(table_name, diff=False)
219
+ return live_id, await self.db.subscribe_live(live_id)
220
+
221
+ async def stop_live(self, live_id: UUID) -> None:
222
+ """Kill a live SurrealQL query.
223
+
224
+ Args:
225
+ live_id: record ID to watch for
226
+ """
227
+ logger.debug("KILL Subscribe live for: %s", live_id)
228
+ await self.db.kill(live_id)
@@ -0,0 +1,389 @@
1
+ """Task manager with comprehensive lifecycle management."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import datetime
6
+ from collections.abc import Coroutine
7
+ from typing import Any
8
+
9
+ from digitalkin.logger import logger
10
+ from digitalkin.models.module.task_monitor import SignalMessage, SignalType, TaskStatus
11
+ from digitalkin.modules._base_module import BaseModule
12
+ from digitalkin.modules.job_manager.task_session import SurrealDBConnection, TaskSession
13
+
14
+
15
+ class TaskManager:
16
+ """Task manager with comprehensive lifecycle management."""
17
+
18
+ tasks: dict[str, asyncio.Task]
19
+ tasks_sessions: dict[str, TaskSession]
20
+ channel: SurrealDBConnection
21
+ default_timeout: float
22
+ max_concurrent_tasks: int
23
+ _shutdown_event: asyncio.Event
24
+
25
+ def __init__(self, default_timeout: float = 10.0, max_concurrent_tasks: int = 100) -> None:
26
+ """."""
27
+ self.tasks = {}
28
+ self.tasks_sessions = {}
29
+ self.default_timeout = default_timeout
30
+ self.max_concurrent_tasks = max_concurrent_tasks
31
+ self._shutdown_event = asyncio.Event()
32
+
33
+ logger.info(
34
+ "TaskManager initialized with max_concurrent_tasks: %d, default_timeout: %.1f",
35
+ max_concurrent_tasks,
36
+ default_timeout,
37
+ extra={"max_concurrent_tasks": max_concurrent_tasks, "default_timeout": default_timeout},
38
+ )
39
+
40
+ @property
41
+ def task_count(self) -> int:
42
+ """."""
43
+ return len(self.tasks_sessions)
44
+
45
+ @property
46
+ def running_tasks(self) -> set[str]:
47
+ """."""
48
+ return {task_id for task_id, task in self.tasks.items() if not task.done()}
49
+
50
+ async def _cleanup_task(self, task_id: str) -> None:
51
+ """Clean up task resources."""
52
+ logger.debug("Cleaning up resources for task: '%s'", task_id, extra={"task_id": task_id})
53
+ if task_id in self.tasks_sessions:
54
+ await self.tasks_sessions[task_id].db.close()
55
+ # Remove from collections
56
+
57
+ async def _task_wrapper( # noqa: C901, PLR0915
58
+ self,
59
+ task_id: str,
60
+ coro: Coroutine[Any, Any, None],
61
+ session: TaskSession,
62
+ ) -> asyncio.Task[None]:
63
+ """Task wrapper that runs main, heartbeat, and listener concurrently.
64
+
65
+ The first to finish determines the outcome. Returns a Task that the
66
+ caller can await externally.
67
+
68
+ Returns:
69
+ asyncio.Task[None]: The supervisor task managing the lifecycle.
70
+ """
71
+
72
+ async def signal_wrapper() -> None:
73
+ try:
74
+ await self.channel.create(
75
+ "tasks",
76
+ SignalMessage(
77
+ task_id=task_id,
78
+ status=session.status,
79
+ action=SignalType.START,
80
+ ).model_dump(),
81
+ )
82
+ await session.listen_signals()
83
+ except asyncio.CancelledError:
84
+ logger.debug("Signal listener cancelled", extra={"task_id": task_id})
85
+ finally:
86
+ await self.channel.create(
87
+ "tasks",
88
+ SignalMessage(
89
+ task_id=task_id,
90
+ status=session.status,
91
+ action=SignalType.STOP,
92
+ ).model_dump(),
93
+ )
94
+ logger.info("Signal listener ended", extra={"task_id": task_id})
95
+
96
+ async def heartbeat_wrapper() -> None:
97
+ try:
98
+ await session.generate_heartbeats()
99
+ except asyncio.CancelledError:
100
+ logger.debug("Signal listener cancelled", extra={"task_id": task_id})
101
+ finally:
102
+ logger.info("Heartbeat task ended", extra={"task_id": task_id})
103
+
104
+ async def supervisor() -> None:
105
+ session.started_at = datetime.datetime.now(datetime.timezone.utc)
106
+ session.status = TaskStatus.RUNNING
107
+
108
+ main_task = asyncio.create_task(coro, name=f"{task_id}_main")
109
+ hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
110
+ sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
111
+
112
+ try:
113
+ done, pending = await asyncio.wait(
114
+ [main_task, sig_task, hb_task],
115
+ return_when=asyncio.FIRST_COMPLETED,
116
+ )
117
+
118
+ # One task completed -> cancel the others
119
+ for t in pending:
120
+ t.cancel()
121
+
122
+ # Propagate exception/result from the finished task
123
+ completed = next(iter(done))
124
+ await completed
125
+
126
+ if completed is main_task:
127
+ session.status = TaskStatus.COMPLETED
128
+ elif completed is sig_task or (completed is hb_task and sig_task.done()):
129
+ logger.debug(f"Task cancelled due to signal {sig_task=}")
130
+ session.status = TaskStatus.CANCELLED
131
+ elif completed is hb_task:
132
+ session.status = TaskStatus.FAILED
133
+ msg = f"Heartbeat stopped for {task_id}"
134
+ raise RuntimeError(msg) # noqa: TRY301
135
+
136
+ except asyncio.CancelledError:
137
+ session.status = TaskStatus.CANCELLED
138
+ raise
139
+ except Exception:
140
+ session.status = TaskStatus.FAILED
141
+ raise
142
+ finally:
143
+ session.completed_at = datetime.datetime.now(datetime.timezone.utc)
144
+ # Ensure all tasks are cleaned up
145
+ for t in [main_task, hb_task, sig_task]:
146
+ if not t.done():
147
+ t.cancel()
148
+ await asyncio.gather(main_task, hb_task, sig_task, return_exceptions=True)
149
+
150
+ # Return the supervisor task to be awaited outside
151
+ return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
152
+
153
+ async def create_task(
154
+ self,
155
+ task_id: str,
156
+ module: BaseModule,
157
+ coro: Coroutine[Any, Any, None],
158
+ heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
159
+ connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
160
+ ) -> None:
161
+ """Create and start a new managed task.
162
+
163
+ Raises:
164
+ ValueError: task_id duplicated
165
+ RuntimeError: task overload
166
+ """
167
+ if task_id in self.tasks:
168
+ # close Coroutine during runtime
169
+ coro.close()
170
+ logger.warning("Task creation failed - task already exists: '%s'", task_id, extra={"task_id": task_id})
171
+ msg = f"Task {task_id} already exists"
172
+ raise ValueError(msg)
173
+
174
+ if len(self.tasks) >= self.max_concurrent_tasks:
175
+ coro.close()
176
+ logger.error(
177
+ "Task creation failed - max concurrent tasks reached: %d",
178
+ self.max_concurrent_tasks,
179
+ extra={
180
+ "task_id": task_id,
181
+ "current_count": len(self.tasks),
182
+ "max_concurrent": self.max_concurrent_tasks,
183
+ },
184
+ )
185
+ msg = f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached"
186
+ raise RuntimeError(msg)
187
+
188
+ logger.info(
189
+ "Creating new task: '%s'",
190
+ task_id,
191
+ extra={
192
+ "task_id": task_id,
193
+ "heartbeat_interval": heartbeat_interval,
194
+ "connection_timeout": connection_timeout,
195
+ },
196
+ )
197
+
198
+ try:
199
+ # Initialize components
200
+ channel: SurrealDBConnection = SurrealDBConnection("task_manager", connection_timeout)
201
+ await channel.init_surreal_instance()
202
+ session = TaskSession(task_id, channel, module, heartbeat_interval)
203
+
204
+ self.tasks_sessions[task_id] = session
205
+
206
+ # Create wrapper task
207
+ self.tasks[task_id] = asyncio.create_task(self._task_wrapper(task_id, coro, session), name=task_id)
208
+
209
+ logger.info(
210
+ "Task created successfully: '%s'",
211
+ task_id,
212
+ extra={
213
+ "task_id": task_id,
214
+ "total_tasks": len(self.tasks),
215
+ },
216
+ )
217
+
218
+ except Exception as e:
219
+ logger.error(
220
+ "Failed to create task: '%s'", task_id, extra={"task_id": task_id, "error": str(e)}, exc_info=True
221
+ )
222
+ # Cleanup on failure
223
+ await self._cleanup_task(task_id)
224
+ raise
225
+
226
+ async def send_signal(self, task_id: str, signal_type: str, payload: dict) -> bool:
227
+ """Send signal to a specific task.
228
+
229
+ Returns:
230
+ bool: True if the task sent successfully the given signal, False otherwise.
231
+ """
232
+ if task_id not in self.tasks_sessions:
233
+ logger.warning(
234
+ "Cannot send signal - task not found: '%s'",
235
+ task_id,
236
+ extra={"task_id": task_id, "signal_type": signal_type},
237
+ )
238
+ return False
239
+
240
+ logger.info(
241
+ "Sending signal '%s' to task: '%s'",
242
+ signal_type,
243
+ task_id,
244
+ extra={"task_id": task_id, "signal_type": signal_type, "payload": payload},
245
+ )
246
+
247
+ await self.channel.update("tasks", signal_type, payload)
248
+ return True
249
+
250
+ async def cancel_task(self, task_id: str, timeout: float | None = None) -> bool:
251
+ """Cancel a task with graceful shutdown and fallback.
252
+
253
+ Returns:
254
+ bool: True if the task was cancelled successfully, False otherwise.
255
+ """
256
+ if task_id not in self.tasks:
257
+ logger.warning("Cannot cancel - task not found: '%s'", task_id, extra={"task_id": task_id})
258
+ return True
259
+
260
+ timeout = timeout or self.default_timeout
261
+ task = self.tasks[task_id]
262
+
263
+ logger.info(
264
+ "Initiating task cancellation: '%s', timeout: %.1fs",
265
+ task_id,
266
+ timeout,
267
+ extra={"task_id": task_id, "timeout": timeout},
268
+ )
269
+
270
+ try:
271
+ # Phase 1: Cooperative cancellation
272
+ # await self.send_signal(task_id, "cancel") # noqa: ERA001
273
+
274
+ # Wait for graceful shutdown
275
+ await asyncio.wait_for(task, timeout=timeout)
276
+
277
+ logger.info("Task cancelled gracefully: '%s'", task_id, extra={"task_id": task_id})
278
+
279
+ except asyncio.TimeoutError:
280
+ logger.warning(
281
+ "Graceful cancellation timed out for task: '%s', forcing cancellation",
282
+ task_id,
283
+ extra={"task_id": task_id, "timeout": timeout},
284
+ )
285
+
286
+ # Phase 2: Force cancellation
287
+ task.cancel()
288
+ with contextlib.suppress(asyncio.CancelledError):
289
+ await task
290
+
291
+ logger.warning("Task force-cancelled: '%s'", task_id, extra={"task_id": task_id})
292
+ return True
293
+
294
+ except Exception as e:
295
+ logger.error(
296
+ "Error during task cancellation: '%s'",
297
+ task_id,
298
+ extra={"task_id": task_id, "error": str(e)},
299
+ exc_info=True,
300
+ )
301
+ return False
302
+ return True
303
+
304
+ async def clean_session(self, task_id: str) -> bool:
305
+ """Clean up task session without cancelling the task.
306
+
307
+ Returns:
308
+ bool: True if the task was cleaned successfully, False otherwise.
309
+ """
310
+ if task_id not in self.tasks_sessions:
311
+ logger.warning("Cannot clean session - task not found: '%s'", task_id, extra={"task_id": task_id})
312
+ return False
313
+
314
+ await self.tasks_sessions[task_id].module.stop()
315
+ await self.cancel_task(task_id)
316
+
317
+ logger.info("Cleaning up session for task: '%s'", task_id, extra={"task_id": task_id})
318
+ self.tasks_sessions.pop(task_id, None)
319
+ return True
320
+
321
+ async def pause_task(self, task_id: str) -> bool:
322
+ """Pause a running task.
323
+
324
+ Returns:
325
+ bool: True if the task was paused successfully, False otherwise.
326
+ """
327
+ return await self.send_signal(task_id, "pause", {})
328
+
329
+ async def resume_task(self, task_id: str) -> bool:
330
+ """Resume a paused task.
331
+
332
+ Returns:
333
+ bool: True if the task was paused successfully, False otherwise.
334
+ """
335
+ return await self.send_signal(task_id, "resume", {})
336
+
337
+ async def get_task_status(self, task_id: str) -> bool:
338
+ """Request status from a task.
339
+
340
+ Returns:
341
+ bool: True if the task was paused successfully, False otherwise.
342
+ """
343
+ return await self.send_signal(task_id, "status", {})
344
+
345
+ async def cancel_all_tasks(self, timeout: float | None = None) -> dict[str, bool]:
346
+ """Cancel all running tasks.
347
+
348
+ Returns:
349
+ dict[str: bool]: True if the tasks were paused successfully, False otherwise.
350
+ """
351
+ timeout = timeout or self.default_timeout
352
+ task_ids = list(self.running_tasks)
353
+
354
+ logger.info(
355
+ "Cancelling all tasks: %d tasks", len(task_ids), extra={"task_count": len(task_ids), "timeout": timeout}
356
+ )
357
+
358
+ results = {}
359
+ for task_id in task_ids:
360
+ results[task_id] = await self.cancel_task(task_id, timeout)
361
+
362
+ return results
363
+
364
+ async def shutdown(self, timeout: float = 30.0) -> None:
365
+ """Graceful shutdown of all tasks."""
366
+ logger.info(
367
+ "TaskManager shutdown initiated, timeout: %.1fs",
368
+ timeout,
369
+ extra={"timeout": timeout, "active_tasks": len(self.running_tasks)},
370
+ )
371
+
372
+ self._shutdown_event.set()
373
+ results = await self.cancel_all_tasks(timeout)
374
+
375
+ failed_tasks = [task_id for task_id, success in results.items() if not success]
376
+ if failed_tasks:
377
+ logger.error(
378
+ "Failed to cancel %d tasks during shutdown: %s",
379
+ len(failed_tasks),
380
+ failed_tasks,
381
+ extra={"failed_tasks": failed_tasks, "failed_count": len(failed_tasks)},
382
+ )
383
+
384
+ logger.info(
385
+ "TaskManager shutdown completed, cancelled: %d, failed: %d",
386
+ len(results) - len(failed_tasks),
387
+ len(failed_tasks),
388
+ extra={"cancelled_count": len(results) - len(failed_tasks), "failed_count": len(failed_tasks)},
389
+ )