digitalkin 0.3.0rc2__py3-none-any.whl → 0.3.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/core/common/__init__.py +9 -0
  3. digitalkin/core/common/factories.py +156 -0
  4. digitalkin/core/job_manager/base_job_manager.py +128 -28
  5. digitalkin/core/job_manager/single_job_manager.py +80 -25
  6. digitalkin/core/job_manager/taskiq_broker.py +114 -19
  7. digitalkin/core/job_manager/taskiq_job_manager.py +292 -39
  8. digitalkin/core/task_manager/base_task_manager.py +464 -0
  9. digitalkin/core/task_manager/local_task_manager.py +108 -0
  10. digitalkin/core/task_manager/remote_task_manager.py +87 -0
  11. digitalkin/core/task_manager/surrealdb_repository.py +43 -4
  12. digitalkin/core/task_manager/task_executor.py +173 -0
  13. digitalkin/core/task_manager/task_session.py +34 -12
  14. digitalkin/grpc_servers/module_server.py +2 -2
  15. digitalkin/grpc_servers/module_servicer.py +4 -3
  16. digitalkin/grpc_servers/registry_server.py +1 -1
  17. digitalkin/grpc_servers/registry_servicer.py +4 -4
  18. digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
  19. digitalkin/models/grpc_servers/models.py +4 -4
  20. digitalkin/services/cost/grpc_cost.py +8 -41
  21. digitalkin/services/filesystem/grpc_filesystem.py +9 -38
  22. digitalkin/services/setup/default_setup.py +5 -6
  23. digitalkin/services/setup/grpc_setup.py +51 -14
  24. digitalkin/services/storage/grpc_storage.py +2 -2
  25. digitalkin/services/user_profile/__init__.py +1 -0
  26. digitalkin/services/user_profile/default_user_profile.py +55 -0
  27. digitalkin/services/user_profile/grpc_user_profile.py +69 -0
  28. digitalkin/services/user_profile/user_profile_strategy.py +40 -0
  29. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/METADATA +7 -7
  30. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/RECORD +33 -23
  31. digitalkin/core/task_manager/task_manager.py +0 -442
  32. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/WHEEL +0 -0
  33. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/licenses/LICENSE +0 -0
  34. {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """SurrealDB connection management."""
2
2
 
3
+ import asyncio
3
4
  import datetime
4
5
  import os
5
6
  from collections.abc import AsyncGenerator
@@ -38,6 +39,7 @@ class SurrealDBConnection(Generic[TSurreal]):
38
39
 
39
40
  db: TSurreal
40
41
  timeout: datetime.timedelta
42
+ _live_queries: set[UUID] # Track active live queries for cleanup
41
43
 
42
44
  @staticmethod
43
45
  def _valid_id(raw_id: str, table_name: str) -> RecordID:
@@ -82,6 +84,7 @@ class SurrealDBConnection(Generic[TSurreal]):
82
84
  self.password = os.getenv("SURREALDB_PASSWORD", "root")
83
85
  self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
84
86
  self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
87
+ self._live_queries = set() # Initialize live queries tracker
85
88
 
86
89
  async def init_surreal_instance(self) -> None:
87
90
  """Init a SurrealDB connection instance."""
@@ -92,7 +95,37 @@ class SurrealDBConnection(Generic[TSurreal]):
92
95
  logger.debug("Successfully connected to SurrealDB")
93
96
 
94
97
  async def close(self) -> None:
95
- """Close the SurrealDB connection if it exists."""
98
+ """Close the SurrealDB connection if it exists.
99
+
100
+ This will also kill all active live queries to prevent memory leaks.
101
+ """
102
+ # Kill all tracked live queries before closing connection
103
+ if self._live_queries:
104
+ logger.debug("Killing %d active live queries before closing", len(self._live_queries))
105
+ live_query_ids = list(self._live_queries)
106
+
107
+ # Kill all queries concurrently, capturing any exceptions
108
+ results = await asyncio.gather(
109
+ *[self.db.kill(live_id) for live_id in live_query_ids], return_exceptions=True
110
+ )
111
+
112
+ # Process results and track failures
113
+ failed_queries = []
114
+ for live_id, result in zip(live_query_ids, results):
115
+ if isinstance(result, (ConnectionError, TimeoutError, Exception)):
116
+ failed_queries.append((live_id, str(result)))
117
+ else:
118
+ self._live_queries.discard(live_id)
119
+
120
+ # Log aggregated failures once instead of per-query
121
+ if failed_queries:
122
+ logger.warning(
123
+ "Failed to kill %d live queries: %s",
124
+ len(failed_queries),
125
+ failed_queries[:5], # Only log first 5 to avoid log spam
126
+ extra={"total_failed": len(failed_queries)},
127
+ )
128
+
96
129
  logger.debug("Closing SurrealDB connection")
97
130
  await self.db.close()
98
131
 
@@ -208,20 +241,26 @@ class SurrealDBConnection(Generic[TSurreal]):
208
241
  ) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
