digitalkin 0.2.23__py3-none-any.whl → 0.3.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/core/__init__.py +1 -0
  3. digitalkin/core/common/__init__.py +9 -0
  4. digitalkin/core/common/factories.py +156 -0
  5. digitalkin/core/job_manager/__init__.py +1 -0
  6. digitalkin/{modules → core}/job_manager/base_job_manager.py +137 -31
  7. digitalkin/core/job_manager/single_job_manager.py +354 -0
  8. digitalkin/{modules → core}/job_manager/taskiq_broker.py +116 -22
  9. digitalkin/core/job_manager/taskiq_job_manager.py +541 -0
  10. digitalkin/core/task_manager/__init__.py +1 -0
  11. digitalkin/core/task_manager/base_task_manager.py +539 -0
  12. digitalkin/core/task_manager/local_task_manager.py +108 -0
  13. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  14. digitalkin/core/task_manager/surrealdb_repository.py +266 -0
  15. digitalkin/core/task_manager/task_executor.py +249 -0
  16. digitalkin/core/task_manager/task_session.py +406 -0
  17. digitalkin/grpc_servers/__init__.py +1 -19
  18. digitalkin/grpc_servers/_base_server.py +3 -3
  19. digitalkin/grpc_servers/module_server.py +27 -43
  20. digitalkin/grpc_servers/module_servicer.py +51 -36
  21. digitalkin/grpc_servers/registry_server.py +2 -2
  22. digitalkin/grpc_servers/registry_servicer.py +4 -4
  23. digitalkin/grpc_servers/utils/__init__.py +1 -0
  24. digitalkin/grpc_servers/utils/exceptions.py +0 -8
  25. digitalkin/grpc_servers/utils/grpc_client_wrapper.py +4 -4
  26. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  27. digitalkin/logger.py +73 -24
  28. digitalkin/mixins/__init__.py +19 -0
  29. digitalkin/mixins/base_mixin.py +10 -0
  30. digitalkin/mixins/callback_mixin.py +24 -0
  31. digitalkin/mixins/chat_history_mixin.py +110 -0
  32. digitalkin/mixins/cost_mixin.py +76 -0
  33. digitalkin/mixins/file_history_mixin.py +93 -0
  34. digitalkin/mixins/filesystem_mixin.py +46 -0
  35. digitalkin/mixins/logger_mixin.py +51 -0
  36. digitalkin/mixins/storage_mixin.py +79 -0
  37. digitalkin/models/core/__init__.py +1 -0
  38. digitalkin/{modules/job_manager → models/core}/job_manager_models.py +3 -3
  39. digitalkin/models/core/task_monitor.py +70 -0
  40. digitalkin/models/grpc_servers/__init__.py +1 -0
  41. digitalkin/{grpc_servers/utils → models/grpc_servers}/models.py +5 -5
  42. digitalkin/models/module/__init__.py +2 -0
  43. digitalkin/models/module/module.py +9 -1
  44. digitalkin/models/module/module_context.py +122 -6
  45. digitalkin/models/module/module_types.py +307 -19
  46. digitalkin/models/services/__init__.py +9 -0
  47. digitalkin/models/services/cost.py +1 -0
  48. digitalkin/models/services/storage.py +39 -5
  49. digitalkin/modules/_base_module.py +123 -118
  50. digitalkin/modules/tool_module.py +10 -2
  51. digitalkin/modules/trigger_handler.py +7 -6
  52. digitalkin/services/cost/__init__.py +9 -2
  53. digitalkin/services/cost/grpc_cost.py +9 -42
  54. digitalkin/services/filesystem/default_filesystem.py +0 -2
  55. digitalkin/services/filesystem/grpc_filesystem.py +10 -39
  56. digitalkin/services/setup/default_setup.py +5 -6
  57. digitalkin/services/setup/grpc_setup.py +52 -15
  58. digitalkin/services/storage/grpc_storage.py +4 -4
  59. digitalkin/services/user_profile/__init__.py +1 -0
  60. digitalkin/services/user_profile/default_user_profile.py +55 -0
  61. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  62. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  63. digitalkin/utils/__init__.py +28 -0
  64. digitalkin/utils/arg_parser.py +1 -1
  65. digitalkin/utils/development_mode_action.py +2 -2
  66. digitalkin/utils/dynamic_schema.py +483 -0
  67. digitalkin/utils/package_discover.py +1 -2
  68. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/METADATA +11 -30
  69. digitalkin-0.3.1.dev2.dist-info/RECORD +119 -0
  70. modules/dynamic_setup_module.py +362 -0
  71. digitalkin/grpc_servers/utils/factory.py +0 -180
  72. digitalkin/modules/job_manager/single_job_manager.py +0 -294
  73. digitalkin/modules/job_manager/taskiq_job_manager.py +0 -290
  74. digitalkin-0.2.23.dist-info/RECORD +0 -89
  75. /digitalkin/{grpc_servers/utils → models/grpc_servers}/types.py +0 -0
  76. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/WHEEL +0 -0
  77. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/licenses/LICENSE +0 -0
  78. {digitalkin-0.2.23.dist-info → digitalkin-0.3.1.dev2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,539 @@
1
+ """Base task manager with common lifecycle management."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import datetime
6
+ import types
7
+ from abc import ABC, abstractmethod
8
+ from collections.abc import Coroutine
9
+ from typing import Any
10
+
11
+ from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
12
+ from digitalkin.core.task_manager.task_session import TaskSession
13
+ from digitalkin.logger import logger
14
+ from digitalkin.models.core.task_monitor import CancellationReason
15
+ from digitalkin.modules._base_module import BaseModule
16
+
17
+
18
+ class BaseTaskManager(ABC):
19
+ """Base task manager with common lifecycle management.
20
+
21
+ Provides shared functionality for task orchestration, monitoring, signaling, and cancellation.
22
+ Subclasses implement specific execution strategies (local or remote).
23
+
24
+ Supports async context manager protocol for automatic resource cleanup:
25
+ async with LocalTaskManager() as manager:
26
+ await manager.create_task(...)
27
+ # Resources automatically cleaned up on exit
28
+ """
29
+
30
+ tasks: dict[str, asyncio.Task]
31
+ tasks_sessions: dict[str, TaskSession]
32
+ default_timeout: float
33
+ max_concurrent_tasks: int
34
+ _shutdown_event: asyncio.Event
35
+
36
+ def __init__(
37
+ self,
38
+ default_timeout: float = 10.0,
39
+ max_concurrent_tasks: int = 100,
40
+ ) -> None:
41
+ """Initialize task manager properties.
42
+
43
+ Args:
44
+ default_timeout: Default timeout for task operations in seconds
45
+ max_concurrent_tasks: Maximum number of concurrent tasks
46
+ """
47
+ self.tasks = {}
48
+ self.tasks_sessions = {}
49
+ self.default_timeout = default_timeout
50
+ self.max_concurrent_tasks = max_concurrent_tasks
51
+ self._shutdown_event = asyncio.Event()
52
+
53
+ logger.info(
54
+ "%s initialized with max_concurrent_tasks: %d, default_timeout: %.1f",
55
+ self.__class__.__name__,
56
+ max_concurrent_tasks,
57
+ default_timeout,
58
+ extra={
59
+ "max_concurrent_tasks": max_concurrent_tasks,
60
+ "default_timeout": default_timeout,
61
+ },
62
+ )
63
+
64
+ @property
65
+ def task_count(self) -> int:
66
+ """Number of managed tasks."""
67
+ return len(self.tasks_sessions)
68
+
69
+ @property
70
+ def running_tasks(self) -> set[str]:
71
+ """Get IDs of currently running tasks."""
72
+ return {task_id for task_id, task in self.tasks.items() if not task.done()}
73
+
74
+ async def _cleanup_task(self, task_id: str, mission_id: str) -> None:
75
+ """Clean up task resources.
76
+
77
+ Delegates cleanup to TaskSession which handles:
78
+ - Clearing queue items to free memory
79
+ - Stopping module (if not already stopped)
80
+ - Closing database connection (which kills live queries)
81
+
82
+ Then removes task from tracking dictionaries.
83
+
84
+ Args:
85
+ task_id: The ID of the task to clean up
86
+ mission_id: The ID of the mission associated with the task
87
+ """
88
+ session = self.tasks_sessions.get(task_id)
89
+ cancellation_reason = session.cancellation_reason.value if session else "no_session"
90
+ final_status = session.status.value if session else "unknown"
91
+
92
+ logger.debug(
93
+ "Cleaning up resources",
94
+ extra={
95
+ "mission_id": mission_id,
96
+ "task_id": task_id,
97
+ "final_status": final_status,
98
+ "cancellation_reason": cancellation_reason,
99
+ },
100
+ )
101
+
102
+ if session:
103
+ await session.cleanup()
104
+ self.tasks_sessions.pop(task_id, None)
105
+ logger.debug(
106
+ "Task session cleanup completed",
107
+ extra={
108
+ "mission_id": mission_id,
109
+ "task_id": task_id,
110
+ "final_status": final_status,
111
+ "cancellation_reason": cancellation_reason,
112
+ },
113
+ )
114
+
115
+ self.tasks.pop(task_id, None)
116
+
117
+ async def _validate_task_creation(self, task_id: str, mission_id: str, coro: Coroutine[Any, Any, None]) -> None:
118
+ """Validate task creation preconditions.
119
+
120
+ Args:
121
+ task_id: The ID of the task to create
122
+ mission_id: The ID of the mission associated with the task
123
+ coro: The coroutine to execute
124
+
125
+ Raises:
126
+ ValueError: If task_id already exists
127
+ RuntimeError: If max concurrent tasks reached
128
+ """
129
+ if task_id in self.tasks_sessions:
130
+ coro.close()
131
+ logger.warning(
132
+ "Task creation failed - task already exists: '%s'",
133
+ task_id,
134
+ extra={"mission_id": mission_id, "task_id": task_id},
135
+ )
136
+ msg = f"Task {task_id} already exists"
137
+ raise ValueError(msg)
138
+
139
+ if len(self.tasks_sessions) >= self.max_concurrent_tasks:
140
+ coro.close()
141
+ logger.error(
142
+ "Task creation failed - max concurrent tasks reached: %d",
143
+ self.max_concurrent_tasks,
144
+ extra={
145
+ "mission_id": mission_id,
146
+ "task_id": task_id,
147
+ "current_count": len(self.tasks_sessions),
148
+ "max_concurrent": self.max_concurrent_tasks,
149
+ },
150
+ )
151
+ msg = f"Maximum concurrent tasks ({self.max_concurrent_tasks}) reached"
152
+ raise RuntimeError(msg)
153
+
154
+ async def _create_session(
155
+ self,
156
+ task_id: str,
157
+ mission_id: str,
158
+ module: BaseModule,
159
+ heartbeat_interval: datetime.timedelta,
160
+ connection_timeout: datetime.timedelta,
161
+ ) -> tuple[SurrealDBConnection, TaskSession]:
162
+ """Create SurrealDB connection and task session.
163
+
164
+ Args:
165
+ task_id: The ID of the task
166
+ mission_id: The ID of the mission
167
+ module: The module instance
168
+ heartbeat_interval: Interval between heartbeats
169
+ connection_timeout: Connection timeout for SurrealDB
170
+
171
+ Returns:
172
+ Tuple of (channel, session)
173
+ """
174
+ channel: SurrealDBConnection = SurrealDBConnection("task_manager", connection_timeout)
175
+ await channel.init_surreal_instance()
176
+ session = TaskSession(
177
+ task_id=task_id,
178
+ mission_id=mission_id,
179
+ db=channel,
180
+ module=module,
181
+ heartbeat_interval=heartbeat_interval,
182
+ )
183
+ self.tasks_sessions[task_id] = session
184
+ return channel, session
185
+
186
+ @abstractmethod
187
+ async def create_task(
188
+ self,
189
+ task_id: str,
190
+ mission_id: str,
191
+ module: BaseModule,
192
+ coro: Coroutine[Any, Any, None],
193
+ heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
194
+ connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
195
+ ) -> None:
196
+ """Create and manage a new task.
197
+
198
+ Subclasses implement specific execution strategies.
199
+
200
+ Args:
201
+ task_id: Unique identifier for the task
202
+ mission_id: Mission identifier
203
+ module: Module instance to execute
204
+ coro: Coroutine to execute
205
+ heartbeat_interval: Interval between heartbeats
206
+ connection_timeout: Connection timeout for SurrealDB
207
+
208
+ Raises:
209
+ ValueError: If task_id duplicated
210
+ RuntimeError: If task overload
211
+ """
212
+ ...
213
+
214
+ async def send_signal(self, task_id: str, mission_id: str, signal_type: str, payload: dict) -> bool:
215
+ """Send signal to a specific task.
216
+
217
+ Args:
218
+ task_id: The ID of the task
219
+ mission_id: The ID of the mission
220
+ signal_type: Type of signal to send
221
+ payload: Signal payload
222
+
223
+ Returns:
224
+ True if the signal was sent successfully, False otherwise
225
+ """
226
+ if task_id not in self.tasks_sessions:
227
+ logger.warning(
228
+ "Cannot send signal - task not found: '%s'",
229
+ task_id,
230
+ extra={"mission_id": mission_id, "task_id": task_id, "signal_type": signal_type},
231
+ )
232
+ return False
233
+
234
+ logger.info(
235
+ "Sending signal '%s' to task: '%s'",
236
+ signal_type,
237
+ task_id,
238
+ extra={"mission_id": mission_id, "task_id": task_id, "signal_type": signal_type, "payload": payload},
239
+ )
240
+
241
+ # Use the task session's db connection to send the signal
242
+ session = self.tasks_sessions[task_id]
243
+ await session.db.update("signals", task_id, {"type": signal_type, "payload": payload})
244
+ return True
245
+
246
+ async def cancel_task(self, task_id: str, mission_id: str, timeout: float | None = None) -> bool:
247
+ """Cancel a task with graceful shutdown and fallback.
248
+
249
+ Args:
250
+ task_id: The ID of the task to cancel
251
+ mission_id: The ID of the mission
252
+ timeout: Optional timeout for cancellation
253
+
254
+ Returns:
255
+ True if the task was cancelled successfully, False otherwise
256
+ """
257
+ if task_id not in self.tasks:
258
+ logger.warning(
259
+ "Cannot cancel - task not found: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id}
260
+ )
261
+ # Still cleanup any orphaned session
262
+ await self._cleanup_task(task_id, mission_id)
263
+ return True
264
+
265
+ timeout = timeout or self.default_timeout
266
+ task = self.tasks[task_id]
267
+
268
+ logger.info(
269
+ "Initiating task cancellation: '%s', timeout: %.1fs",
270
+ task_id,
271
+ timeout,
272
+ extra={"mission_id": mission_id, "task_id": task_id, "timeout": timeout},
273
+ )
274
+
275
+ try:
276
+ # Wait for graceful shutdown
277
+ await asyncio.wait_for(task, timeout=timeout)
278
+
279
+ logger.info(
280
+ "Task cancelled gracefully: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id}
281
+ )
282
+
283
+ except asyncio.TimeoutError:
284
+ # Set timeout as cancellation reason
285
+ if task_id in self.tasks_sessions:
286
+ session = self.tasks_sessions[task_id]
287
+ if session.cancellation_reason == CancellationReason.UNKNOWN:
288
+ session.cancellation_reason = CancellationReason.TIMEOUT
289
+
290
+ logger.warning(
291
+ "Graceful cancellation timed out for task: '%s', forcing cancellation",
292
+ task_id,
293
+ extra={
294
+ "mission_id": mission_id,
295
+ "task_id": task_id,
296
+ "timeout": timeout,
297
+ "cancellation_reason": CancellationReason.TIMEOUT.value,
298
+ },
299
+ )
300
+
301
+ # Phase 2: Force cancellation
302
+ task.cancel()
303
+ with contextlib.suppress(asyncio.CancelledError):
304
+ await task
305
+
306
+ logger.warning(
307
+ "Task force-cancelled: '%s', reason: %s",
308
+ task_id,
309
+ CancellationReason.TIMEOUT.value,
310
+ extra={
311
+ "mission_id": mission_id,
312
+ "task_id": task_id,
313
+ "cancellation_reason": CancellationReason.TIMEOUT.value,
314
+ },
315
+ )
316
+ return True
317
+
318
+ except Exception as e:
319
+ logger.error(
320
+ "Error during task cancellation: '%s'",
321
+ task_id,
322
+ extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
323
+ exc_info=True,
324
+ )
325
+ return False
326
+ finally:
327
+ await self._cleanup_task(task_id, mission_id)
328
+ return True
329
+
330
+ async def clean_session(self, task_id: str, mission_id: str) -> bool:
331
+ """Clean up task session without cancelling the task.
332
+
333
+ Args:
334
+ task_id: The ID of the task
335
+ mission_id: The ID of the mission
336
+
337
+ Returns:
338
+ True if the task session was cleaned successfully, False otherwise.
339
+ """
340
+ if task_id not in self.tasks_sessions:
341
+ logger.warning(
342
+ "Cannot clean session - task not found: '%s'",
343
+ task_id,
344
+ extra={"mission_id": mission_id, "task_id": task_id},
345
+ )
346
+ return False
347
+
348
+ await self.cancel_task(mission_id=mission_id, task_id=task_id)
349
+
350
+ logger.info("Cleaning up session for task: '%s'", task_id, extra={"mission_id": mission_id, "task_id": task_id})
351
+ return True
352
+
353
+ async def pause_task(self, task_id: str, mission_id: str) -> bool:
354
+ """Pause a running task.
355
+
356
+ Args:
357
+ task_id: The ID of the task
358
+ mission_id: The ID of the mission
359
+
360
+ Returns:
361
+ True if the task was paused successfully, False otherwise
362
+ """
363
+ return await self.send_signal(task_id=task_id, mission_id=mission_id, signal_type="pause", payload={})
364
+
365
+ async def resume_task(self, task_id: str, mission_id: str) -> bool:
366
+ """Resume a paused task.
367
+
368
+ Args:
369
+ task_id: The ID of the task
370
+ mission_id: The ID of the mission
371
+
372
+ Returns:
373
+ True if the task was resumed successfully, False otherwise
374
+ """
375
+ return await self.send_signal(task_id=task_id, mission_id=mission_id, signal_type="resume", payload={})
376
+
377
+ async def get_task_status(self, task_id: str, mission_id: str) -> bool:
378
+ """Request status from a task.
379
+
380
+ Args:
381
+ task_id: The ID of the task
382
+ mission_id: The ID of the mission
383
+
384
+ Returns:
385
+ True if the status request was sent successfully, False otherwise
386
+ """
387
+ return await self.send_signal(task_id=task_id, mission_id=mission_id, signal_type="status", payload={})
388
+
389
+ async def cancel_all_tasks(self, mission_id: str, timeout: float | None = None) -> dict[str, bool | BaseException]:
390
+ """Cancel all running tasks.
391
+
392
+ Args:
393
+ mission_id: The ID of the mission
394
+ timeout: Optional timeout for cancellation
395
+
396
+ Returns:
397
+ Dictionary mapping task_id to cancellation success status
398
+ """
399
+ timeout = timeout or self.default_timeout
400
+ task_ids = list(self.running_tasks)
401
+
402
+ logger.info(
403
+ "Cancelling all tasks in parallel: %d tasks",
404
+ len(task_ids),
405
+ extra={"mission_id": mission_id, "task_count": len(task_ids), "timeout": timeout},
406
+ )
407
+
408
+ # Cancel all tasks in parallel to reduce latency
409
+ cancel_coros = [
410
+ self.cancel_task(
411
+ task_id=task_id,
412
+ mission_id=mission_id,
413
+ timeout=timeout,
414
+ )
415
+ for task_id in task_ids
416
+ ]
417
+ results_list = await asyncio.gather(*cancel_coros, return_exceptions=True)
418
+
419
+ # Build results dictionary
420
+ results: dict[str, bool | BaseException] = {}
421
+ for task_id, result in zip(task_ids, results_list):
422
+ if isinstance(result, Exception):
423
+ logger.error(
424
+ "Exception cancelling task: '%s', error: %s",
425
+ task_id,
426
+ result,
427
+ extra={
428
+ "mission_id": mission_id,
429
+ "task_id": task_id,
430
+ "error": str(result),
431
+ },
432
+ )
433
+ results[task_id] = False
434
+ else:
435
+ results[task_id] = result
436
+
437
+ return results
438
+
439
+ async def shutdown(self, mission_id: str, timeout: float = 30.0) -> None:
440
+ """Graceful shutdown of all tasks.
441
+
442
+ Args:
443
+ mission_id: The ID of the mission
444
+ timeout: Timeout for shutdown operations
445
+ """
446
+ logger.info(
447
+ "TaskManager shutdown initiated, timeout: %.1fs",
448
+ timeout,
449
+ extra={"mission_id": mission_id, "timeout": timeout, "active_tasks": len(self.running_tasks)},
450
+ )
451
+
452
+ self._shutdown_event.set()
453
+
454
+ # Mark all sessions with shutdown reason before cancellation
455
+ for task_id, session in self.tasks_sessions.items():
456
+ if session.cancellation_reason == CancellationReason.UNKNOWN:
457
+ session.cancellation_reason = CancellationReason.SHUTDOWN
458
+ logger.debug(
459
+ "Marking task for shutdown: '%s'",
460
+ task_id,
461
+ extra={
462
+ "mission_id": mission_id,
463
+ "task_id": task_id,
464
+ "cancellation_reason": CancellationReason.SHUTDOWN.value,
465
+ },
466
+ )
467
+
468
+ results = await self.cancel_all_tasks(mission_id, timeout)
469
+
470
+ failed_tasks = [task_id for task_id, success in results.items() if not success]
471
+ if failed_tasks:
472
+ logger.error(
473
+ "Failed to cancel %d tasks during shutdown: %s",
474
+ len(failed_tasks),
475
+ failed_tasks,
476
+ extra={
477
+ "mission_id": mission_id,
478
+ "failed_tasks": failed_tasks,
479
+ "failed_count": len(failed_tasks),
480
+ "cancellation_reason": CancellationReason.SHUTDOWN.value,
481
+ },
482
+ )
483
+
484
+ # Clean up any remaining sessions (in case cancellation didn't clean them)
485
+ remaining_sessions = list(self.tasks_sessions.keys())
486
+ if remaining_sessions:
487
+ logger.info(
488
+ "Cleaning up %d remaining task sessions after shutdown",
489
+ len(remaining_sessions),
490
+ extra={
491
+ "mission_id": mission_id,
492
+ "remaining_sessions": remaining_sessions,
493
+ "remaining_count": len(remaining_sessions),
494
+ },
495
+ )
496
+ cleanup_coros = [self._cleanup_task(task_id, mission_id) for task_id in remaining_sessions]
497
+ await asyncio.gather(*cleanup_coros, return_exceptions=True)
498
+
499
+ logger.info(
500
+ "TaskManager shutdown completed, cancelled: %d, failed: %d",
501
+ len(results) - len(failed_tasks),
502
+ len(failed_tasks),
503
+ extra={
504
+ "mission_id": mission_id,
505
+ "cancelled_count": len(results) - len(failed_tasks),
506
+ "failed_count": len(failed_tasks),
507
+ },
508
+ )
509
+
510
+ async def __aenter__(self) -> "BaseTaskManager":
511
+ """Enter async context manager.
512
+
513
+ Returns:
514
+ Self for use in async with statements
515
+ """
516
+ logger.debug("Entering %s context", self.__class__.__name__)
517
+ return self
518
+
519
+ async def __aexit__(
520
+ self,
521
+ exc_type: type[BaseException] | None,
522
+ exc_val: BaseException | None,
523
+ exc_tb: types.TracebackType | None,
524
+ ) -> None:
525
+ """Exit async context manager and clean up resources.
526
+
527
+ Args:
528
+ exc_type: Exception type if an exception occurred
529
+ exc_val: Exception value if an exception occurred
530
+ exc_tb: Exception traceback if an exception occurred
531
+ """
532
+ logger.debug(
533
+ "Exiting %s context, exception: %s",
534
+ self.__class__.__name__,
535
+ exc_type,
536
+ extra={"exc_type": exc_type, "exc_val": exc_val},
537
+ )
538
+ # Shutdown with default mission_id for context manager usage
539
+ await self.shutdown(mission_id="context_manager_cleanup")
@@ -0,0 +1,108 @@
1
+ """Local task manager for single-process execution."""
2
+
3
+ import datetime
4
+ from collections.abc import Coroutine
5
+ from typing import Any
6
+
7
+ from digitalkin.core.task_manager.base_task_manager import BaseTaskManager
8
+ from digitalkin.core.task_manager.task_executor import TaskExecutor
9
+ from digitalkin.logger import logger
10
+ from digitalkin.modules._base_module import BaseModule
11
+
12
+
13
+ class LocalTaskManager(BaseTaskManager):
14
+ """Task manager for local execution in the same process.
15
+
16
+ Executes tasks locally using TaskExecutor with the supervisor pattern.
17
+ Suitable for single-server deployments and development.
18
+ """
19
+
20
+ _executor: TaskExecutor
21
+
22
+ def __init__(
23
+ self,
24
+ default_timeout: float = 10.0,
25
+ max_concurrent_tasks: int = 100,
26
+ ) -> None:
27
+ """Initialize local task manager.
28
+
29
+ Args:
30
+ default_timeout: Default timeout for task operations in seconds
31
+ max_concurrent_tasks: Maximum number of concurrent tasks
32
+ """
33
+ super().__init__(default_timeout, max_concurrent_tasks)
34
+ self._executor = TaskExecutor()
35
+
36
+ async def create_task(
37
+ self,
38
+ task_id: str,
39
+ mission_id: str,
40
+ module: BaseModule,
41
+ coro: Coroutine[Any, Any, None],
42
+ heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
43
+ connection_timeout: datetime.timedelta = datetime.timedelta(seconds=5),
44
+ ) -> None:
45
+ """Create and execute a task locally using TaskExecutor.
46
+
47
+ Args:
48
+ task_id: Unique identifier for the task
49
+ mission_id: Mission identifier
50
+ module: Module instance to execute
51
+ coro: Coroutine to execute
52
+ heartbeat_interval: Interval between heartbeats
53
+ connection_timeout: Connection timeout for SurrealDB
54
+
55
+ Raises:
56
+ ValueError: If task_id duplicated
57
+ RuntimeError: If task overload
58
+ """
59
+ # Validation
60
+ await self._validate_task_creation(task_id, mission_id, coro)
61
+
62
+ logger.info(
63
+ "Creating local task: '%s'",
64
+ task_id,
65
+ extra={
66
+ "mission_id": mission_id,
67
+ "task_id": task_id,
68
+ "heartbeat_interval": heartbeat_interval,
69
+ "connection_timeout": connection_timeout,
70
+ },
71
+ )
72
+
73
+ try:
74
+ # Create session
75
+ channel, session = await self._create_session(
76
+ task_id, mission_id, module, heartbeat_interval, connection_timeout
77
+ )
78
+
79
+ # Execute task using TaskExecutor
80
+ supervisor_task = await self._executor.execute_task(
81
+ task_id,
82
+ mission_id,
83
+ coro,
84
+ session,
85
+ channel,
86
+ )
87
+ self.tasks[task_id] = supervisor_task
88
+
89
+ logger.info(
90
+ "Local task created and started: '%s'",
91
+ task_id,
92
+ extra={
93
+ "mission_id": mission_id,
94
+ "task_id": task_id,
95
+ "total_tasks": len(self.tasks),
96
+ },
97
+ )
98
+
99
+ except Exception as e:
100
+ logger.error(
101
+ "Failed to create local task: '%s'",
102
+ task_id,
103
+ extra={"mission_id": mission_id, "task_id": task_id, "error": str(e)},
104
+ exc_info=True,
105
+ )
106
+ # Cleanup on failure
107
+ await self._cleanup_task(task_id, mission_id=mission_id)
108
+ raise