digitalkin 0.2.25rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/grpc_servers/_base_server.py +1 -1
- digitalkin/grpc_servers/module_server.py +26 -42
- digitalkin/grpc_servers/module_servicer.py +30 -24
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +3 -3
- digitalkin/grpc_servers/utils/models.py +1 -1
- digitalkin/logger.py +60 -23
- digitalkin/mixins/__init__.py +19 -0
- digitalkin/mixins/base_mixin.py +10 -0
- digitalkin/mixins/callback_mixin.py +24 -0
- digitalkin/mixins/chat_history_mixin.py +108 -0
- digitalkin/mixins/cost_mixin.py +76 -0
- digitalkin/mixins/file_history_mixin.py +99 -0
- digitalkin/mixins/filesystem_mixin.py +47 -0
- digitalkin/mixins/logger_mixin.py +59 -0
- digitalkin/mixins/storage_mixin.py +79 -0
- digitalkin/models/module/__init__.py +2 -0
- digitalkin/models/module/module.py +9 -1
- digitalkin/models/module/module_context.py +90 -6
- digitalkin/models/module/module_types.py +5 -5
- digitalkin/models/module/task_monitor.py +51 -0
- digitalkin/models/services/__init__.py +9 -0
- digitalkin/models/services/storage.py +39 -5
- digitalkin/modules/_base_module.py +105 -74
- digitalkin/modules/job_manager/base_job_manager.py +12 -8
- digitalkin/modules/job_manager/single_job_manager.py +84 -78
- digitalkin/modules/job_manager/surrealdb_repository.py +225 -0
- digitalkin/modules/job_manager/task_manager.py +391 -0
- digitalkin/modules/job_manager/task_session.py +276 -0
- digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
- digitalkin/modules/tool_module.py +10 -2
- digitalkin/modules/trigger_handler.py +7 -6
- digitalkin/services/cost/__init__.py +9 -2
- digitalkin/services/storage/grpc_storage.py +1 -1
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/METADATA +18 -18
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/RECORD +39 -26
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/WHEEL +0 -0
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Background module manager with single instance."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
+
import datetime
|
|
4
5
|
import uuid
|
|
5
6
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
6
7
|
from contextlib import asynccontextmanager
|
|
@@ -9,15 +10,18 @@ from typing import Any, Generic
|
|
|
9
10
|
import grpc
|
|
10
11
|
|
|
11
12
|
from digitalkin.logger import logger
|
|
12
|
-
from digitalkin.models import ModuleStatus
|
|
13
13
|
from digitalkin.models.module import InputModelT, OutputModelT, SetupModelT
|
|
14
|
+
from digitalkin.models.module.module import ModuleCodeModel
|
|
15
|
+
from digitalkin.models.module.task_monitor import TaskStatus
|
|
14
16
|
from digitalkin.modules._base_module import BaseModule
|
|
15
17
|
from digitalkin.modules.job_manager.base_job_manager import BaseJobManager
|
|
16
|
-
from digitalkin.modules.job_manager.
|
|
18
|
+
from digitalkin.modules.job_manager.surrealdb_repository import SurrealDBConnection
|
|
19
|
+
from digitalkin.modules.job_manager.task_manager import TaskManager
|
|
20
|
+
from digitalkin.modules.job_manager.task_session import TaskSession
|
|
17
21
|
from digitalkin.services.services_models import ServicesMode
|
|
18
22
|
|
|
19
23
|
|
|
20
|
-
class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
24
|
+
class SingleJobManager(BaseJobManager, TaskManager, Generic[InputModelT, OutputModelT, SetupModelT]):
|
|
21
25
|
"""Manages a single instance of a module job.
|
|
22
26
|
|
|
23
27
|
This class ensures that only one instance of a module job is active at a time.
|
|
@@ -25,8 +29,10 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
25
29
|
to handle their output data.
|
|
26
30
|
"""
|
|
27
31
|
|
|
28
|
-
|
|
29
|
-
|
|
32
|
+
async def start(self) -> None:
|
|
33
|
+
"""Start manager."""
|
|
34
|
+
self.channel = SurrealDBConnection("task_manager", datetime.timedelta(seconds=5))
|
|
35
|
+
await self.channel.init_surreal_instance()
|
|
30
36
|
|
|
31
37
|
def __init__(
|
|
32
38
|
self,
|
|
@@ -40,12 +46,9 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
40
46
|
services_mode: The mode of operation for the services (e.g., ASYNC or SYNC).
|
|
41
47
|
"""
|
|
42
48
|
super().__init__(module_class, services_mode)
|
|
43
|
-
|
|
44
49
|
self._lock = asyncio.Lock()
|
|
45
|
-
self.modules: dict[str, BaseModule] = {}
|
|
46
|
-
self.queues: dict[str, asyncio.Queue] = {}
|
|
47
50
|
|
|
48
|
-
async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT:
|
|
51
|
+
async def generate_config_setup_module_response(self, job_id: str) -> SetupModelT | ModuleCodeModel:
|
|
49
52
|
"""Generate a stream consumer for a module's output data.
|
|
50
53
|
|
|
51
54
|
This method creates an asynchronous generator that streams output data
|
|
@@ -56,16 +59,19 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
56
59
|
job_id: The unique identifier of the job.
|
|
57
60
|
|
|
58
61
|
Returns:
|
|
59
|
-
SetupModelT: the SetupModelT object fully processed.
|
|
62
|
+
SetupModelT | ModuleCodeModel: the SetupModelT object fully processed.
|
|
60
63
|
"""
|
|
61
|
-
|
|
62
|
-
|
|
64
|
+
if (session := self.tasks_sessions.get(job_id, None)) is None:
|
|
65
|
+
return ModuleCodeModel(
|
|
66
|
+
code=str(grpc.StatusCode.NOT_FOUND),
|
|
67
|
+
message=f"Module {job_id} not found",
|
|
68
|
+
)
|
|
63
69
|
|
|
70
|
+
logger.debug("Module %s found: %s", job_id, session.module)
|
|
64
71
|
try:
|
|
65
|
-
return await
|
|
72
|
+
return await session.queue.get()
|
|
66
73
|
finally:
|
|
67
|
-
logger.info(f"{job_id=}: {
|
|
68
|
-
del self.queues[job_id]
|
|
74
|
+
logger.info(f"{job_id=}: {session.queue.empty()}")
|
|
69
75
|
|
|
70
76
|
async def create_config_setup_instance_job(
|
|
71
77
|
self,
|
|
@@ -95,8 +101,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
95
101
|
job_id = str(uuid.uuid4())
|
|
96
102
|
# TODO: Ensure the job_id is unique.
|
|
97
103
|
module = self.module_class(job_id, mission_id=mission_id, setup_id=setup_id, setup_version_id=setup_version_id)
|
|
98
|
-
self.
|
|
99
|
-
self.queues[job_id] = asyncio.Queue()
|
|
104
|
+
self.tasks_sessions[job_id] = TaskSession(job_id, self.channel, module)
|
|
100
105
|
|
|
101
106
|
try:
|
|
102
107
|
await module.start_config_setup(
|
|
@@ -106,13 +111,13 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
106
111
|
logger.debug("Module %s (%s) started successfully", job_id, module.name)
|
|
107
112
|
except Exception:
|
|
108
113
|
# Remove the module from the manager in case of an error.
|
|
109
|
-
del self.
|
|
114
|
+
del self.tasks_sessions[job_id]
|
|
110
115
|
logger.exception("Failed to start module %s: %s", job_id)
|
|
111
116
|
raise
|
|
112
117
|
else:
|
|
113
118
|
return job_id
|
|
114
119
|
|
|
115
|
-
async def add_to_queue(self, job_id: str, output_data: OutputModelT) -> None:
|
|
120
|
+
async def add_to_queue(self, job_id: str, output_data: OutputModelT | ModuleCodeModel) -> None:
|
|
116
121
|
"""Add output data to the queue for a specific job.
|
|
117
122
|
|
|
118
123
|
This method is used as a callback to handle output data generated by a module job.
|
|
@@ -121,7 +126,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
121
126
|
job_id: The unique identifier of the job.
|
|
122
127
|
output_data: The output data produced by the job.
|
|
123
128
|
"""
|
|
124
|
-
await self.
|
|
129
|
+
await self.tasks_sessions[job_id].queue.put(output_data.model_dump())
|
|
125
130
|
|
|
126
131
|
@asynccontextmanager # type: ignore
|
|
127
132
|
async def generate_stream_consumer(self, job_id: str) -> AsyncIterator[AsyncGenerator[dict[str, Any], None]]: # type: ignore
|
|
@@ -137,39 +142,48 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
137
142
|
Yields:
|
|
138
143
|
AsyncGenerator: A stream of output data or error messages.
|
|
139
144
|
"""
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
logger.debug("Module %s found: %s", job_id, module)
|
|
145
|
+
if (session := self.tasks_sessions.get(job_id, None)) is None:
|
|
143
146
|
|
|
144
|
-
|
|
145
|
-
|
|
147
|
+
async def _error_gen() -> AsyncGenerator[dict[str, Any], None]: # noqa: RUF029
|
|
148
|
+
"""Generate an error message for a non-existent module.
|
|
146
149
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
if module is None:
|
|
150
|
+
Yields:
|
|
151
|
+
AsyncGenerator: A generator yielding an error message.
|
|
152
|
+
"""
|
|
151
153
|
yield {
|
|
152
154
|
"error": {
|
|
153
155
|
"error_message": f"Module {job_id} not found",
|
|
154
156
|
"code": grpc.StatusCode.NOT_FOUND,
|
|
155
157
|
}
|
|
156
158
|
}
|
|
157
|
-
return
|
|
158
159
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
}
|
|
167
|
-
):
|
|
168
|
-
logger.info(f"{job_id=}: {module.status=}")
|
|
169
|
-
yield await self.queues[job_id].get()
|
|
160
|
+
yield _error_gen()
|
|
161
|
+
return
|
|
162
|
+
|
|
163
|
+
logger.debug("Session: %s with Module %s", job_id, session.module)
|
|
164
|
+
|
|
165
|
+
async def _stream() -> AsyncGenerator[dict[str, Any], Any]:
|
|
166
|
+
"""Stream output data from the module.
|
|
170
167
|
|
|
171
|
-
|
|
172
|
-
|
|
168
|
+
Yields:
|
|
169
|
+
dict: Output data generated by the module.
|
|
170
|
+
"""
|
|
171
|
+
while True:
|
|
172
|
+
# if queue is empty but producer not finished yet, block on get()
|
|
173
|
+
msg = await session.queue.get()
|
|
174
|
+
try:
|
|
175
|
+
yield msg
|
|
176
|
+
finally:
|
|
177
|
+
session.queue.task_done()
|
|
178
|
+
|
|
179
|
+
# If the producer marked finished and no more items, break soon:
|
|
180
|
+
if (
|
|
181
|
+
session.is_cancelled.is_set()
|
|
182
|
+
or (session.status is TaskStatus.COMPLETED and session.queue.empty())
|
|
183
|
+
or session.status is TaskStatus.FAILED
|
|
184
|
+
):
|
|
185
|
+
# and session.queue.empty():
|
|
186
|
+
break
|
|
173
187
|
|
|
174
188
|
yield _stream()
|
|
175
189
|
|
|
@@ -200,32 +214,21 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
200
214
|
Exception: If the module fails to start.
|
|
201
215
|
"""
|
|
202
216
|
job_id = str(uuid.uuid4())
|
|
203
|
-
# TODO: Ensure the job_id is unique.
|
|
204
217
|
module = self.module_class(
|
|
205
218
|
job_id,
|
|
206
219
|
mission_id=mission_id,
|
|
207
220
|
setup_id=setup_id,
|
|
208
221
|
setup_version_id=setup_version_id,
|
|
209
222
|
)
|
|
210
|
-
self.modules[job_id] = module
|
|
211
|
-
self.queues[job_id] = asyncio.Queue()
|
|
212
223
|
callback = await self.job_specific_callback(self.add_to_queue, job_id)
|
|
213
224
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
logger.debug("Module %s (%s) started successfully", job_id, module.name)
|
|
222
|
-
except Exception:
|
|
223
|
-
# Remove the module from the manager in case of an error.
|
|
224
|
-
del self.modules[job_id]
|
|
225
|
-
logger.exception("Failed to start module %s: %s", job_id)
|
|
226
|
-
raise
|
|
227
|
-
else:
|
|
228
|
-
return job_id
|
|
225
|
+
await self.create_task(
|
|
226
|
+
job_id,
|
|
227
|
+
module,
|
|
228
|
+
module.start(input_data, setup_data, callback, done_callback=None),
|
|
229
|
+
)
|
|
230
|
+
logger.info("Managed task started: '%s'", job_id, extra={"task_id": job_id})
|
|
231
|
+
return job_id
|
|
229
232
|
|
|
230
233
|
async def stop_module(self, job_id: str) -> bool:
|
|
231
234
|
"""Stop a running module job.
|
|
@@ -239,34 +242,37 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
239
242
|
Raises:
|
|
240
243
|
Exception: If an error occurs while stopping the module.
|
|
241
244
|
"""
|
|
245
|
+
logger.critical(f"STOP {job_id=} | {self.tasks_sessions.keys()}")
|
|
246
|
+
|
|
242
247
|
async with self._lock:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
248
|
+
session = self.tasks_sessions.get(job_id)
|
|
249
|
+
|
|
250
|
+
if not session:
|
|
251
|
+
logger.warning(f"session with id: {job_id} not found")
|
|
246
252
|
return False
|
|
247
253
|
try:
|
|
248
|
-
await module.stop()
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
logger.debug(f"
|
|
254
|
+
await session.module.stop()
|
|
255
|
+
|
|
256
|
+
if job_id in self.tasks:
|
|
257
|
+
await self.cancel_task(job_id)
|
|
258
|
+
logger.debug(f"session {job_id} ({session.module.name}) stopped successfully")
|
|
253
259
|
except Exception as e:
|
|
254
260
|
logger.error(f"Error while stopping module {job_id}: {e}")
|
|
255
261
|
raise
|
|
256
262
|
else:
|
|
257
263
|
return True
|
|
258
264
|
|
|
259
|
-
async def get_module_status(self, job_id: str) ->
|
|
265
|
+
async def get_module_status(self, job_id: str) -> TaskStatus:
|
|
260
266
|
"""Retrieve the status of a module job.
|
|
261
267
|
|
|
262
268
|
Args:
|
|
263
269
|
job_id: The unique identifier of the job.
|
|
264
270
|
|
|
265
271
|
Returns:
|
|
266
|
-
ModuleStatus
|
|
272
|
+
ModuleStatus: The status of the module.
|
|
267
273
|
"""
|
|
268
|
-
|
|
269
|
-
return
|
|
274
|
+
session = self.tasks_sessions.get(job_id, None)
|
|
275
|
+
return session.status if session is not None else TaskStatus.FAILED
|
|
270
276
|
|
|
271
277
|
async def stop_all_modules(self) -> None:
|
|
272
278
|
"""Stop all currently running module jobs.
|
|
@@ -274,7 +280,7 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
274
280
|
This method ensures that all active jobs are gracefully terminated.
|
|
275
281
|
"""
|
|
276
282
|
async with self._lock:
|
|
277
|
-
stop_tasks = [self.stop_module(job_id) for job_id in list(self.
|
|
283
|
+
stop_tasks = [self.stop_module(job_id) for job_id in list(self.tasks_sessions.keys())]
|
|
278
284
|
if stop_tasks:
|
|
279
285
|
await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
280
286
|
|
|
@@ -286,9 +292,9 @@ class SingleJobManager(BaseJobManager, Generic[InputModelT, SetupModelT]):
|
|
|
286
292
|
"""
|
|
287
293
|
return {
|
|
288
294
|
job_id: {
|
|
289
|
-
"name": module.name,
|
|
290
|
-
"status": module.status,
|
|
291
|
-
"class": module.__class__.__name__,
|
|
295
|
+
"name": session.module.name,
|
|
296
|
+
"status": session.module.status,
|
|
297
|
+
"class": session.module.__class__.__name__,
|
|
292
298
|
}
|
|
293
|
-
for job_id,
|
|
299
|
+
for job_id, session in self.tasks_sessions.items()
|
|
294
300
|
}
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""SurrealDB connection management."""
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import os
|
|
5
|
+
from collections.abc import AsyncGenerator
|
|
6
|
+
from typing import Any, Generic, TypeVar
|
|
7
|
+
from uuid import UUID
|
|
8
|
+
|
|
9
|
+
from surrealdb import AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID
|
|
10
|
+
|
|
11
|
+
from digitalkin.logger import logger
|
|
12
|
+
|
|
13
|
+
TSurreal = TypeVar("TSurreal", bound=AsyncHttpSurrealConnection | AsyncWsSurrealConnection)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SurrealDBSetupBadIDError(Exception):
|
|
17
|
+
"""Exception raised when an invalid ID is encountered during the setup process in the SurrealDB repository.
|
|
18
|
+
|
|
19
|
+
This error is used to indicate that the provided ID does not meet the
|
|
20
|
+
expected format or criteria.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SurrealDBSetupVersionBadIDError(Exception):
|
|
25
|
+
"""Exception raised when an invalid ID is encountered during the setup of a SurrealDB version.
|
|
26
|
+
|
|
27
|
+
This error is intended to signal that the provided ID does not meet
|
|
28
|
+
the expected format or criteria for a valid SurrealDB setup version ID.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SurrealDBConnection(Generic[TSurreal]):
|
|
33
|
+
"""Base repository for database operations.
|
|
34
|
+
|
|
35
|
+
This class provides common database operations that can be used by
|
|
36
|
+
specific table repositories.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
db: TSurreal
|
|
40
|
+
timeout: datetime.timedelta
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def _valid_id(raw_id: str, table_name: str) -> RecordID:
|
|
44
|
+
"""Validate and parse a raw ID string into a RecordID.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
raw_id: The raw ID string to validate
|
|
48
|
+
table_name: table name to enforce
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
SurrealDBSetupBadIDError: If the raw ID string is not valid
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
RecordID: Parsed RecordID object if valid, None otherwise
|
|
55
|
+
"""
|
|
56
|
+
try:
|
|
57
|
+
split_id = raw_id.split(":")
|
|
58
|
+
if split_id[0] != table_name:
|
|
59
|
+
msg = f"Invalid table name for ID: {raw_id}"
|
|
60
|
+
raise SurrealDBSetupBadIDError(msg)
|
|
61
|
+
return RecordID(split_id[0], split_id[1])
|
|
62
|
+
except IndexError:
|
|
63
|
+
raise SurrealDBSetupBadIDError
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
database: str | None = None,
|
|
68
|
+
timeout: datetime.timedelta = datetime.timedelta(seconds=5),
|
|
69
|
+
) -> None:
|
|
70
|
+
"""Initialize the repository.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
database: AsyncSurrealDB connection to a specific database
|
|
74
|
+
table_name: Name of the table to interact with
|
|
75
|
+
timeout: Timeout for database operations
|
|
76
|
+
"""
|
|
77
|
+
self.timeout = timeout
|
|
78
|
+
self.url = f"{os.getenv('SURREALDB_URL', 'ws://localhost')}:{os.getenv('SURREALDB_PORT', '8000')}/rpc"
|
|
79
|
+
self.username = os.getenv("SURREALDB_USERNAME", "root")
|
|
80
|
+
self.password = os.getenv("SURREALDB_PASSWORD", "root")
|
|
81
|
+
self.namespace = os.getenv("SURREALDB_NAMESPACE", "test")
|
|
82
|
+
self.database = database or os.getenv("SURREALDB_DATABASE", "task_manager")
|
|
83
|
+
|
|
84
|
+
async def init_surreal_instance(self) -> None:
|
|
85
|
+
"""Init a SurrealDB connection instance."""
|
|
86
|
+
logger.debug("Connecting to SurrealDB at %s", self.url)
|
|
87
|
+
self.db = AsyncSurreal(self.url) # type: ignore
|
|
88
|
+
await self.db.signin({"username": self.username, "password": self.password})
|
|
89
|
+
await self.db.use(self.namespace, self.database)
|
|
90
|
+
logger.debug("Successfully connected to SurrealDB")
|
|
91
|
+
|
|
92
|
+
async def close(self) -> None:
|
|
93
|
+
"""Close the SurrealDB connection if it exists."""
|
|
94
|
+
logger.debug("Closing SurrealDB connection")
|
|
95
|
+
await self.db.close()
|
|
96
|
+
|
|
97
|
+
async def create(
|
|
98
|
+
self,
|
|
99
|
+
table_name: str,
|
|
100
|
+
data: dict[str, Any],
|
|
101
|
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
|
102
|
+
"""Create a new record.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
table_name: Name of the table to insert into
|
|
106
|
+
data: Data to insert
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Dict[str, Any]: The created record as returned by the database
|
|
110
|
+
"""
|
|
111
|
+
logger.debug("Creating record in %s with data: %s", table_name, data)
|
|
112
|
+
result = await self.db.create(table_name, data)
|
|
113
|
+
logger.debug("create result: %s", result)
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
async def merge(
|
|
117
|
+
self,
|
|
118
|
+
table_name: str,
|
|
119
|
+
record_id: str | RecordID,
|
|
120
|
+
data: dict[str, Any],
|
|
121
|
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
|
122
|
+
"""Update an existing record.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
table_name: Name of the table to insert into
|
|
126
|
+
record_id: record ID to update
|
|
127
|
+
data: Data to insert
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Dict[str, Any]: The created record as returned by the database
|
|
131
|
+
"""
|
|
132
|
+
if isinstance(record_id, str):
|
|
133
|
+
# validate surrealDB id if raw str
|
|
134
|
+
record_id = self._valid_id(record_id, table_name)
|
|
135
|
+
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
136
|
+
result = await self.db.merge(record_id, data)
|
|
137
|
+
logger.debug("update result: %s", result)
|
|
138
|
+
return result
|
|
139
|
+
|
|
140
|
+
async def update(
|
|
141
|
+
self,
|
|
142
|
+
table_name: str,
|
|
143
|
+
record_id: str | RecordID,
|
|
144
|
+
data: dict[str, Any],
|
|
145
|
+
) -> list[dict[str, Any]] | dict[str, Any]:
|
|
146
|
+
"""Update an existing record.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
table_name: Name of the table to insert into
|
|
150
|
+
record_id: record ID to update
|
|
151
|
+
data: Data to insert
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Dict[str, Any]: The created record as returned by the database
|
|
155
|
+
"""
|
|
156
|
+
if isinstance(record_id, str):
|
|
157
|
+
# validate surrealDB id if raw str
|
|
158
|
+
record_id = self._valid_id(record_id, table_name)
|
|
159
|
+
logger.debug("Updating record in %s with data: %s", record_id, data)
|
|
160
|
+
result = await self.db.update(record_id, data)
|
|
161
|
+
logger.debug("update result: %s", result)
|
|
162
|
+
return result
|
|
163
|
+
|
|
164
|
+
async def execute_query(self, query: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
|
|
165
|
+
"""Execute a custom SurrealQL query.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
query: SurrealQL query
|
|
169
|
+
params: Query parameters
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
List[Dict[str, Any]]: Query results
|
|
173
|
+
"""
|
|
174
|
+
logger.debug("execute_query: %s with params: %s", query, params)
|
|
175
|
+
result = await self.db.query(query, params or {})
|
|
176
|
+
logger.debug("execute_query result: %s", result)
|
|
177
|
+
return [result] if isinstance(result, dict) else result
|
|
178
|
+
|
|
179
|
+
async def select_by_task_id(self, table: str, value: str) -> dict[str, Any]:
|
|
180
|
+
"""Fetch a record from a table by a unique field.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
table: Table name
|
|
184
|
+
value: Field value to match
|
|
185
|
+
|
|
186
|
+
Raises:
|
|
187
|
+
ValueError: If no records are found
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Dict with record data if found, else None
|
|
191
|
+
"""
|
|
192
|
+
query = "SELECT * FROM type::table($table) WHERE task_id = $value;"
|
|
193
|
+
params = {"table": table, "value": value}
|
|
194
|
+
|
|
195
|
+
result = await self.execute_query(query, params)
|
|
196
|
+
if not result:
|
|
197
|
+
msg = f"No records found in table '{table}' with task_id '{value}'"
|
|
198
|
+
logger.error(msg)
|
|
199
|
+
raise ValueError(msg)
|
|
200
|
+
|
|
201
|
+
return result[0]
|
|
202
|
+
|
|
203
|
+
async def start_live(
|
|
204
|
+
self,
|
|
205
|
+
table_name: str,
|
|
206
|
+
) -> tuple[UUID, AsyncGenerator[dict[str, Any], None]]:
|
|
207
|
+
"""Create and subscribe to a live SurrealQL query.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
table_name: Name of the table to insert into
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
List[Dict[str, Any]]: Query results
|
|
214
|
+
"""
|
|
215
|
+
live_id = await self.db.live(table_name, diff=False)
|
|
216
|
+
return live_id, await self.db.subscribe_live(live_id)
|
|
217
|
+
|
|
218
|
+
async def stop_live(self, live_id: UUID) -> None:
|
|
219
|
+
"""Kill a live SurrealQL query.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
live_id: record ID to watch for
|
|
223
|
+
"""
|
|
224
|
+
logger.debug("KILL Subscribe live for: %s", live_id)
|
|
225
|
+
await self.db.kill(live_id)
|