209
242
  """Create and subscribe to a live SurrealQL query.
210
243
 
244
+ The live query ID is tracked to ensure proper cleanup on connection close.
245
+
211
246
  Args:
212
247
  table_name: Name of the table to insert into
213
248
 
214
249
  Returns:
215
- List[Dict[str, Any]]: Query results
250
+ tuple[UUID, AsyncGenerator]: Live query ID and subscription generator
216
251
  """
217
252
  live_id = await self.db.live(table_name, diff=False)
253
+ self._live_queries.add(live_id) # Track for cleanup
254
+ logger.debug("Started live query %s for table %s (total: %d)", live_id, table_name, len(self._live_queries))
218
255
  return live_id, await self.db.subscribe_live(live_id)
219
256
 
220
257
  async def stop_live(self, live_id: UUID) -> None:
221
258
  """Kill a live SurrealQL query.
222
259
 
223
260
  Args:
224
- live_id: record ID to watch for
261
+ live_id: Live query ID to kill
225
262
  """
226
- logger.debug("KILL Subscribe live for: %s", live_id)
263
+ logger.debug("Killing live query: %s", live_id)
227
264
  await self.db.kill(live_id)
265
+ self._live_queries.discard(live_id) # Remove from tracker
266
+ logger.debug("Stopped live query %s (remaining: %d)", live_id, len(self._live_queries))
@@ -0,0 +1,173 @@
1
+ """Task executor for running tasks with full lifecycle management."""
2
+
3
+ import asyncio
4
+ import datetime
5
+ from collections.abc import Coroutine
6
+ from typing import Any
7
+
8
+ from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
9
+ from digitalkin.core.task_manager.task_session import TaskSession
10
+ from digitalkin.logger import logger
11
+ from digitalkin.models.core.task_monitor import SignalMessage, SignalType, TaskStatus
12
+
13
+
14
+ class TaskExecutor:
15
+ """Executes tasks with the supervisor pattern (main + heartbeat + signal listener).
16
+
17
+ Pure execution logic - no task registry or orchestration.
18
+ Used by workers to run distributed tasks or by TaskManager for local execution.
19
+ """
20
+
21
+ @staticmethod
22
+ async def execute_task( # noqa: C901, PLR0915
23
+ task_id: str,
24
+ mission_id: str,
25
+ coro: Coroutine[Any, Any, None],
26
+ session: TaskSession,
27
+ channel: SurrealDBConnection,
28
+ ) -> asyncio.Task[None]:
29
+ """Execute a task using the supervisor pattern.
30
+
31
+ Runs three concurrent sub-tasks:
32
+ - Main coroutine (the actual work)
33
+ - Heartbeat generator (sends heartbeats to SurrealDB)
34
+ - Signal listener (watches for stop/pause/resume signals)
35
+
36
+ The first task to complete determines the outcome.
37
+
38
+ Args:
39
+ task_id: Unique identifier for the task
40
+ mission_id: Mission identifier for the task
41
+ coro: The coroutine to execute (module.start(...))
42
+ session: TaskSession for state management
43
+ channel: SurrealDB connection for signals
44
+
45
+ Returns:
46
+ asyncio.Task: The supervisor task managing the lifecycle
47
+ """
48
+
49
+ async def signal_wrapper() -> None:
50
+ """Create initial signal record and listen for signals."""
51
+ try:
52
+ await channel.create(
53
+ "tasks",
54
+ SignalMessage(
55
+ task_id=task_id,
56
+ mission_id=mission_id,
57
+ status=session.status,
58
+ action=SignalType.START,
59
+ ).model_dump(),
60
+ )
61
+ await session.listen_signals()
62
+ except asyncio.CancelledError:
63
+ logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
64
+ finally:
65
+ await channel.create(
66
+ "tasks",
67
+ SignalMessage(
68
+ task_id=task_id,
69
+ mission_id=mission_id,
70
+ status=session.status,
71
+ action=SignalType.STOP,
72
+ ).model_dump(),
73
+ )
74
+ logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
75
+
76
+ async def heartbeat_wrapper() -> None:
77
+ """Generate heartbeats for task health monitoring."""
78
+ try:
79
+ await session.generate_heartbeats()
80
+ except asyncio.CancelledError:
81
+ logger.debug("Heartbeat cancelled", extra={"mission_id": mission_id, "task_id": task_id})
82
+ finally:
83
+ logger.info("Heartbeat task ended", extra={"mission_id": mission_id, "task_id": task_id})
84
+
85
+ async def supervisor() -> None:
86
+ """Supervise the three concurrent tasks and handle outcomes.
87
+
88
+ Raises:
89
+ RuntimeError: If the heartbeat task stops unexpectedly.
90
+ asyncio.CancelledError: If the supervisor task is cancelled.
91
+ """
92
+ session.started_at = datetime.datetime.now(datetime.timezone.utc)
93
+ session.status = TaskStatus.RUNNING
94
+
95
+ # Create tasks with proper exception handling
96
+ main_task = None
97
+ hb_task = None
98
+ sig_task = None
99
+
100
+ try:
101
+ main_task = asyncio.create_task(coro, name=f"{task_id}_main")
102
+ hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
103
+ sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
104
+ done, pending = await asyncio.wait(
105
+ [main_task, sig_task, hb_task],
106
+ return_when=asyncio.FIRST_COMPLETED,
107
+ )
108
+
109
+ # One task completed -> cancel the others
110
+ for t in pending:
111
+ t.cancel()
112
+
113
+ # Propagate exception/result from the finished task
114
+ completed = next(iter(done))
115
+ await completed
116
+
117
+ # Determine final status based on which task completed
118
+ if completed is main_task:
119
+ session.status = TaskStatus.COMPLETED
120
+ logger.info(
121
+ "Main task completed successfully",
122
+ extra={"mission_id": mission_id, "task_id": task_id},
123
+ )
124
+ elif completed is sig_task or (completed is hb_task and sig_task.done()):
125
+ logger.debug(
126
+ f"Task cancelled due to signal {sig_task=}",
127
+ extra={"mission_id": mission_id, "task_id": task_id},
128
+ )
129
+ session.status = TaskStatus.CANCELLED
130
+ elif completed is hb_task:
131
+ session.status = TaskStatus.FAILED
132
+ logger.error(
133
+ f"Heartbeat stopped for {task_id}",
134
+ extra={"mission_id": mission_id, "task_id": task_id},
135
+ )
136
+ msg = f"Heartbeat stopped for {task_id}"
137
+ raise RuntimeError(msg) # noqa: TRY301
138
+
139
+ except asyncio.CancelledError:
140
+ session.status = TaskStatus.CANCELLED
141
+ logger.info("Task cancelled", extra={"mission_id": mission_id, "task_id": task_id})
142
+ raise
143
+ except Exception:
144
+ session.status = TaskStatus.FAILED
145
+ logger.exception("Task failed", extra={"mission_id": mission_id, "task_id": task_id})
146
+ raise
147
+ finally:
148
+ session.completed_at = datetime.datetime.now(datetime.timezone.utc)
149
+ # Ensure all tasks are cleaned up
150
+ tasks_to_cleanup = [t for t in [main_task, hb_task, sig_task] if t is not None]
151
+ for t in tasks_to_cleanup:
152
+ if not t.done():
153
+ t.cancel()
154
+ if tasks_to_cleanup:
155
+ await asyncio.gather(*tasks_to_cleanup, return_exceptions=True)
156
+
157
+ logger.info(
158
+ "Task execution completed with status: %s",
159
+ session.status,
160
+ extra={
161
+ "mission_id": mission_id,
162
+ "task_id": task_id,
163
+ "status": session.status,
164
+ "duration": (
165
+ (session.completed_at - session.started_at).total_seconds()
166
+ if session.started_at and session.completed_at
167
+ else None
168
+ ),
169
+ },
170
+ )
171
+
172
+ # Return the supervisor task to be awaited by caller
173
+ return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
@@ -42,23 +42,14 @@ class TaskSession:
42
42
  db: SurrealDBConnection,
