digitalkin 0.3.0rc2__py3-none-any.whl → 0.3.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/core/common/__init__.py +9 -0
- digitalkin/core/common/factories.py +156 -0
- digitalkin/core/job_manager/base_job_manager.py +128 -28
- digitalkin/core/job_manager/single_job_manager.py +80 -25
- digitalkin/core/job_manager/taskiq_broker.py +114 -19
- digitalkin/core/job_manager/taskiq_job_manager.py +292 -39
- digitalkin/core/task_manager/base_task_manager.py +464 -0
- digitalkin/core/task_manager/local_task_manager.py +108 -0
- digitalkin/core/task_manager/remote_task_manager.py +87 -0
- digitalkin/core/task_manager/surrealdb_repository.py +43 -4
- digitalkin/core/task_manager/task_executor.py +173 -0
- digitalkin/core/task_manager/task_session.py +34 -12
- digitalkin/grpc_servers/module_server.py +2 -2
- digitalkin/grpc_servers/module_servicer.py +4 -3
- digitalkin/grpc_servers/registry_server.py +1 -1
- digitalkin/grpc_servers/registry_servicer.py +4 -4
- digitalkin/grpc_servers/utils/grpc_error_handler.py +53 -0
- digitalkin/models/grpc_servers/models.py +4 -4
- digitalkin/services/cost/grpc_cost.py +8 -41
- digitalkin/services/filesystem/grpc_filesystem.py +9 -38
- digitalkin/services/setup/default_setup.py +5 -6
- digitalkin/services/setup/grpc_setup.py +51 -14
- digitalkin/services/storage/grpc_storage.py +2 -2
- digitalkin/services/user_profile/__init__.py +1 -0
- digitalkin/services/user_profile/default_user_profile.py +55 -0
- digitalkin/services/user_profile/grpc_user_profile.py +69 -0
- digitalkin/services/user_profile/user_profile_strategy.py +40 -0
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/METADATA +7 -7
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/RECORD +33 -23
- digitalkin/core/task_manager/task_manager.py +0 -442
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/WHEEL +0 -0
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.3.0rc2.dist-info → digitalkin-0.3.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""SurrealDB connection management."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import datetime
|
|
4
5
|
import os
|
|
5
6
|
from collections.abc import AsyncGenerator
|
|
@@ -38,6 +39,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
38
39
|
|
|
39
40
|
db: TSurreal
|
|
40
41
|
timeout: datetime.timedelta
|
|
42
|
+
_live_queries: set[UUID] # Track active live queries for cleanup
|
|
41
43
|
|
|
42
44
|
@staticmethod
|
|
43
45
|
def _valid_id(raw_id: str, table_name: str) -> RecordID:
|
|
@@ -82,6 +84,7 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
82
84
|
self.password = os.getenv("SURREALDB_PASSWORD", "root")
|
|
83
85
|
self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
|
|
84
86
|
self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
|
|
87
|
+
self._live_queries = set() # Initialize live queries tracker
|
|
85
88
|
|
|
86
89
|
async def init_surreal_instance(self) -> None:
|
|
87
90
|
"""Init a SurrealDB connection instance."""
|
|
@@ -92,7 +95,37 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
92
95
|
logger.debug("Successfully connected to SurrealDB")
|
|
93
96
|
|
|
94
97
|
async def close(self) -> None:
|
|
95
|
-
"""Close the SurrealDB connection if it exists.
|
|
98
|
+
"""Close the SurrealDB connection if it exists.
|
|
99
|
+
|
|
100
|
+
This will also kill all active live queries to prevent memory leaks.
|
|
101
|
+
"""
|
|
102
|
+
# Kill all tracked live queries before closing connection
|
|
103
|
+
if self._live_queries:
|
|
104
|
+
logger.debug("Killing %d active live queries before closing", len(self._live_queries))
|
|
105
|
+
live_query_ids = list(self._live_queries)
|
|
106
|
+
|
|
107
|
+
# Kill all queries concurrently, capturing any exceptions
|
|
108
|
+
results = await asyncio.gather(
|
|
109
|
+
*[self.db.kill(live_id) for live_id in live_query_ids], return_exceptions=True
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Process results and track failures
|
|
113
|
+
failed_queries = []
|
|
114
|
+
for live_id, result in zip(live_query_ids, results):
|
|
115
|
+
if isinstance(result, (ConnectionError, TimeoutError, Exception)):
|
|
116
|
+
failed_queries.append((live_id, str(result)))
|
|
117
|
+
else:
|
|
118
|
+
self._live_queries.discard(live_id)
|
|
119
|
+
|
|
120
|
+
# Log aggregated failures once instead of per-query
|
|
121
|
+
if failed_queries:
|
|
122
|
+
logger.warning(
|
|
123
|
+
"Failed to kill %d live queries: %s",
|
|
124
|
+
len(failed_queries),
|
|
125
|
+
failed_queries[:5], # Only log first 5 to avoid log spam
|
|
126
|
+
extra={"total_failed": len(failed_queries)},
|
|
127
|
+
)
|
|
128
|
+
|
|
96
129
|
logger.debug("Closing SurrealDB connection")
|
|
97
130
|
await self.db.close()
|
|
98
131
|
|
|
@@ -208,20 +241,26 @@ class SurrealDBConnection(Generic[TSurreal]):
|
|
|
208
241
|
) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
|
|
209
242
|
"""Create and subscribe to a live SurrealQL query.
|
|
210
243
|
|
|
244
|
+
The live query ID is tracked to ensure proper cleanup on connection close.
|
|
245
|
+
|
|
211
246
|
Args:
|
|
212
247
|
table_name: Name of the table to insert into
|
|
213
248
|
|
|
214
249
|
Returns:
|
|
215
|
-
|
|
250
|
+
tuple[UUID, AsyncGenerator]: Live query ID and subscription generator
|
|
216
251
|
"""
|
|
217
252
|
live_id = await self.db.live(table_name, diff=False)
|
|
253
|
+
self._live_queries.add(live_id) # Track for cleanup
|
|
254
|
+
logger.debug("Started live query %s for table %s (total: %d)", live_id, table_name, len(self._live_queries))
|
|
218
255
|
return live_id, await self.db.subscribe_live(live_id)
|
|
219
256
|
|
|
220
257
|
async def stop_live(self, live_id: UUID) -> None:
|
|
221
258
|
"""Kill a live SurrealQL query.
|
|
222
259
|
|
|
223
260
|
Args:
|
|
224
|
-
live_id:
|
|
261
|
+
live_id: Live query ID to kill
|
|
225
262
|
"""
|
|
226
|
-
logger.debug("
|
|
263
|
+
logger.debug("Killing live query: %s", live_id)
|
|
227
264
|
await self.db.kill(live_id)
|
|
265
|
+
self._live_queries.discard(live_id) # Remove from tracker
|
|
266
|
+
logger.debug("Stopped live query %s (remaining: %d)", live_id, len(self._live_queries))
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Task executor for running tasks with full lifecycle management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import datetime
|
|
5
|
+
from collections.abc import Coroutine
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from digitalkin.core.task_manager.surrealdb_repository import SurrealDBConnection
|
|
9
|
+
from digitalkin.core.task_manager.task_session import TaskSession
|
|
10
|
+
from digitalkin.logger import logger
|
|
11
|
+
from digitalkin.models.core.task_monitor import SignalMessage, SignalType, TaskStatus
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TaskExecutor:
|
|
15
|
+
"""Executes tasks with the supervisor pattern (main + heartbeat + signal listener).
|
|
16
|
+
|
|
17
|
+
Pure execution logic - no task registry or orchestration.
|
|
18
|
+
Used by workers to run distributed tasks or by TaskManager for local execution.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
async def execute_task( # noqa: C901, PLR0915
|
|
23
|
+
task_id: str,
|
|
24
|
+
mission_id: str,
|
|
25
|
+
coro: Coroutine[Any, Any, None],
|
|
26
|
+
session: TaskSession,
|
|
27
|
+
channel: SurrealDBConnection,
|
|
28
|
+
) -> asyncio.Task[None]:
|
|
29
|
+
"""Execute a task using the supervisor pattern.
|
|
30
|
+
|
|
31
|
+
Runs three concurrent sub-tasks:
|
|
32
|
+
- Main coroutine (the actual work)
|
|
33
|
+
- Heartbeat generator (sends heartbeats to SurrealDB)
|
|
34
|
+
- Signal listener (watches for stop/pause/resume signals)
|
|
35
|
+
|
|
36
|
+
The first task to complete determines the outcome.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
task_id: Unique identifier for the task
|
|
40
|
+
mission_id: Mission identifier for the task
|
|
41
|
+
coro: The coroutine to execute (module.start(...))
|
|
42
|
+
session: TaskSession for state management
|
|
43
|
+
channel: SurrealDB connection for signals
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
asyncio.Task: The supervisor task managing the lifecycle
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
async def signal_wrapper() -> None:
|
|
50
|
+
"""Create initial signal record and listen for signals."""
|
|
51
|
+
try:
|
|
52
|
+
await channel.create(
|
|
53
|
+
"tasks",
|
|
54
|
+
SignalMessage(
|
|
55
|
+
task_id=task_id,
|
|
56
|
+
mission_id=mission_id,
|
|
57
|
+
status=session.status,
|
|
58
|
+
action=SignalType.START,
|
|
59
|
+
).model_dump(),
|
|
60
|
+
)
|
|
61
|
+
await session.listen_signals()
|
|
62
|
+
except asyncio.CancelledError:
|
|
63
|
+
logger.debug("Signal listener cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
64
|
+
finally:
|
|
65
|
+
await channel.create(
|
|
66
|
+
"tasks",
|
|
67
|
+
SignalMessage(
|
|
68
|
+
task_id=task_id,
|
|
69
|
+
mission_id=mission_id,
|
|
70
|
+
status=session.status,
|
|
71
|
+
action=SignalType.STOP,
|
|
72
|
+
).model_dump(),
|
|
73
|
+
)
|
|
74
|
+
logger.info("Signal listener ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
75
|
+
|
|
76
|
+
async def heartbeat_wrapper() -> None:
|
|
77
|
+
"""Generate heartbeats for task health monitoring."""
|
|
78
|
+
try:
|
|
79
|
+
await session.generate_heartbeats()
|
|
80
|
+
except asyncio.CancelledError:
|
|
81
|
+
logger.debug("Heartbeat cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
82
|
+
finally:
|
|
83
|
+
logger.info("Heartbeat task ended", extra={"mission_id": mission_id, "task_id": task_id})
|
|
84
|
+
|
|
85
|
+
async def supervisor() -> None:
|
|
86
|
+
"""Supervise the three concurrent tasks and handle outcomes.
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
RuntimeError: If the heartbeat task stops unexpectedly.
|
|
90
|
+
asyncio.CancelledError: If the supervisor task is cancelled.
|
|
91
|
+
"""
|
|
92
|
+
session.started_at = datetime.datetime.now(datetime.timezone.utc)
|
|
93
|
+
session.status = TaskStatus.RUNNING
|
|
94
|
+
|
|
95
|
+
# Create tasks with proper exception handling
|
|
96
|
+
main_task = None
|
|
97
|
+
hb_task = None
|
|
98
|
+
sig_task = None
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
main_task = asyncio.create_task(coro, name=f"{task_id}_main")
|
|
102
|
+
hb_task = asyncio.create_task(heartbeat_wrapper(), name=f"{task_id}_heartbeat")
|
|
103
|
+
sig_task = asyncio.create_task(signal_wrapper(), name=f"{task_id}_listener")
|
|
104
|
+
done, pending = await asyncio.wait(
|
|
105
|
+
[main_task, sig_task, hb_task],
|
|
106
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# One task completed -> cancel the others
|
|
110
|
+
for t in pending:
|
|
111
|
+
t.cancel()
|
|
112
|
+
|
|
113
|
+
# Propagate exception/result from the finished task
|
|
114
|
+
completed = next(iter(done))
|
|
115
|
+
await completed
|
|
116
|
+
|
|
117
|
+
# Determine final status based on which task completed
|
|
118
|
+
if completed is main_task:
|
|
119
|
+
session.status = TaskStatus.COMPLETED
|
|
120
|
+
logger.info(
|
|
121
|
+
"Main task completed successfully",
|
|
122
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
123
|
+
)
|
|
124
|
+
elif completed is sig_task or (completed is hb_task and sig_task.done()):
|
|
125
|
+
logger.debug(
|
|
126
|
+
f"Task cancelled due to signal {sig_task=}",
|
|
127
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
128
|
+
)
|
|
129
|
+
session.status = TaskStatus.CANCELLED
|
|
130
|
+
elif completed is hb_task:
|
|
131
|
+
session.status = TaskStatus.FAILED
|
|
132
|
+
logger.error(
|
|
133
|
+
f"Heartbeat stopped for {task_id}",
|
|
134
|
+
extra={"mission_id": mission_id, "task_id": task_id},
|
|
135
|
+
)
|
|
136
|
+
msg = f"Heartbeat stopped for {task_id}"
|
|
137
|
+
raise RuntimeError(msg) # noqa: TRY301
|
|
138
|
+
|
|
139
|
+
except asyncio.CancelledError:
|
|
140
|
+
session.status = TaskStatus.CANCELLED
|
|
141
|
+
logger.info("Task cancelled", extra={"mission_id": mission_id, "task_id": task_id})
|
|
142
|
+
raise
|
|
143
|
+
except Exception:
|
|
144
|
+
session.status = TaskStatus.FAILED
|
|
145
|
+
logger.exception("Task failed", extra={"mission_id": mission_id, "task_id": task_id})
|
|
146
|
+
raise
|
|
147
|
+
finally:
|
|
148
|
+
session.completed_at = datetime.datetime.now(datetime.timezone.utc)
|
|
149
|
+
# Ensure all tasks are cleaned up
|
|
150
|
+
tasks_to_cleanup = [t for t in [main_task, hb_task, sig_task] if t is not None]
|
|
151
|
+
for t in tasks_to_cleanup:
|
|
152
|
+
if not t.done():
|
|
153
|
+
t.cancel()
|
|
154
|
+
if tasks_to_cleanup:
|
|
155
|
+
await asyncio.gather(*tasks_to_cleanup, return_exceptions=True)
|
|
156
|
+
|
|
157
|
+
logger.info(
|
|
158
|
+
"Task execution completed with status: %s",
|
|
159
|
+
session.status,
|
|
160
|
+
extra={
|
|
161
|
+
"mission_id": mission_id,
|
|
162
|
+
"task_id": task_id,
|
|
163
|
+
"status": session.status,
|
|
164
|
+
"duration": (
|
|
165
|
+
(session.completed_at - session.started_at).total_seconds()
|
|
166
|
+
if session.started_at and session.completed_at
|
|
167
|
+
else None
|
|
168
|
+
),
|
|
169
|
+
},
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Return the supervisor task to be awaited by caller
|
|
173
|
+
return asyncio.create_task(supervisor(), name=f"{task_id}_supervisor")
|
|
@@ -42,23 +42,14 @@ class TaskSession:
|
|
|
42
42
|
db: SurrealDBConnection,
|
|
43
43
|
module: BaseModule,
|
|
44
44
|
heartbeat_interval: datetime.timedelta = datetime.timedelta(seconds=2),
|
|
45
|
-
queue_maxsize: int = 1000,
|
|
46
45
|
) -> None:
|
|
47
|
-
"""Initialize Task Session.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
task_id: Unique task identifier
|
|
51
|
-
mission_id: Mission identifier
|
|
52
|
-
db: SurrealDB connection
|
|
53
|
-
module: Module instance
|
|
54
|
-
heartbeat_interval: Interval between heartbeats
|
|
55
|
-
queue_maxsize: Maximum size for the queue (0 = unlimited)
|
|
56
|
-
"""
|
|
46
|
+
"""Initialize Task Session."""
|
|
57
47
|
self.db = db
|
|
58
48
|
self.module = module
|
|
59
49
|
|
|
60
50
|
self.status = TaskStatus.PENDING
|
|
61
|
-
|
|
51
|
+
# Bounded queue to prevent unbounded memory growth (max 1000 items)
|
|
52
|
+
self.queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
|
|
62
53
|
|
|
63
54
|
self.task_id = task_id
|
|
64
55
|
self.mission_id = mission_id
|
|
@@ -326,3 +317,34 @@ class TaskSession:
|
|
|
326
317
|
self.task_id,
|
|
327
318
|
extra={"task_id": self.task_id},
|
|
328
319
|
)
|
|
320
|
+
|
|
321
|
+
async def cleanup(self) -> None:
|
|
322
|
+
"""Clean up task session resources.
|
|
323
|
+
|
|
324
|
+
This includes:
|
|
325
|
+
- Clearing queue to free memory
|
|
326
|
+
- Stopping module
|
|
327
|
+
- Closing database connection
|
|
328
|
+
- Clearing module reference
|
|
329
|
+
"""
|
|
330
|
+
# Clear queue to free memory
|
|
331
|
+
try:
|
|
332
|
+
while not self.queue.empty():
|
|
333
|
+
self.queue.get_nowait()
|
|
334
|
+
except asyncio.QueueEmpty:
|
|
335
|
+
pass
|
|
336
|
+
|
|
337
|
+
# Stop module
|
|
338
|
+
try:
|
|
339
|
+
await self.module.stop()
|
|
340
|
+
except Exception:
|
|
341
|
+
logger.exception(
|
|
342
|
+
"Error stopping module during cleanup",
|
|
343
|
+
extra={"mission_id": self.mission_id, "task_id": self.task_id},
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Close DB connection (kills all live queries)
|
|
347
|
+
await self.db.close()
|
|
348
|
+
|
|
349
|
+
# Clear module reference to allow garbage collection
|
|
350
|
+
self.module = None # type: ignore
|
|
@@ -4,11 +4,11 @@ import uuid
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import grpc
|
|
7
|
-
from digitalkin_proto.
|
|
7
|
+
from digitalkin_proto.agentic_mesh_protocol.module.v1 import (
|
|
8
8
|
module_service_pb2,
|
|
9
9
|
module_service_pb2_grpc,
|
|
10
10
|
)
|
|
11
|
-
from digitalkin_proto.
|
|
11
|
+
from digitalkin_proto.agentic_mesh_protocol.module_registry.v1 import (
|
|
12
12
|
metadata_pb2,
|
|
13
13
|
module_registry_service_pb2_grpc,
|
|
14
14
|
registration_pb2,
|
|
@@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator
|
|
|
5
5
|
from typing import Any
|
|
6
6
|
|
|
7
7
|
import grpc
|
|
8
|
-
from digitalkin_proto.
|
|
8
|
+
from digitalkin_proto.agentic_mesh_protocol.module.v1 import (
|
|
9
9
|
information_pb2,
|
|
10
10
|
lifecycle_pb2,
|
|
11
11
|
module_service_pb2_grpc,
|
|
@@ -172,7 +172,8 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
172
172
|
)
|
|
173
173
|
# Process the module input
|
|
174
174
|
# TODO: Check failure of input data format
|
|
175
|
-
input_data = self.module_class.create_input_model(
|
|
175
|
+
input_data = self.module_class.create_input_model(json_format.MessageToDict(request.input))
|
|
176
|
+
|
|
176
177
|
setup_data_class = self.setup.get_setup(
|
|
177
178
|
setup_dict={
|
|
178
179
|
"setup_id": request.setup_id,
|
|
@@ -225,7 +226,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
225
226
|
proto = json_format.ParseDict(message, struct_pb2.Struct(), ignore_unknown_fields=True)
|
|
226
227
|
yield lifecycle_pb2.StartModuleResponse(success=True, output=proto, job_id=job_id)
|
|
227
228
|
finally:
|
|
228
|
-
await self.job_manager.
|
|
229
|
+
await self.job_manager.wait_for_completion(job_id)
|
|
229
230
|
await self.job_manager.clean_session(job_id, mission_id=request.mission_id)
|
|
230
231
|
|
|
231
232
|
logger.info("Job %s finished", job_id)
|
|
@@ -9,7 +9,7 @@ from collections.abc import Iterator
|
|
|
9
9
|
from enum import Enum
|
|
10
10
|
|
|
11
11
|
import grpc
|
|
12
|
-
from digitalkin_proto.
|
|
12
|
+
from digitalkin_proto.agentic_mesh_protocol.module_registry.v1 import (
|
|
13
13
|
discover_pb2,
|
|
14
14
|
metadata_pb2,
|
|
15
15
|
module_registry_service_pb2_grpc,
|
|
@@ -344,7 +344,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
|
|
|
344
344
|
return status_pb2.ModuleStatusResponse()
|
|
345
345
|
|
|
346
346
|
module = self.registered_modules[request.module_id]
|
|
347
|
-
return status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.
|
|
347
|
+
return status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.value)
|
|
348
348
|
|
|
349
349
|
def ListModuleStatus( # noqa: N802
|
|
350
350
|
self,
|
|
@@ -379,7 +379,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
|
|
|
379
379
|
list_size = len(self.registered_modules)
|
|
380
380
|
|
|
381
381
|
modules_statuses = [
|
|
382
|
-
status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.
|
|
382
|
+
status_pb2.ModuleStatusResponse(module_id=module.module_id, status=module.status.value)
|
|
383
383
|
for module in list(self.registered_modules.values())[request.offset : request.offset + list_size]
|
|
384
384
|
]
|
|
385
385
|
|
|
@@ -409,7 +409,7 @@ class RegistryServicer(module_registry_service_pb2_grpc.ModuleRegistryServiceSer
|
|
|
409
409
|
for module in self.registered_modules.values():
|
|
410
410
|
yield status_pb2.ModuleStatusResponse(
|
|
411
411
|
module_id=module.module_id,
|
|
412
|
-
status=module.status.
|
|
412
|
+
status=module.status.value,
|
|
413
413
|
)
|
|
414
414
|
|
|
415
415
|
def UpdateModuleStatus( # noqa: N802
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Shared error handling utilities for gRPC services."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Generator
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from digitalkin.grpc_servers.utils.exceptions import ServerError
|
|
8
|
+
from digitalkin.logger import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GrpcErrorHandlerMixin:
|
|
12
|
+
"""Mixin class providing common gRPC error handling functionality."""
|
|
13
|
+
|
|
14
|
+
@contextmanager
|
|
15
|
+
def handle_grpc_errors( # noqa: PLR6301
|
|
16
|
+
self,
|
|
17
|
+
operation: str,
|
|
18
|
+
service_error_class: type[Exception] | None = None,
|
|
19
|
+
) -> Generator[Any, Any, Any]:
|
|
20
|
+
"""Handle gRPC errors for the given operation.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
operation: Name of the operation being performed.
|
|
24
|
+
service_error_class: Optional specific service exception class to raise.
|
|
25
|
+
If not provided, uses the generic ServerError.
|
|
26
|
+
|
|
27
|
+
Yields:
|
|
28
|
+
Context for the operation.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ServerError: For gRPC-related errors.
|
|
32
|
+
service_error_class: For service-specific errors if provided.
|
|
33
|
+
"""
|
|
34
|
+
if service_error_class is None:
|
|
35
|
+
service_error_class = ServerError
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
yield
|
|
39
|
+
except service_error_class as e:
|
|
40
|
+
# Re-raise service-specific errors as-is
|
|
41
|
+
msg = f"{service_error_class.__name__} in {operation}: {e}"
|
|
42
|
+
logger.exception(msg)
|
|
43
|
+
raise service_error_class(msg) from e
|
|
44
|
+
except ServerError as e:
|
|
45
|
+
# Handle gRPC server errors
|
|
46
|
+
msg = f"gRPC {operation} failed: {e}"
|
|
47
|
+
logger.exception(msg)
|
|
48
|
+
raise ServerError(msg) from e
|
|
49
|
+
except Exception as e:
|
|
50
|
+
# Handle unexpected errors
|
|
51
|
+
msg = f"Unexpected error in {operation}: {e}"
|
|
52
|
+
logger.exception(msg)
|
|
53
|
+
raise service_error_class(msg) from e
|
|
@@ -175,8 +175,8 @@ class ClientConfig(ChannelConfig):
|
|
|
175
175
|
credentials: ClientCredentials | None = Field(None, description="Client credentials for secure mode")
|
|
176
176
|
channel_options: list[tuple[str, Any]] = Field(
|
|
177
177
|
default_factory=lambda: [
|
|
178
|
-
("grpc.max_receive_message_length",
|
|
179
|
-
("grpc.max_send_message_length",
|
|
178
|
+
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
|
179
|
+
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
|
180
180
|
],
|
|
181
181
|
description="Additional channel options",
|
|
182
182
|
)
|
|
@@ -223,8 +223,8 @@ class ServerConfig(ChannelConfig):
|
|
|
223
223
|
credentials: ServerCredentials | None = Field(None, description="Server credentials for secure mode")
|
|
224
224
|
server_options: list[tuple[str, Any]] = Field(
|
|
225
225
|
default_factory=lambda: [
|
|
226
|
-
("grpc.max_receive_message_length",
|
|
227
|
-
("grpc.max_send_message_length",
|
|
226
|
+
("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB
|
|
227
|
+
("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB
|
|
228
228
|
],
|
|
229
229
|
description="Additional server options",
|
|
230
230
|
)
|
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
"""This module implements the gRPC Cost strategy."""
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from contextlib import contextmanager
|
|
5
|
-
from typing import Any, Literal
|
|
3
|
+
from typing import Literal
|
|
6
4
|
|
|
7
|
-
from digitalkin_proto.
|
|
5
|
+
from digitalkin_proto.agentic_mesh_protocol.cost.v1 import cost_pb2, cost_service_pb2_grpc
|
|
8
6
|
from google.protobuf import json_format
|
|
9
7
|
|
|
10
|
-
from digitalkin.grpc_servers.utils.exceptions import ServerError
|
|
11
8
|
from digitalkin.grpc_servers.utils.grpc_client_wrapper import GrpcClientWrapper
|
|
9
|
+
from digitalkin.grpc_servers.utils.grpc_error_handler import GrpcErrorHandlerMixin
|
|
12
10
|
from digitalkin.logger import logger
|
|
13
11
|
from digitalkin.models.grpc_servers.models import ClientConfig
|
|
14
12
|
from digitalkin.services.cost.cost_strategy import (
|
|
@@ -20,40 +18,9 @@ from digitalkin.services.cost.cost_strategy import (
|
|
|
20
18
|
)
|
|
21
19
|
|
|
22
20
|
|
|
23
|
-
class GrpcCost(CostStrategy, GrpcClientWrapper):
|
|
21
|
+
class GrpcCost(CostStrategy, GrpcClientWrapper, GrpcErrorHandlerMixin):
|
|
24
22
|
"""This class implements the default Cost strategy."""
|
|
25
23
|
|
|
26
|
-
@staticmethod
|
|
27
|
-
@contextmanager
|
|
28
|
-
def _handle_grpc_errors(operation: str) -> Generator[Any, Any, Any]:
|
|
29
|
-
"""Context manager for consistent gRPC error handling.
|
|
30
|
-
|
|
31
|
-
Yields:
|
|
32
|
-
Allow error handling in context.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
operation: Description of the operation being performed.
|
|
36
|
-
|
|
37
|
-
Raises:
|
|
38
|
-
ValueError: Error with the model validation.
|
|
39
|
-
ServerError: from gRPC Client.
|
|
40
|
-
CostServiceError: Unexpected error.
|
|
41
|
-
"""
|
|
42
|
-
try:
|
|
43
|
-
yield
|
|
44
|
-
except CostServiceError as e:
|
|
45
|
-
msg = f"CostServiceError in {operation}: {e}"
|
|
46
|
-
logger.exception(msg)
|
|
47
|
-
raise CostServiceError(msg) from e
|
|
48
|
-
except ServerError as e:
|
|
49
|
-
msg = f"gRPC {operation} failed: {e}"
|
|
50
|
-
logger.exception(msg)
|
|
51
|
-
raise ServerError(msg) from e
|
|
52
|
-
except Exception as e:
|
|
53
|
-
msg = f"Unexpected error in {operation}"
|
|
54
|
-
logger.exception(msg)
|
|
55
|
-
raise CostServiceError(msg) from e
|
|
56
|
-
|
|
57
24
|
def __init__(
|
|
58
25
|
self,
|
|
59
26
|
mission_id: str,
|
|
@@ -66,7 +33,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
|
|
|
66
33
|
super().__init__(mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id, config=config)
|
|
67
34
|
channel = self._init_channel(client_config)
|
|
68
35
|
self.stub = cost_service_pb2_grpc.CostServiceStub(channel)
|
|
69
|
-
logger.debug("Channel client 'Cost' initialized
|
|
36
|
+
logger.debug("Channel client 'Cost' initialized successfully")
|
|
70
37
|
|
|
71
38
|
def add(
|
|
72
39
|
self,
|
|
@@ -84,7 +51,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
|
|
|
84
51
|
Raises:
|
|
85
52
|
CostServiceError: If the cost config is invalid
|
|
86
53
|
"""
|
|
87
|
-
with self.
|
|
54
|
+
with self.handle_grpc_errors("AddCost", CostServiceError):
|
|
88
55
|
cost_config = self.config.get(cost_config_name)
|
|
89
56
|
if cost_config is None:
|
|
90
57
|
msg = f"Cost config {cost_config_name} not found in the configuration."
|
|
@@ -122,7 +89,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
|
|
|
122
89
|
Returns:
|
|
123
90
|
CostData: The cost data
|
|
124
91
|
"""
|
|
125
|
-
with self.
|
|
92
|
+
with self.handle_grpc_errors("GetCost", CostServiceError):
|
|
126
93
|
request = cost_pb2.GetCostRequest(name=name, mission_id=self.mission_id)
|
|
127
94
|
response: cost_pb2.GetCostResponse = self.exec_grpc_query("GetCost", request)
|
|
128
95
|
cost_data_list = [
|
|
@@ -150,7 +117,7 @@ class GrpcCost(CostStrategy, GrpcClientWrapper):
|
|
|
150
117
|
Returns:
|
|
151
118
|
list[CostData]: The cost data
|
|
152
119
|
"""
|
|
153
|
-
with self.
|
|
120
|
+
with self.handle_grpc_errors("GetCosts", CostServiceError):
|
|
154
121
|
request = cost_pb2.GetCostsRequest(
|
|
155
122
|
mission_id=self.mission_id,
|
|
156
123
|
filter=cost_pb2.CostFilter(
|