digitalkin 0.2.25rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__version__.py +1 -1
- digitalkin/grpc_servers/_base_server.py +1 -1
- digitalkin/grpc_servers/module_server.py +26 -42
- digitalkin/grpc_servers/module_servicer.py +30 -24
- digitalkin/grpc_servers/utils/grpc_client_wrapper.py +3 -3
- digitalkin/grpc_servers/utils/models.py +1 -1
- digitalkin/logger.py +60 -23
- digitalkin/mixins/__init__.py +19 -0
- digitalkin/mixins/base_mixin.py +10 -0
- digitalkin/mixins/callback_mixin.py +24 -0
- digitalkin/mixins/chat_history_mixin.py +108 -0
- digitalkin/mixins/cost_mixin.py +76 -0
- digitalkin/mixins/file_history_mixin.py +99 -0
- digitalkin/mixins/filesystem_mixin.py +47 -0
- digitalkin/mixins/logger_mixin.py +59 -0
- digitalkin/mixins/storage_mixin.py +79 -0
- digitalkin/models/module/__init__.py +2 -0
- digitalkin/models/module/module.py +9 -1
- digitalkin/models/module/module_context.py +90 -6
- digitalkin/models/module/module_types.py +5 -5
- digitalkin/models/module/task_monitor.py +51 -0
- digitalkin/models/services/__init__.py +9 -0
- digitalkin/models/services/storage.py +39 -5
- digitalkin/modules/_base_module.py +105 -74
- digitalkin/modules/job_manager/base_job_manager.py +12 -8
- digitalkin/modules/job_manager/single_job_manager.py +84 -78
- digitalkin/modules/job_manager/surrealdb_repository.py +225 -0
- digitalkin/modules/job_manager/task_manager.py +391 -0
- digitalkin/modules/job_manager/task_session.py +276 -0
- digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
- digitalkin/modules/tool_module.py +10 -2
- digitalkin/modules/trigger_handler.py +7 -6
- digitalkin/services/cost/__init__.py +9 -2
- digitalkin/services/storage/grpc_storage.py +1 -1
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/METADATA +18 -18
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/RECORD +39 -26
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/WHEEL +0 -0
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/licenses/LICENSE +0 -0
- {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/top_level.txt +0 -0
digitalkin/__version__.py
CHANGED
|
@@ -299,7 +299,7 @@ class BaseServer(abc.ABC):
|
|
|
299
299
|
self._add_reflection()
|
|
300
300
|
|
|
301
301
|
# Start the server
|
|
302
|
-
logger.debug("Starting gRPC server on %s", self.config.address)
|
|
302
|
+
logger.debug("Starting gRPC server on %s", self.config.address, extra={"config": self.config})
|
|
303
303
|
try:
|
|
304
304
|
if self.config.mode == ServerMode.ASYNC:
|
|
305
305
|
# For async server, use the event loop
|
|
@@ -1,20 +1,9 @@
|
|
|
1
1
|
"""Module gRPC server implementation for DigitalKin."""
|
|
2
2
|
|
|
3
|
-
from pathlib import Path
|
|
4
3
|
import uuid
|
|
4
|
+
from pathlib import Path
|
|
5
5
|
|
|
6
6
|
import grpc
|
|
7
|
-
|
|
8
|
-
from digitalkin.grpc_servers._base_server import BaseServer
|
|
9
|
-
from digitalkin.grpc_servers.module_servicer import ModuleServicer
|
|
10
|
-
from digitalkin.grpc_servers.utils.exceptions import ServerError
|
|
11
|
-
from digitalkin.grpc_servers.utils.models import (
|
|
12
|
-
ClientConfig,
|
|
13
|
-
ModuleServerConfig,
|
|
14
|
-
SecurityMode,
|
|
15
|
-
)
|
|
16
|
-
from digitalkin.modules._base_module import BaseModule
|
|
17
|
-
|
|
18
7
|
from digitalkin_proto.digitalkin.module.v2 import (
|
|
19
8
|
module_service_pb2,
|
|
20
9
|
module_service_pb2_grpc,
|
|
@@ -24,7 +13,17 @@ from digitalkin_proto.digitalkin.module_registry.v2 import (
|
|
|
24
13
|
module_registry_service_pb2_grpc,
|
|
25
14
|
registration_pb2,
|
|
26
15
|
)
|
|
16
|
+
|
|
17
|
+
from digitalkin.grpc_servers._base_server import BaseServer
|
|
18
|
+
from digitalkin.grpc_servers.module_servicer import ModuleServicer
|
|
19
|
+
from digitalkin.grpc_servers.utils.exceptions import ServerError
|
|
20
|
+
from digitalkin.grpc_servers.utils.models import (
|
|
21
|
+
ClientConfig,
|
|
22
|
+
ModuleServerConfig,
|
|
23
|
+
SecurityMode,
|
|
24
|
+
)
|
|
27
25
|
from digitalkin.logger import logger
|
|
26
|
+
from digitalkin.modules._base_module import BaseModule
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
class ModuleServer(BaseServer):
|
|
@@ -50,7 +49,8 @@ class ModuleServer(BaseServer):
|
|
|
50
49
|
|
|
51
50
|
Args:
|
|
52
51
|
module_class: The module instance to be served.
|
|
53
|
-
|
|
52
|
+
server_config: Server configuration including registry address if auto-registration is desired.
|
|
53
|
+
client_config: Client configuration used by services.
|
|
54
54
|
"""
|
|
55
55
|
super().__init__(server_config)
|
|
56
56
|
self.module_class = module_class
|
|
@@ -79,10 +79,9 @@ class ModuleServer(BaseServer):
|
|
|
79
79
|
|
|
80
80
|
def start(self) -> None:
|
|
81
81
|
"""Start the module server and register with the registry if configured."""
|
|
82
|
-
logger.info("Starting module server",extra={"server_config": self.server_config})
|
|
82
|
+
logger.info("Starting module server", extra={"server_config": self.server_config})
|
|
83
83
|
super().start()
|
|
84
84
|
|
|
85
|
-
logger.debug("Starting module server",extra={"server_config": self.server_config})
|
|
86
85
|
# If a registry address is provided, register the module
|
|
87
86
|
if self.server_config.registry_address:
|
|
88
87
|
try:
|
|
@@ -91,14 +90,12 @@ class ModuleServer(BaseServer):
|
|
|
91
90
|
logger.exception("Failed to register with registry")
|
|
92
91
|
|
|
93
92
|
if self.module_servicer is not None:
|
|
94
|
-
logger.debug(
|
|
95
|
-
"Setup post init started",extra={"client_config": self.client_config}
|
|
96
|
-
)
|
|
93
|
+
logger.debug("Setup post init started", extra={"client_config": self.client_config})
|
|
97
94
|
self.module_servicer.setup.__post_init__(self.client_config)
|
|
98
95
|
|
|
99
96
|
async def start_async(self) -> None:
|
|
100
97
|
"""Start the module server and register with the registry if configured."""
|
|
101
|
-
logger.info("Starting module server",extra={"server_config": self.server_config})
|
|
98
|
+
logger.info("Starting module server", extra={"server_config": self.server_config})
|
|
102
99
|
await super().start_async()
|
|
103
100
|
# If a registry address is provided, register the module
|
|
104
101
|
if self.server_config.registry_address:
|
|
@@ -108,10 +105,8 @@ class ModuleServer(BaseServer):
|
|
|
108
105
|
logger.exception("Failed to register with registry")
|
|
109
106
|
|
|
110
107
|
if self.module_servicer is not None:
|
|
111
|
-
logger.info(
|
|
112
|
-
|
|
113
|
-
)
|
|
114
|
-
await self.module_servicer.job_manager._start()
|
|
108
|
+
logger.info("Setup post init started", extra={"client_config": self.client_config})
|
|
109
|
+
await self.module_servicer.job_manager.start()
|
|
115
110
|
self.module_servicer.setup.__post_init__(self.client_config)
|
|
116
111
|
|
|
117
112
|
def stop(self, grace: float | None = None) -> None:
|
|
@@ -134,6 +129,7 @@ class ModuleServer(BaseServer):
|
|
|
134
129
|
logger.debug(
|
|
135
130
|
"Registering module with registry at %s",
|
|
136
131
|
self.server_config.registry_address,
|
|
132
|
+
extra={"server_config": self.server_config},
|
|
137
133
|
)
|
|
138
134
|
|
|
139
135
|
# Create appropriate channel based on security mode
|
|
@@ -148,16 +144,11 @@ class ModuleServer(BaseServer):
|
|
|
148
144
|
|
|
149
145
|
metadata = metadata_pb2.Metadata(
|
|
150
146
|
name=self.module_class.metadata["name"],
|
|
151
|
-
tags=[
|
|
152
|
-
metadata_pb2.Tag(tag=tag)
|
|
153
|
-
for tag in self.module_class.metadata["tags"]
|
|
154
|
-
],
|
|
147
|
+
tags=[metadata_pb2.Tag(tag=tag) for tag in self.module_class.metadata["tags"]],
|
|
155
148
|
description=self.module_class.metadata["description"],
|
|
156
149
|
)
|
|
157
150
|
|
|
158
|
-
self.module_class.metadata["module_id"] = (
|
|
159
|
-
f"{self.module_class.metadata['name']}:{uuid.uuid4()}"
|
|
160
|
-
)
|
|
151
|
+
self.module_class.metadata["module_id"] = f"{self.module_class.metadata['name']}:{uuid.uuid4()}"
|
|
161
152
|
# Create registration request
|
|
162
153
|
request = registration_pb2.RegisterRequest(
|
|
163
154
|
module_id=self.module_class.metadata["module_id"],
|
|
@@ -173,6 +164,7 @@ class ModuleServer(BaseServer):
|
|
|
173
164
|
"Request sent to registry for module: %s:%s",
|
|
174
165
|
self.module_class.metadata["name"],
|
|
175
166
|
self.module_class.metadata["module_id"],
|
|
167
|
+
extra={"module_info": self.module_class.metadata},
|
|
176
168
|
)
|
|
177
169
|
response = stub.RegisterModule(request)
|
|
178
170
|
|
|
@@ -234,9 +226,7 @@ class ModuleServer(BaseServer):
|
|
|
234
226
|
):
|
|
235
227
|
# Secure channel
|
|
236
228
|
# Secure channel
|
|
237
|
-
root_certificates = Path(
|
|
238
|
-
self.client_config.credentials.root_cert_path
|
|
239
|
-
).read_bytes()
|
|
229
|
+
root_certificates = Path(self.client_config.credentials.root_cert_path).read_bytes()
|
|
240
230
|
|
|
241
231
|
# mTLS channel
|
|
242
232
|
private_key = None
|
|
@@ -245,12 +235,8 @@ class ModuleServer(BaseServer):
|
|
|
245
235
|
self.client_config.credentials.client_cert_path is not None
|
|
246
236
|
and self.client_config.credentials.client_key_path is not None
|
|
247
237
|
):
|
|
248
|
-
private_key = Path(
|
|
249
|
-
|
|
250
|
-
).read_bytes()
|
|
251
|
-
certificate_chain = Path(
|
|
252
|
-
self.client_config.credentials.client_cert_path
|
|
253
|
-
).read_bytes()
|
|
238
|
+
private_key = Path(self.client_config.credentials.client_key_path).read_bytes()
|
|
239
|
+
certificate_chain = Path(self.client_config.credentials.client_cert_path).read_bytes()
|
|
254
240
|
|
|
255
241
|
# Create channel credentials
|
|
256
242
|
channel_credentials = grpc.ssl_channel_credentials(
|
|
@@ -258,9 +244,7 @@ class ModuleServer(BaseServer):
|
|
|
258
244
|
certificate_chain=certificate_chain,
|
|
259
245
|
private_key=private_key,
|
|
260
246
|
)
|
|
261
|
-
return grpc.secure_channel(
|
|
262
|
-
self.server_config.registry_address, channel_credentials
|
|
263
|
-
)
|
|
247
|
+
return grpc.secure_channel(self.server_config.registry_address, channel_credentials)
|
|
264
248
|
# Insecure channel
|
|
265
249
|
return grpc.insecure_channel(self.server_config.registry_address)
|
|
266
250
|
|
|
@@ -76,9 +76,9 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
76
76
|
self.job_manager = job_manager_class(module_class, self.args.services_mode)
|
|
77
77
|
|
|
78
78
|
logger.debug(
|
|
79
|
-
"ModuleServicer initialized with job manager: %s
|
|
79
|
+
"ModuleServicer initialized with job manager: %s",
|
|
80
80
|
self.args.job_manager_mode,
|
|
81
|
-
self.job_manager,
|
|
81
|
+
extra={"job_manager": self.job_manager},
|
|
82
82
|
)
|
|
83
83
|
self.setup = GrpcSetup() if self.args.services_mode == ServicesMode.REMOTE else DefaultSetup()
|
|
84
84
|
|
|
@@ -201,27 +201,33 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
201
201
|
yield lifecycle_pb2.StartModuleResponse(success=False)
|
|
202
202
|
return
|
|
203
203
|
|
|
204
|
-
|
|
205
|
-
async
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
204
|
+
try:
|
|
205
|
+
async with self.job_manager.generate_stream_consumer(job_id) as stream: # type: ignore
|
|
206
|
+
async for message in stream:
|
|
207
|
+
if message.get("error", None) is not None:
|
|
208
|
+
logger.error("Error in output_data", extra={"message": message})
|
|
209
|
+
context.set_code(message["error"]["code"])
|
|
210
|
+
context.set_details(message["error"]["error_message"])
|
|
211
|
+
yield lifecycle_pb2.StartModuleResponse(success=False, job_id=job_id)
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
if message.get("exception", None) is not None:
|
|
215
|
+
logger.error("Exception in output_data", extra={"message": message})
|
|
216
|
+
context.set_code(message["short_description"])
|
|
217
|
+
context.set_details(message["exception"])
|
|
218
|
+
yield lifecycle_pb2.StartModuleResponse(success=False, job_id=job_id)
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
if message.get("code", None) is not None and message.get("code") == "__END_OF_STREAM__":
|
|
222
|
+
yield lifecycle_pb2.StartModuleResponse(success=True, job_id=job_id)
|
|
223
|
+
break
|
|
224
|
+
|
|
225
|
+
proto = json_format.ParseDict(message, struct_pb2.Struct(), ignore_unknown_fields=True)
|
|
226
|
+
yield lifecycle_pb2.StartModuleResponse(success=True, output=proto, job_id=job_id)
|
|
227
|
+
finally:
|
|
228
|
+
await self.job_manager.tasks[job_id]
|
|
229
|
+
await self.job_manager.clean_session(job_id)
|
|
230
|
+
|
|
225
231
|
logger.info("Job %s finished", job_id)
|
|
226
232
|
|
|
227
233
|
async def StopModule( # noqa: N802
|
|
@@ -248,7 +254,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
|
|
|
248
254
|
context.set_details(message)
|
|
249
255
|
return lifecycle_pb2.StopModuleResponse(success=False)
|
|
250
256
|
|
|
251
|
-
logger.debug("Job %s stopped successfully", request.job_id)
|
|
257
|
+
logger.debug("Job %s stopped successfully", request.job_id, extra={"job_id": request.job_id})
|
|
252
258
|
return lifecycle_pb2.StopModuleResponse(success=True)
|
|
253
259
|
|
|
254
260
|
async def GetModuleStatus( # noqa: N802
|
|
@@ -62,11 +62,11 @@ class GrpcClientWrapper:
|
|
|
62
62
|
"""
|
|
63
63
|
try:
|
|
64
64
|
# Call the register method
|
|
65
|
-
logger.debug("send request to %s", query_endpoint)
|
|
65
|
+
logger.debug("send request to %s", query_endpoint, extra={"request": request})
|
|
66
66
|
response = getattr(self.stub, query_endpoint)(request)
|
|
67
|
-
logger.debug("receive response from request to
|
|
67
|
+
logger.debug("receive response from request to %s", query_endpoint, extra={"response": response})
|
|
68
68
|
except grpc.RpcError as e:
|
|
69
|
-
logger.exception("RPC error during %s
|
|
69
|
+
logger.exception("RPC error during %s", query_endpoint, extra={"error": e.details()})
|
|
70
70
|
raise ServerError
|
|
71
71
|
else:
|
|
72
72
|
return response
|
|
@@ -262,7 +262,7 @@ class ModuleServerConfig(ServerConfig):
|
|
|
262
262
|
registry_address: Address of the registry server
|
|
263
263
|
"""
|
|
264
264
|
|
|
265
|
-
registry_address: str
|
|
265
|
+
registry_address: str = Field(..., description="Address of the registry server")
|
|
266
266
|
|
|
267
267
|
|
|
268
268
|
class RegistryServerConfig(ServerConfig):
|
digitalkin/logger.py
CHANGED
|
@@ -48,12 +48,10 @@ class ColorJSONFormatter(logging.Formatter):
|
|
|
48
48
|
log_obj: dict[str, Any] = {
|
|
49
49
|
"timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
|
|
50
50
|
"level": record.levelname.lower(),
|
|
51
|
-
"logger": record.name,
|
|
52
51
|
"message": record.getMessage(),
|
|
53
|
-
"
|
|
54
|
-
"
|
|
52
|
+
"module": record.module,
|
|
53
|
+
"location": f"{record.pathname}:{record.lineno}:{record.funcName}",
|
|
55
54
|
}
|
|
56
|
-
|
|
57
55
|
# Add exception info if present
|
|
58
56
|
if record.exc_info:
|
|
59
57
|
log_obj["exception"] = self.formatException(record.exc_info)
|
|
@@ -98,23 +96,62 @@ class ColorJSONFormatter(logging.Formatter):
|
|
|
98
96
|
return f"{color}{json_str}{self.reset}"
|
|
99
97
|
|
|
100
98
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
99
|
+
def setup_logger(
|
|
100
|
+
name: str,
|
|
101
|
+
level: int = logging.INFO,
|
|
102
|
+
additional_loggers: dict[str, int] | None = None,
|
|
103
|
+
*,
|
|
104
|
+
is_production: bool | None = None,
|
|
105
|
+
configure_root: bool = True,
|
|
106
|
+
) -> logging.Logger:
|
|
107
|
+
"""Set up a logger with the ColorJSONFormatter.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
name: Name of the logger to create
|
|
111
|
+
level: Logging level (default: logging.INFO)
|
|
112
|
+
is_production: Whether running in production. If None, checks RAILWAY_SERVICE_NAME env var
|
|
113
|
+
configure_root: Whether to configure root logger (default: True)
|
|
114
|
+
additional_loggers: Dict of additional logger names and their levels to configure
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
logging.Logger: Configured logger instance
|
|
118
|
+
"""
|
|
119
|
+
# Determine if we're in production
|
|
120
|
+
if is_production is None:
|
|
121
|
+
is_production = os.getenv("RAILWAY_SERVICE_NAME") is not None
|
|
122
|
+
|
|
123
|
+
# Configure root logger if requested
|
|
124
|
+
if configure_root:
|
|
125
|
+
logging.basicConfig(
|
|
126
|
+
level=logging.DEBUG,
|
|
127
|
+
stream=sys.stdout,
|
|
128
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Configure additional loggers
|
|
132
|
+
if additional_loggers:
|
|
133
|
+
for logger_name, logger_level in additional_loggers.items():
|
|
134
|
+
logging.getLogger(logger_name).setLevel(logger_level)
|
|
135
|
+
|
|
136
|
+
# Create and configure the main logger
|
|
137
|
+
logger = logging.getLogger(name)
|
|
138
|
+
logger.setLevel(level)
|
|
139
|
+
# Only add handler if not already configured
|
|
140
|
+
if not logger.handlers:
|
|
141
|
+
ch = logging.StreamHandler()
|
|
142
|
+
ch.setLevel(level)
|
|
143
|
+
ch.setFormatter(ColorJSONFormatter(is_production=is_production))
|
|
144
|
+
logger.addHandler(ch)
|
|
145
|
+
logger.propagate = False
|
|
146
|
+
|
|
147
|
+
return logger
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
logger = setup_logger(
|
|
151
|
+
"digitalkin",
|
|
152
|
+
level=logging.INFO,
|
|
153
|
+
additional_loggers={
|
|
154
|
+
"grpc": logging.DEBUG,
|
|
155
|
+
"asyncio": logging.DEBUG,
|
|
156
|
+
},
|
|
105
157
|
)
|
|
106
|
-
|
|
107
|
-
logging.getLogger("grpc").setLevel(logging.DEBUG)
|
|
108
|
-
logging.getLogger("asyncio").setLevel(logging.DEBUG)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
logger = logging.getLogger("digitalkin")
|
|
112
|
-
is_production = os.getenv("RAILWAY_SERVICE_NAME") is not None
|
|
113
|
-
|
|
114
|
-
if not logger.handlers:
|
|
115
|
-
ch = logging.StreamHandler()
|
|
116
|
-
ch.setLevel(logging.INFO)
|
|
117
|
-
ch.setFormatter(ColorJSONFormatter(is_production=is_production))
|
|
118
|
-
|
|
119
|
-
logger.addHandler(ch)
|
|
120
|
-
logger.propagate = False
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Mixin definitions."""
|
|
2
|
+
|
|
3
|
+
from digitalkin.mixins.base_mixin import BaseMixin
|
|
4
|
+
from digitalkin.mixins.callback_mixin import UserMessageMixin
|
|
5
|
+
from digitalkin.mixins.chat_history_mixin import ChatHistoryMixin
|
|
6
|
+
from digitalkin.mixins.cost_mixin import CostMixin
|
|
7
|
+
from digitalkin.mixins.filesystem_mixin import FilesystemMixin
|
|
8
|
+
from digitalkin.mixins.logger_mixin import LoggerMixin
|
|
9
|
+
from digitalkin.mixins.storage_mixin import StorageMixin
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"BaseMixin",
|
|
13
|
+
"ChatHistoryMixin",
|
|
14
|
+
"CostMixin",
|
|
15
|
+
"FilesystemMixin",
|
|
16
|
+
"LoggerMixin",
|
|
17
|
+
"StorageMixin",
|
|
18
|
+
"UserMessageMixin",
|
|
19
|
+
]
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""Simple toolkit class with basic and simple API access in the Triggers."""
|
|
2
|
+
|
|
3
|
+
from digitalkin.mixins.chat_history_mixin import ChatHistoryMixin
|
|
4
|
+
from digitalkin.mixins.cost_mixin import CostMixin
|
|
5
|
+
from digitalkin.mixins.file_history_mixin import FileHistoryMixin
|
|
6
|
+
from digitalkin.mixins.logger_mixin import LoggerMixin
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseMixin(CostMixin, ChatHistoryMixin, FileHistoryMixin, LoggerMixin):
|
|
10
|
+
"""Base Mixin to access to minimum Module Context functionnalities in the Triggers."""
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""User callback to send a message from the Trigger."""
|
|
2
|
+
|
|
3
|
+
from typing import Generic
|
|
4
|
+
|
|
5
|
+
from digitalkin.models.module.module_context import ModuleContext
|
|
6
|
+
from digitalkin.models.module.module_types import OutputModelT
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class UserMessageMixin(Generic[OutputModelT]):
|
|
10
|
+
"""Mixin providing callback operations through the callbacks .
|
|
11
|
+
|
|
12
|
+
This mixin wraps callback strategy calls to provide a cleaner API
|
|
13
|
+
for direct messaging in trigger handlers.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
async def send_message(context: ModuleContext, output: OutputModelT) -> None:
|
|
18
|
+
"""Send a message using the callbacks strategy.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
context: Module context containing the callbacks strategy.
|
|
22
|
+
output: Message to send with the Module defined output Type.
|
|
23
|
+
"""
|
|
24
|
+
await context.callbacks.send_message(output)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Context mixins providing ergonomic access to service strategies.
|
|
2
|
+
|
|
3
|
+
This module provides mixins that wrap service strategy calls with cleaner APIs,
|
|
4
|
+
following Django/FastAPI patterns where context is passed explicitly to each method.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Generic
|
|
8
|
+
|
|
9
|
+
from digitalkin.mixins.callback_mixin import UserMessageMixin
|
|
10
|
+
from digitalkin.mixins.logger_mixin import LoggerMixin
|
|
11
|
+
from digitalkin.mixins.storage_mixin import StorageMixin
|
|
12
|
+
from digitalkin.models.module.module_context import ModuleContext
|
|
13
|
+
from digitalkin.models.module.module_types import InputModelT, OutputModelT
|
|
14
|
+
from digitalkin.models.services.storage import BaseMessage, ChatHistory, Role
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ChatHistoryMixin(UserMessageMixin, StorageMixin, LoggerMixin, Generic[InputModelT, OutputModelT]):
|
|
18
|
+
"""Mixin providing chat history operations through storage strategy.
|
|
19
|
+
|
|
20
|
+
This mixin provides a higher-level API for managing chat history,
|
|
21
|
+
using the storage strategy as the underlying persistence mechanism.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
CHAT_HISTORY_COLLECTION = "chat_history"
|
|
25
|
+
CHAT_HISTORY_RECORD_ID = "full_chat_history"
|
|
26
|
+
|
|
27
|
+
def _get_history_key(self, context: ModuleContext) -> str:
|
|
28
|
+
"""Get session-specific history key.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
context: Module context containing session information
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Unique history key for the current session
|
|
35
|
+
"""
|
|
36
|
+
mission_id = getattr(context.session, "mission_id", None) or "default"
|
|
37
|
+
return f"{self.CHAT_HISTORY_RECORD_ID}_{mission_id}"
|
|
38
|
+
|
|
39
|
+
def load_chat_history(self, context: ModuleContext) -> ChatHistory:
|
|
40
|
+
"""Load chat history for the current session.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
context: Module context containing storage strategy
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Chat history object, empty if none exists or loading fails
|
|
47
|
+
"""
|
|
48
|
+
history_key = self._get_history_key(context)
|
|
49
|
+
|
|
50
|
+
if (raw_history := self.read_storage(context, self.CHAT_HISTORY_COLLECTION, history_key)) is not None:
|
|
51
|
+
return ChatHistory.model_validate(raw_history.data)
|
|
52
|
+
return ChatHistory(messages=[])
|
|
53
|
+
|
|
54
|
+
def append_chat_history_message(
|
|
55
|
+
self,
|
|
56
|
+
context: ModuleContext,
|
|
57
|
+
role: Role,
|
|
58
|
+
content: Any, # noqa: ANN401
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Append a message to chat history.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
context: Module context containing storage strategy
|
|
64
|
+
role: Message role (user, assistant, system)
|
|
65
|
+
content: Message content
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
StorageServiceError: If history update fails
|
|
69
|
+
"""
|
|
70
|
+
history_key = self._get_history_key(context)
|
|
71
|
+
chat_history = self.load_chat_history(context)
|
|
72
|
+
|
|
73
|
+
chat_history.messages.append(BaseMessage(role=role, content=content))
|
|
74
|
+
if len(chat_history.messages) == 1:
|
|
75
|
+
# Create new record
|
|
76
|
+
self.log_debug(context, f"Creating new chat history for session: {history_key}")
|
|
77
|
+
self.store_storage(
|
|
78
|
+
context,
|
|
79
|
+
self.CHAT_HISTORY_COLLECTION,
|
|
80
|
+
history_key,
|
|
81
|
+
chat_history.model_dump(),
|
|
82
|
+
data_type="OUTPUT",
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
self.log_debug(context, f"Updating chat history for session: {history_key}")
|
|
86
|
+
self.update_storage(
|
|
87
|
+
context,
|
|
88
|
+
self.CHAT_HISTORY_COLLECTION,
|
|
89
|
+
history_key,
|
|
90
|
+
chat_history.model_dump(),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
async def save_send_message(
|
|
94
|
+
self,
|
|
95
|
+
context: ModuleContext,
|
|
96
|
+
output: OutputModelT,
|
|
97
|
+
role: Role,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Save the output message to the chat history and send a response to the Module request.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
context: Module context containing storage strategy
|
|
103
|
+
role: Message role (user, assistant, system)
|
|
104
|
+
output: Message content as Pydantic Class
|
|
105
|
+
"""
|
|
106
|
+
# TO-DO: we should define a default output message type to ease user experience
|
|
107
|
+
self.append_chat_history_message(context=context, role=role, content=str(output.root))
|
|
108
|
+
await self.send_message(context=context, output=output)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Cost Mixin to ease trigger deveolpment."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from digitalkin.models.module.module_context import ModuleContext
|
|
6
|
+
from digitalkin.services.cost.cost_strategy import CostData
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CostMixin:
|
|
10
|
+
"""Mixin providing cost tracking operations through the cost strategy.
|
|
11
|
+
|
|
12
|
+
This mixin wraps cost strategy calls to provide a cleaner API
|
|
13
|
+
for cost tracking in trigger handlers.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
def add_cost(context: ModuleContext, name: str, cost_config_name: str, quantity: float) -> None:
|
|
18
|
+
"""Add a cost entry using the cost strategy.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
context: Module context containing the cost strategy
|
|
22
|
+
name: Name/identifier for this cost entry
|
|
23
|
+
cost_config_name: Name of the cost configuration to use
|
|
24
|
+
quantity: Quantity of units consumed
|
|
25
|
+
|
|
26
|
+
Raises:
|
|
27
|
+
CostServiceError: If cost addition fails
|
|
28
|
+
"""
|
|
29
|
+
return context.cost.add(name, cost_config_name, quantity)
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def get_cost(context: ModuleContext, name: str) -> list[CostData]:
|
|
33
|
+
"""Get cost entries for a specific name.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
context: Module context containing the cost strategy
|
|
37
|
+
name: Name/identifier to get costs for
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
List of cost data entries
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
CostServiceError: If cost retrieval fails
|
|
44
|
+
"""
|
|
45
|
+
return context.cost.get(name)
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def get_costs(
|
|
49
|
+
context: ModuleContext,
|
|
50
|
+
names: list[str] | None = None,
|
|
51
|
+
cost_types: list[
|
|
52
|
+
Literal[
|
|
53
|
+
"TOKEN_INPUT",
|
|
54
|
+
"TOKEN_OUTPUT",
|
|
55
|
+
"API_CALL",
|
|
56
|
+
"STORAGE",
|
|
57
|
+
"TIME",
|
|
58
|
+
"OTHER",
|
|
59
|
+
]
|
|
60
|
+
]
|
|
61
|
+
| None = None,
|
|
62
|
+
) -> list[CostData]:
|
|
63
|
+
"""Get filtered cost entries.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
context: Module context containing the cost strategy
|
|
67
|
+
names: Optional list of names to filter by
|
|
68
|
+
cost_types: Optional list of cost types to filter by
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
List of filtered cost data entries
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
CostServiceError: If cost retrieval fails
|
|
75
|
+
"""
|
|
76
|
+
return context.cost.get_filtered(names, cost_types)
|