43
43
  module: BaseModule,
44
44
  heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
45
- queue_maxsize: int = 1000,
46
45
  ) -> None:
47
- """Initialize Task Session.
48
-
49
- Args:
50
- task_id: Unique task identifier
51
- mission_id: Mission identifier
52
- db: SurrealDB connection
53
- module: Module instance
54
- heartbeat_interval: Interval between heartbeats
55
- queue_maxsize: Maximum size for the queue (0 = unlimited)
56
- """
46
+ """Initialize Task Session."""
57
47
  self.db = db
58
48
  self.module = module
59
49
 
60
50
  self.status = TaskStatus.PENDING
61
- self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
51
+ # Bounded queue to prevent unbounded memory growth (max 1000 items)
52
+ self.queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
62
53
 
63
54
  self.task_id = task_id
64
55
  self.mission_id = mission_id
@@ -326,3 +317,34 @@ class TaskSession:
326
317
  self.task_id,
327
318
  extra={"task_id": self.task_id},
328
319
  )
320
+
321
+ async def cleanup(self) -> None:
322
+ """Clean up task session resources.
323
+
324
+ This includes:
325
+ - Clearing queue to free memory
326
+ - Stopping module
327
+ - Closing database connection
328
+ - Clearing module reference
329
+ """
330
+ # Clear queue to free memory
331
+ try:
332
+ while not self.queue.empty():
333
+ self.queue.get_nowait()
334
+ except asyncio.QueueEmpty:
335
+ pass
336
+
337
+ # Stop module
338
+ try:
339
+ await self.module.stop()
340
+ except Exception:
341
+ logger.exception(
342
+ "Error stopping module during cleanup",
343
+ extra={"mission_id": self.mission_id, "task_id": self.task_id},
344
+ )
345
+
346
+ # Close DB connection (kills all live queries)
347
+ await self.db.close()
348
+
349
+ # Clear module reference to allow garbage collection
350
+ self.module = None # type: ignore
@@ -4,11 +4,11 @@ import uuid
4
4
  from pathlib import Path
5
5
 
6
6
  import grpc
7
- from digitalkin_proto.digitalkin.module.v2 import (
7
+ from digitalkin_proto.agentic_mesh_protocol.module.v1 import (
8
8
  module_service_pb2,
9
9
  module_service_pb2_grpc,
10
10
  )
11
- from digitalkin_proto.digitalkin.module_registry.v2 import (
11
+ from digitalkin_proto.agentic_mesh_protocol.module_registry.v1 import (
12
12
  metadata_pb2,
13
13
  module_registry_service_pb2_grpc,
14
14
  registration_pb2,
@@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator
5
5
  from typing import Any
6
6
 
7
7
  import grpc
8
- from digitalkin_proto.digitalkin.module.v2 import (
8
+ from digitalkin_proto.agentic_mesh_protocol.module.v1 import (
9
9
  information_pb2,
10
10
  lifecycle_pb2,
11
11
  module_service_pb2_grpc,
@@ -172,7 +172,8 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
172
172
  )
173
173
  # Process the module input
174
174
  # TODO: Check failure of input data format
175
- input_data = self.module_class.create_input_model(dict(request.input.items()))
175
+ input_data = self.module_class.create_input_model(json_format.MessageToDict(request.input))
176
+
176
177
  setup_data_class = self.setup.get_setup(
177
178
  setup_dict={
178
179
  "setup_id": request.setup_id,
@@ -225,7 +226,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
225
226
  proto = json_format.ParseDict(message, struct_pb2.Struct(), ignore_unknown_fields=True)
226
227
  yield lifecycle_pb2.StartModuleResponse(success=True, output=proto, job_id=job_id)
227
228
  finally:
228
- await self.job_manager.tasks[job_id]
229
+ await self.job_manager.wait_for_completion(job_id)
229
230
  await self.job_manager.clean_session(job_id, mission_id=request.mission_id)
230
231
 
231
232
  logger.info("Job %s finished", job_id)
@@ -1,6 +1,6 @@
1
1
  """Registry gRPC server implementation for DigitalKin."""
2
2
 
3
- from digitalkin_proto.digitalkin.module_registry.v2 import (
3
+ from digitalkin_proto.agentic_mesh_protocol.module_registry.v1 import (
4
4
  module_registry_service_pb2,
5
5
  module_registry_service_pb2_grpc,
6
6
  )
@@ -9,7 +9,7 @@ from collections.abc import Iterator
9
9
  from enum import Enum
10
10
 
11
11
  import grpc
12
- from digitalkin_proto.digitalkin.module_registry.v2 import (
12
+ from digitalkin_proto.agentic_mesh_protocol.module_registry.v1 import (
13
13
  discover_pb2,
14
14
  metadata_pb2,
15
15
  module_registry_service_pb2_grpc,
@@ -344,7 +344,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
344
344
  return status_pb2.ModuleStatusResponse()
345
345
 
346
346
  module = self.registered_modules[request.module_id]
347
- return status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.name)
347
+ return status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.value)
348
348
 
349
349
  def ListModuleStatus( # noqa: N802
350
350
  self,
@@ -379,7 +379,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
379
379
  list_size = len(self.registered_modules)
380
380
 
381
381
  modules_statuses = [
382
- status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.name)
382
+ status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.value)
383
383
  for module in list(self.registered_modules.values())[request.offset : request.offset + list_size]
384
384
  ]
385
385
 
@@ -409,7 +409,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
409
409
  for module in self.registered_modules.values():
410
410
  yield status_pb2.ModuleStatusResponse(
411
411
  module_id=module.module_id,
412
- status=module.status.name,
412
+ status=module.status.value,
413
413
  )
414
414
 
415
415
  def UpdateModuleStatus( # noqa: N802
@@ -0,0 +1,53 @@
1
+ """Shared error handling utilities for gRPC services."""
2
+
3
+ from collections.abc import Generator
4
+ from contextlib import contextmanager
5
+ from typing import Any
6
+
7
+ from digitalkin.grpc_servers.utils.exceptions import ServerError
8
+ from digitalkin.logger import logger
9
+
10
+
11
+ class GrpcErrorHandlerMixin:
12
+ """Mixin class providing common gRPC error handling functionality."""
13
+
14
+ @contextmanager
15
+ def handle_grpc_errors( # noqa: PLR6301
16
+ self,
17
+ operation: str,
18
+ service_error_class: type[Exception] | None = None,
19
+ ) -> Generator[Any, Any, Any]:
20
+ """Handle gRPC errors for the given operation.
21
+
22
+ Args:
23
+ operation: Name of the operation being performed.
24
+ service_error_class: Optional specific service exception class to raise.
25
+ If not provided, uses the generic ServerError.
26
+
27
+ Yields:
28
+ Context for the operation.
29
+
30
+ Raises:
31
+ ServerError: For gRPC-related errors.
32
+ service_error_class: For service-specific errors if provided.
33
+ """
34
+ if service_error_class is None:
35
+ service_error_class = ServerError
36
+
37
+ try:
38
+ yield
39
+ except service_error_class as e:
40
+ # Re-raise service-specific errors as-is
41
+ msg = f"{service_error_class.__name__} in {operation}: {e}"
42
+ logger.exception(msg)
43
+ raise service_error_class(msg) from e
44
+ except ServerError as e:
45
+ # Handle gRPC server errors
46
+ msg = f"gRPC {operation} failed: {e}"
47
+ logger.exception(msg)
48
+ raise ServerError(msg) from e
49
+ except Exception as e:
50
+ # Handle unexpected errors
51
+ msg = f"Unexpected error in {operation}: {e}"
52
+ logger.exception(msg)
53
+ raise service_error_class(msg) from e
@@ -175,8 +175,8 @@ class ClientConfig(ChannelConfig):
175
175
  credentials: ClientCredentials | None = Field(None, description="Client credentials for secure mode")
176
176
  channel_options: list[tuple[str, Any]] = Field(
177
177
  default_factory=lambda: [
178
- ("grpc.max_receive_message_length", 100 * 1024 * 1024), # 100MB
179
- ("grpc.max_send_message_length", 100 * 1024 * 1024), # 100MB
178
+ ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
179
+ ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
180
180
  ],
181
181
  description="Additional channel options",
182
182
  )
@@ -223,8 +223,8 @@ class ServerConfig(ChannelConfig):
223
223
  credentials: ServerCredentials | None = Field(None, description="Server credentials for secure mode")
224
224
  server_options: list[tuple[str, Any]] = Field(
225
225
  default_factory=lambda: [
226
- ("grpc.max_receive_message_length", 100 * 1024 * 1024), # 100MB
227
- ("grpc.max_send_message_length", 100 * 1024 * 1024), # 100MB
226
+ ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
227
+ ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
228
228
  ],
229
229
  description="Additional server options",
230
230
  )
@@ -1,14 +1,12 @@
1
1
  """This module implements the gRPC Cost strategy."""
2
2
 
3
- from collections.abc import Generator
4
- from contextlib import contextmanager
5
- from typing import Any, Literal
3
+ from typing import Literal
6
4
 
7
- from digitalkin_proto.digitalkin.cost.v1 import cost_pb2, cost_service_pb2_grpc
5
+ from digitalkin_proto.agentic_mesh_protocol.cost.v1 import cost_pb2, cost_service_pb2_grpc
8
6
  from google.protobuf import json_format
9
7
 
10
- from digitalkin.grpc_servers.utils.exceptions import ServerError
11
8
  from digitalkin.grpc_servers.utils.grpc_client_wrapper import GrpcClientWrapper
9
+ from digitalkin.grpc_servers.utils.grpc_error_handler import GrpcErrorHandlerMixin
12
10
  from digitalkin.logger import logger
13
11
  from digitalkin.models.grpc_servers.models import ClientConfig
14
12
  from digitalkin.services.cost.cost_strategy import (
@@ -20,40 +18,9 @@ from digitalkin.services.cost.cost_strategy import (
20
18
  )
21
19
 
22
20
 
23
- class GrpcCost(CostStrategy, GrpcClientWrapper):
21
+ class GrpcCost(CostStrategy, GrpcClientWrapper, GrpcErrorHandlerMixin):
24
22
  """This class implements the default Cost strategy."""
25
23
 
26
- @staticmethod
27
- @contextmanager
28
- def _handle_grpc_errors(operation: str) -> Generator[Any, Any, Any]:
29
- """Context manager for consistent gRPC error handling.
30
-
31
- Yields:
32
- Allow error handling in context.
33
-
34
- Args:
35
- operation: Description of the operation being performed.
36
-
37
- Raises:
38
- ValueError: Error with the model validation.
39
- ServerError: from gRPC Client.
40
- CostServiceError: Unexpected error.
41
- """
42
- try:
43
- yield
44
- except CostServiceError as e:
45
- msg = f"CostServiceError in {operation}: {e}"
46
- logger.exception(msg)
47
- raise CostServiceError(msg) from e
48
- except ServerError as e:
49
- msg = f"gRPC {operation} failed: {e}"
50
- logger.exception(msg)
51
- raise ServerError(msg) from e
52
- except Exception as e:
53
- msg = f"Unexpected error in {operation}"
54
- logger.exception(msg)
55
- raise CostServiceError(msg) from e
56
-
57
24
  def __init__(
58
25
  self,
59
26
  mission_id: str,
@@ -66,7 +33,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
66
33
  super().__init__(mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id, config=config)
67
34
  channel = self._init_channel(client_config)
68
35
  self.stub = cost_service_pb2_grpc.CostServiceStub(channel)
69
- logger.debug("Channel client 'Cost' initialized succesfully")
36
+ logger.debug("Channel client 'Cost' initialized successfully")
70
37
 
71
38
  def add(
72
39
  self,
@@ -84,7 +51,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
84
51
  Raises:
85
52
  CostServiceError: If the cost config is invalid
86
53
  """
87
- with self._handle_grpc_errors("AddCost"):
54
+ with self.handle_grpc_errors("AddCost", CostServiceError):
88
55
  cost_config = self.config.get(cost_config_name)
89
56
  if cost_config is None:
90
57
  msg = f"Cost config {cost_config_name} not found in the configuration."
@@ -122,7 +89,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
122
89
  Returns:
123
90
  CostData: The cost data
124
91
  """
125
- with self._handle_grpc_errors("GetCost"):
92
+ with self.handle_grpc_errors("GetCost", CostServiceError):
126
93
  request = cost_pb2.GetCostRequest(name=name, mission_id=self.mission_id)
127
94
  response: cost_pb2.GetCostResponse = self.exec_grpc_query("GetCost", request)
128
95
  cost_data_list = [
@@ -150,7 +117,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
150
117
  Returns:
151
118
  list[CostData]: The cost data
152
119
  """
153
- with self._handle_grpc_errors("GetCosts"):
120
+ with self.handle_grpc_errors("GetCosts", CostServiceError):
154
121
  request = cost_pb2.GetCostsRequest(
155
122
  mission_id=self.mission_id,
156
123
  filter=cost_pb2.CostFilter(