digitalkin 0.2.25rc1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. digitalkin/__version__.py +1 -1
  2. digitalkin/grpc_servers/_base_server.py +1 -1
  3. digitalkin/grpc_servers/module_server.py +26 -42
  4. digitalkin/grpc_servers/module_servicer.py +30 -24
  5. digitalkin/grpc_servers/utils/grpc_client_wrapper.py +3 -3
  6. digitalkin/grpc_servers/utils/models.py +1 -1
  7. digitalkin/logger.py +60 -23
  8. digitalkin/mixins/__init__.py +19 -0
  9. digitalkin/mixins/base_mixin.py +10 -0
  10. digitalkin/mixins/callback_mixin.py +24 -0
  11. digitalkin/mixins/chat_history_mixin.py +108 -0
  12. digitalkin/mixins/cost_mixin.py +76 -0
  13. digitalkin/mixins/file_history_mixin.py +99 -0
  14. digitalkin/mixins/filesystem_mixin.py +47 -0
  15. digitalkin/mixins/logger_mixin.py +59 -0
  16. digitalkin/mixins/storage_mixin.py +79 -0
  17. digitalkin/models/module/__init__.py +2 -0
  18. digitalkin/models/module/module.py +9 -1
  19. digitalkin/models/module/module_context.py +90 -6
  20. digitalkin/models/module/module_types.py +5 -5
  21. digitalkin/models/module/task_monitor.py +51 -0
  22. digitalkin/models/services/__init__.py +9 -0
  23. digitalkin/models/services/storage.py +39 -5
  24. digitalkin/modules/_base_module.py +105 -74
  25. digitalkin/modules/job_manager/base_job_manager.py +12 -8
  26. digitalkin/modules/job_manager/single_job_manager.py +84 -78
  27. digitalkin/modules/job_manager/surrealdb_repository.py +225 -0
  28. digitalkin/modules/job_manager/task_manager.py +391 -0
  29. digitalkin/modules/job_manager/task_session.py +276 -0
  30. digitalkin/modules/job_manager/taskiq_job_manager.py +2 -2
  31. digitalkin/modules/tool_module.py +10 -2
  32. digitalkin/modules/trigger_handler.py +7 -6
  33. digitalkin/services/cost/__init__.py +9 -2
  34. digitalkin/services/storage/grpc_storage.py +1 -1
  35. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/METADATA +18 -18
  36. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/RECORD +39 -26
  37. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/WHEEL +0 -0
  38. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/licenses/LICENSE +0 -0
  39. {digitalkin-0.2.25rc1.dist-info → digitalkin-0.3.0.dist-info}/top_level.txt +0 -0
digitalkin/__version__.py CHANGED
@@ -5,4 +5,4 @@ from importlib.metadata import PackageNotFoundError, version
5
5
  try:
6
6
  __version__ = version("digitalkin")
7
7
  except PackageNotFoundError:
8
- __version__ = "0.2.25rc1"
8
+ __version__ = "0.3.0"
@@ -299,7 +299,7 @@ class BaseServer(abc.ABC):
299
299
  self._add_reflection()
300
300
 
301
301
  # Start the server
302
- logger.debug("Starting gRPC server on %s", self.config.address)
302
+ logger.debug("Starting gRPC server on %s", self.config.address, extra={"config": self.config})
303
303
  try:
304
304
  if self.config.mode == ServerMode.ASYNC:
305
305
  # For async server, use the event loop
@@ -1,20 +1,9 @@
1
1
  """Module gRPC server implementation for DigitalKin."""
2
2
 
3
- from pathlib import Path
4
3
  import uuid
4
+ from pathlib import Path
5
5
 
6
6
  import grpc
7
-
8
- from digitalkin.grpc_servers._base_server import BaseServer
9
- from digitalkin.grpc_servers.module_servicer import ModuleServicer
10
- from digitalkin.grpc_servers.utils.exceptions import ServerError
11
- from digitalkin.grpc_servers.utils.models import (
12
- ClientConfig,
13
- ModuleServerConfig,
14
- SecurityMode,
15
- )
16
- from digitalkin.modules._base_module import BaseModule
17
-
18
7
  from digitalkin_proto.digitalkin.module.v2 import (
19
8
  module_service_pb2,
20
9
  module_service_pb2_grpc,
@@ -24,7 +13,17 @@ from digitalkin_proto.digitalkin.module_registry.v2 import (
24
13
  module_registry_service_pb2_grpc,
25
14
  registration_pb2,
26
15
  )
16
+
17
+ from digitalkin.grpc_servers._base_server import BaseServer
18
+ from digitalkin.grpc_servers.module_servicer import ModuleServicer
19
+ from digitalkin.grpc_servers.utils.exceptions import ServerError
20
+ from digitalkin.grpc_servers.utils.models import (
21
+ ClientConfig,
22
+ ModuleServerConfig,
23
+ SecurityMode,
24
+ )
27
25
  from digitalkin.logger import logger
26
+ from digitalkin.modules._base_module import BaseModule
28
27
 
29
28
 
30
29
  class ModuleServer(BaseServer):
@@ -50,7 +49,8 @@ class ModuleServer(BaseServer):
50
49
 
51
50
  Args:
52
51
  module_class: The module instance to be served.
53
- config: Server configuration including registry address if auto-registration is desired.
52
+ server_config: Server configuration including registry address if auto-registration is desired.
53
+ client_config: Client configuration used by services.
54
54
  """
55
55
  super().__init__(server_config)
56
56
  self.module_class = module_class
@@ -79,10 +79,9 @@ class ModuleServer(BaseServer):
79
79
 
80
80
  def start(self) -> None:
81
81
  """Start the module server and register with the registry if configured."""
82
- logger.info("Starting module server",extra={"server_config": self.server_config})
82
+ logger.info("Starting module server", extra={"server_config": self.server_config})
83
83
  super().start()
84
84
 
85
- logger.debug("Starting module server",extra={"server_config": self.server_config})
86
85
  # If a registry address is provided, register the module
87
86
  if self.server_config.registry_address:
88
87
  try:
@@ -91,14 +90,12 @@ class ModuleServer(BaseServer):
91
90
  logger.exception("Failed to register with registry")
92
91
 
93
92
  if self.module_servicer is not None:
94
- logger.debug(
95
- "Setup post init started",extra={"client_config": self.client_config}
96
- )
93
+ logger.debug("Setup post init started", extra={"client_config": self.client_config})
97
94
  self.module_servicer.setup.__post_init__(self.client_config)
98
95
 
99
96
  async def start_async(self) -> None:
100
97
  """Start the module server and register with the registry if configured."""
101
- logger.info("Starting module server",extra={"server_config": self.server_config})
98
+ logger.info("Starting module server", extra={"server_config": self.server_config})
102
99
  await super().start_async()
103
100
  # If a registry address is provided, register the module
104
101
  if self.server_config.registry_address:
@@ -108,10 +105,8 @@ class ModuleServer(BaseServer):
108
105
  logger.exception("Failed to register with registry")
109
106
 
110
107
  if self.module_servicer is not None:
111
- logger.info(
112
- "Setup post init started",extra={"client_config": self.client_config}
113
- )
114
- await self.module_servicer.job_manager._start()
108
+ logger.info("Setup post init started", extra={"client_config": self.client_config})
109
+ await self.module_servicer.job_manager.start()
115
110
  self.module_servicer.setup.__post_init__(self.client_config)
116
111
 
117
112
  def stop(self, grace: float | None = None) -> None:
@@ -134,6 +129,7 @@ class ModuleServer(BaseServer):
134
129
  logger.debug(
135
130
  "Registering module with registry at %s",
136
131
  self.server_config.registry_address,
132
+ extra={"server_config": self.server_config},
137
133
  )
138
134
 
139
135
  # Create appropriate channel based on security mode
@@ -148,16 +144,11 @@ class ModuleServer(BaseServer):
148
144
 
149
145
  metadata = metadata_pb2.Metadata(
150
146
  name=self.module_class.metadata["name"],
151
- tags=[
152
- metadata_pb2.Tag(tag=tag)
153
- for tag in self.module_class.metadata["tags"]
154
- ],
147
+ tags=[metadata_pb2.Tag(tag=tag) for tag in self.module_class.metadata["tags"]],
155
148
  description=self.module_class.metadata["description"],
156
149
  )
157
150
 
158
- self.module_class.metadata["module_id"] = (
159
- f"{self.module_class.metadata['name']}:{uuid.uuid4()}"
160
- )
151
+ self.module_class.metadata["module_id"] = f"{self.module_class.metadata['name']}:{uuid.uuid4()}"
161
152
  # Create registration request
162
153
  request = registration_pb2.RegisterRequest(
163
154
  module_id=self.module_class.metadata["module_id"],
@@ -173,6 +164,7 @@ class ModuleServer(BaseServer):
173
164
  "Request sent to registry for module: %s:%s",
174
165
  self.module_class.metadata["name"],
175
166
  self.module_class.metadata["module_id"],
167
+ extra={"module_info": self.module_class.metadata},
176
168
  )
177
169
  response = stub.RegisterModule(request)
178
170
 
@@ -234,9 +226,7 @@ class ModuleServer(BaseServer):
234
226
  ):
235
227
  # Secure channel
236
228
  # Secure channel
237
- root_certificates = Path(
238
- self.client_config.credentials.root_cert_path
239
- ).read_bytes()
229
+ root_certificates = Path(self.client_config.credentials.root_cert_path).read_bytes()
240
230
 
241
231
  # mTLS channel
242
232
  private_key = None
@@ -245,12 +235,8 @@ class ModuleServer(BaseServer):
245
235
  self.client_config.credentials.client_cert_path is not None
246
236
  and self.client_config.credentials.client_key_path is not None
247
237
  ):
248
- private_key = Path(
249
- self.client_config.credentials.client_key_path
250
- ).read_bytes()
251
- certificate_chain = Path(
252
- self.client_config.credentials.client_cert_path
253
- ).read_bytes()
238
+ private_key = Path(self.client_config.credentials.client_key_path).read_bytes()
239
+ certificate_chain = Path(self.client_config.credentials.client_cert_path).read_bytes()
254
240
 
255
241
  # Create channel credentials
256
242
  channel_credentials = grpc.ssl_channel_credentials(
@@ -258,9 +244,7 @@ class ModuleServer(BaseServer):
258
244
  certificate_chain=certificate_chain,
259
245
  private_key=private_key,
260
246
  )
261
- return grpc.secure_channel(
262
- self.server_config.registry_address, channel_credentials
263
- )
247
+ return grpc.secure_channel(self.server_config.registry_address, channel_credentials)
264
248
  # Insecure channel
265
249
  return grpc.insecure_channel(self.server_config.registry_address)
266
250
 
@@ -76,9 +76,9 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
76
76
  self.job_manager = job_manager_class(module_class, self.args.services_mode)
77
77
 
78
78
  logger.debug(
79
- "ModuleServicer initialized with job manager: %s | %s",
79
+ "ModuleServicer initialized with job manager: %s",
80
80
  self.args.job_manager_mode,
81
- self.job_manager,
81
+ extra={"job_manager": self.job_manager},
82
82
  )
83
83
  self.setup = GrpcSetup() if self.args.services_mode == ServicesMode.REMOTE else DefaultSetup()
84
84
 
@@ -201,27 +201,33 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
201
201
  yield lifecycle_pb2.StartModuleResponse(success=False)
202
202
  return
203
203
 
204
- async with self.job_manager.generate_stream_consumer(job_id) as stream: # type: ignore
205
- async for message in stream:
206
- if message.get("error", None) is not None:
207
- context.set_code(message["error"]["code"])
208
- context.set_details(message["error"]["error_message"])
209
- yield lifecycle_pb2.StartModuleResponse(success=False, job_id=job_id)
210
- break
211
-
212
- if message.get("exception", None) is not None:
213
- logger.error("Error in output_data")
214
- context.set_code(message["short_description"])
215
- context.set_details(message["exception"])
216
- yield lifecycle_pb2.StartModuleResponse(success=False, job_id=job_id)
217
- break
218
-
219
- if message.get("code", None) is not None and message.get("code") == "__END_OF_STREAM__":
220
- yield lifecycle_pb2.StartModuleResponse(success=True, job_id=job_id)
221
- break
222
-
223
- proto = json_format.ParseDict(message, struct_pb2.Struct(), ignore_unknown_fields=True)
224
- yield lifecycle_pb2.StartModuleResponse(success=True, output=proto, job_id=job_id)
204
+ try:
205
+ async with self.job_manager.generate_stream_consumer(job_id) as stream: # type: ignore
206
+ async for message in stream:
207
+ if message.get("error", None) is not None:
208
+ logger.error("Error in output_data", extra={"message": message})
209
+ context.set_code(message["error"]["code"])
210
+ context.set_details(message["error"]["error_message"])
211
+ yield lifecycle_pb2.StartModuleResponse(success=False, job_id=job_id)
212
+ break
213
+
214
+ if message.get("exception", None) is not None:
215
+ logger.error("Exception in output_data", extra={"message": message})
216
+ context.set_code(message["short_description"])
217
+ context.set_details(message["exception"])
218
+ yield lifecycle_pb2.StartModuleResponse(success=False, job_id=job_id)
219
+ break
220
+
221
+ if message.get("code", None) is not None and message.get("code") == "__END_OF_STREAM__":
222
+ yield lifecycle_pb2.StartModuleResponse(success=True, job_id=job_id)
223
+ break
224
+
225
+ proto = json_format.ParseDict(message, struct_pb2.Struct(), ignore_unknown_fields=True)
226
+ yield lifecycle_pb2.StartModuleResponse(success=True, output=proto, job_id=job_id)
227
+ finally:
228
+ await self.job_manager.tasks[job_id]
229
+ await self.job_manager.clean_session(job_id)
230
+
225
231
  logger.info("Job %s finished", job_id)
226
232
 
227
233
  async def StopModule( # noqa: N802
@@ -248,7 +254,7 @@ class ModuleServicer(module_service_pb2_grpc.ModuleServiceServicer, ArgParser):
248
254
  context.set_details(message)
249
255
  return lifecycle_pb2.StopModuleResponse(success=False)
250
256
 
251
- logger.debug("Job %s stopped successfully", request.job_id)
257
+ logger.debug("Job %s stopped successfully", request.job_id, extra={"job_id": request.job_id})
252
258
  return lifecycle_pb2.StopModuleResponse(success=True)
253
259
 
254
260
  async def GetModuleStatus( # noqa: N802
@@ -62,11 +62,11 @@ class GrpcClientWrapper:
62
62
  """
63
63
  try:
64
64
  # Call the register method
65
- logger.debug("send request to %s", query_endpoint)
65
+ logger.debug("send request to %s", query_endpoint, extra={"request": request})
66
66
  response = getattr(self.stub, query_endpoint)(request)
67
- logger.debug("receive response from request to registry: %s", response)
67
+ logger.debug("receive response from request to %s", query_endpoint, extra={"response": response})
68
68
  except grpc.RpcError as e:
69
- logger.exception("RPC error during %s: %s", query_endpoint, e.details())
69
+ logger.exception("RPC error during %s", query_endpoint, extra={"error": e.details()})
70
70
  raise ServerError
71
71
  else:
72
72
  return response
@@ -262,7 +262,7 @@ class ModuleServerConfig(ServerConfig):
262
262
  registry_address: Address of the registry server
263
263
  """
264
264
 
265
- registry_address: str | None = Field(None, description="Address of the registry server")
265
+ registry_address: str = Field(..., description="Address of the registry server")
266
266
 
267
267
 
268
268
  class RegistryServerConfig(ServerConfig):
digitalkin/logger.py CHANGED
@@ -48,12 +48,10 @@ class ColorJSONFormatter(logging.Formatter):
48
48
  log_obj: dict[str, Any] = {
49
49
  "timestamp": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
50
50
  "level": record.levelname.lower(),
51
- "logger": record.name,
52
51
  "message": record.getMessage(),
53
- "location": f"{record.filename}:{record.lineno}",
54
- "function": record.funcName,
52
+ "module": record.module,
53
+ "location": f"{record.pathname}:{record.lineno}:{record.funcName}",
55
54
  }
56
-
57
55
  # Add exception info if present
58
56
  if record.exc_info:
59
57
  log_obj["exception"] = self.formatException(record.exc_info)
@@ -98,23 +96,62 @@ class ColorJSONFormatter(logging.Formatter):
98
96
  return f"{color}{json_str}{self.reset}"
99
97
 
100
98
 
101
- logging.basicConfig(
102
- level=logging.DEBUG,
103
- stream=sys.stdout,
104
- datefmt="%Y-%m-%d %H:%M:%S",
99
+ def setup_logger(
100
+ name: str,
101
+ level: int = logging.INFO,
102
+ additional_loggers: dict[str, int] | None = None,
103
+ *,
104
+ is_production: bool | None = None,
105
+ configure_root: bool = True,
106
+ ) -> logging.Logger:
107
+ """Set up a logger with the ColorJSONFormatter.
108
+
109
+ Args:
110
+ name: Name of the logger to create
111
+ level: Logging level (default: logging.INFO)
112
+ is_production: Whether running in production. If None, checks RAILWAY_SERVICE_NAME env var
113
+ configure_root: Whether to configure root logger (default: True)
114
+ additional_loggers: Dict of additional logger names and their levels to configure
115
+
116
+ Returns:
117
+ logging.Logger: Configured logger instance
118
+ """
119
+ # Determine if we're in production
120
+ if is_production is None:
121
+ is_production = os.getenv("RAILWAY_SERVICE_NAME") is not None
122
+
123
+ # Configure root logger if requested
124
+ if configure_root:
125
+ logging.basicConfig(
126
+ level=logging.DEBUG,
127
+ stream=sys.stdout,
128
+ datefmt="%Y-%m-%d %H:%M:%S",
129
+ )
130
+
131
+ # Configure additional loggers
132
+ if additional_loggers:
133
+ for logger_name, logger_level in additional_loggers.items():
134
+ logging.getLogger(logger_name).setLevel(logger_level)
135
+
136
+ # Create and configure the main logger
137
+ logger = logging.getLogger(name)
138
+ logger.setLevel(level)
139
+ # Only add handler if not already configured
140
+ if not logger.handlers:
141
+ ch = logging.StreamHandler()
142
+ ch.setLevel(level)
143
+ ch.setFormatter(ColorJSONFormatter(is_production=is_production))
144
+ logger.addHandler(ch)
145
+ logger.propagate = False
146
+
147
+ return logger
148
+
149
+
150
+ logger = setup_logger(
151
+ "digitalkin",
152
+ level=logging.INFO,
153
+ additional_loggers={
154
+ "grpc": logging.DEBUG,
155
+ "asyncio": logging.DEBUG,
156
+ },
105
157
  )
106
-
107
- logging.getLogger("grpc").setLevel(logging.DEBUG)
108
- logging.getLogger("asyncio").setLevel(logging.DEBUG)
109
-
110
-
111
- logger = logging.getLogger("digitalkin")
112
- is_production = os.getenv("RAILWAY_SERVICE_NAME") is not None
113
-
114
- if not logger.handlers:
115
- ch = logging.StreamHandler()
116
- ch.setLevel(logging.INFO)
117
- ch.setFormatter(ColorJSONFormatter(is_production=is_production))
118
-
119
- logger.addHandler(ch)
120
- logger.propagate = False
@@ -0,0 +1,19 @@
1
+ """Mixin definitions."""
2
+
3
+ from digitalkin.mixins.base_mixin import BaseMixin
4
+ from digitalkin.mixins.callback_mixin import UserMessageMixin
5
+ from digitalkin.mixins.chat_history_mixin import ChatHistoryMixin
6
+ from digitalkin.mixins.cost_mixin import CostMixin
7
+ from digitalkin.mixins.filesystem_mixin import FilesystemMixin
8
+ from digitalkin.mixins.logger_mixin import LoggerMixin
9
+ from digitalkin.mixins.storage_mixin import StorageMixin
10
+
11
+ __all__ = [
12
+ "BaseMixin",
13
+ "ChatHistoryMixin",
14
+ "CostMixin",
15
+ "FilesystemMixin",
16
+ "LoggerMixin",
17
+ "StorageMixin",
18
+ "UserMessageMixin",
19
+ ]
@@ -0,0 +1,10 @@
1
+ """Simple toolkit class with basic and simple API access in the Triggers."""
2
+
3
+ from digitalkin.mixins.chat_history_mixin import ChatHistoryMixin
4
+ from digitalkin.mixins.cost_mixin import CostMixin
5
+ from digitalkin.mixins.file_history_mixin import FileHistoryMixin
6
+ from digitalkin.mixins.logger_mixin import LoggerMixin
7
+
8
+
9
+ class BaseMixin(CostMixin, ChatHistoryMixin, FileHistoryMixin, LoggerMixin):
10
+ """Base Mixin to access to minimum Module Context functionnalities in the Triggers."""
@@ -0,0 +1,24 @@
1
+ """User callback to send a message from the Trigger."""
2
+
3
+ from typing import Generic
4
+
5
+ from digitalkin.models.module.module_context import ModuleContext
6
+ from digitalkin.models.module.module_types import OutputModelT
7
+
8
+
9
+ class UserMessageMixin(Generic[OutputModelT]):
10
+ """Mixin providing callback operations through the callbacks .
11
+
12
+ This mixin wraps callback strategy calls to provide a cleaner API
13
+ for direct messaging in trigger handlers.
14
+ """
15
+
16
+ @staticmethod
17
+ async def send_message(context: ModuleContext, output: OutputModelT) -> None:
18
+ """Send a message using the callbacks strategy.
19
+
20
+ Args:
21
+ context: Module context containing the callbacks strategy.
22
+ output: Message to send with the Module defined output Type.
23
+ """
24
+ await context.callbacks.send_message(output)
@@ -0,0 +1,108 @@
1
+ """Context mixins providing ergonomic access to service strategies.
2
+
3
+ This module provides mixins that wrap service strategy calls with cleaner APIs,
4
+ following Django/FastAPI patterns where context is passed explicitly to each method.
5
+ """
6
+
7
+ from typing import Any, Generic
8
+
9
+ from digitalkin.mixins.callback_mixin import UserMessageMixin
10
+ from digitalkin.mixins.logger_mixin import LoggerMixin
11
+ from digitalkin.mixins.storage_mixin import StorageMixin
12
+ from digitalkin.models.module.module_context import ModuleContext
13
+ from digitalkin.models.module.module_types import InputModelT, OutputModelT
14
+ from digitalkin.models.services.storage import BaseMessage, ChatHistory, Role
15
+
16
+
17
+ class ChatHistoryMixin(UserMessageMixin, StorageMixin, LoggerMixin, Generic[InputModelT, OutputModelT]):
18
+ """Mixin providing chat history operations through storage strategy.
19
+
20
+ This mixin provides a higher-level API for managing chat history,
21
+ using the storage strategy as the underlying persistence mechanism.
22
+ """
23
+
24
+ CHAT_HISTORY_COLLECTION = "chat_history"
25
+ CHAT_HISTORY_RECORD_ID = "full_chat_history"
26
+
27
+ def _get_history_key(self, context: ModuleContext) -> str:
28
+ """Get session-specific history key.
29
+
30
+ Args:
31
+ context: Module context containing session information
32
+
33
+ Returns:
34
+ Unique history key for the current session
35
+ """
36
+ mission_id = getattr(context.session, "mission_id", None) or "default"
37
+ return f"{self.CHAT_HISTORY_RECORD_ID}_{mission_id}"
38
+
39
+ def load_chat_history(self, context: ModuleContext) -> ChatHistory:
40
+ """Load chat history for the current session.
41
+
42
+ Args:
43
+ context: Module context containing storage strategy
44
+
45
+ Returns:
46
+ Chat history object, empty if none exists or loading fails
47
+ """
48
+ history_key = self._get_history_key(context)
49
+
50
+ if (raw_history := self.read_storage(context, self.CHAT_HISTORY_COLLECTION, history_key)) is not None:
51
+ return ChatHistory.model_validate(raw_history.data)
52
+ return ChatHistory(messages=[])
53
+
54
+ def append_chat_history_message(
55
+ self,
56
+ context: ModuleContext,
57
+ role: Role,
58
+ content: Any, # noqa: ANN401
59
+ ) -> None:
60
+ """Append a message to chat history.
61
+
62
+ Args:
63
+ context: Module context containing storage strategy
64
+ role: Message role (user, assistant, system)
65
+ content: Message content
66
+
67
+ Raises:
68
+ StorageServiceError: If history update fails
69
+ """
70
+ history_key = self._get_history_key(context)
71
+ chat_history = self.load_chat_history(context)
72
+
73
+ chat_history.messages.append(BaseMessage(role=role, content=content))
74
+ if len(chat_history.messages) == 1:
75
+ # Create new record
76
+ self.log_debug(context, f"Creating new chat history for session: {history_key}")
77
+ self.store_storage(
78
+ context,
79
+ self.CHAT_HISTORY_COLLECTION,
80
+ history_key,
81
+ chat_history.model_dump(),
82
+ data_type="OUTPUT",
83
+ )
84
+ else:
85
+ self.log_debug(context, f"Updating chat history for session: {history_key}")
86
+ self.update_storage(
87
+ context,
88
+ self.CHAT_HISTORY_COLLECTION,
89
+ history_key,
90
+ chat_history.model_dump(),
91
+ )
92
+
93
+ async def save_send_message(
94
+ self,
95
+ context: ModuleContext,
96
+ output: OutputModelT,
97
+ role: Role,
98
+ ) -> None:
99
+ """Save the output message to the chat history and send a response to the Module request.
100
+
101
+ Args:
102
+ context: Module context containing storage strategy
103
+ role: Message role (user, assistant, system)
104
+ output: Message content as Pydantic Class
105
+ """
106
+ # TO-DO: we should define a default output message type to ease user experience
107
+ self.append_chat_history_message(context=context, role=role, content=str(output.root))
108
+ await self.send_message(context=context, output=output)
@@ -0,0 +1,76 @@
1
+ """Cost Mixin to ease trigger deveolpment."""
2
+
3
+ from typing import Literal
4
+
5
+ from digitalkin.models.module.module_context import ModuleContext
6
+ from digitalkin.services.cost.cost_strategy import CostData
7
+
8
+
9
+ class CostMixin:
10
+ """Mixin providing cost tracking operations through the cost strategy.
11
+
12
+ This mixin wraps cost strategy calls to provide a cleaner API
13
+ for cost tracking in trigger handlers.
14
+ """
15
+
16
+ @staticmethod
17
+ def add_cost(context: ModuleContext, name: str, cost_config_name: str, quantity: float) -> None:
18
+ """Add a cost entry using the cost strategy.
19
+
20
+ Args:
21
+ context: Module context containing the cost strategy
22
+ name: Name/identifier for this cost entry
23
+ cost_config_name: Name of the cost configuration to use
24
+ quantity: Quantity of units consumed
25
+
26
+ Raises:
27
+ CostServiceError: If cost addition fails
28
+ """
29
+ return context.cost.add(name, cost_config_name, quantity)
30
+
31
+ @staticmethod
32
+ def get_cost(context: ModuleContext, name: str) -> list[CostData]:
33
+ """Get cost entries for a specific name.
34
+
35
+ Args:
36
+ context: Module context containing the cost strategy
37
+ name: Name/identifier to get costs for
38
+
39
+ Returns:
40
+ List of cost data entries
41
+
42
+ Raises:
43
+ CostServiceError: If cost retrieval fails
44
+ """
45
+ return context.cost.get(name)
46
+
47
+ @staticmethod
48
+ def get_costs(
49
+ context: ModuleContext,
50
+ names: list[str] | None = None,
51
+ cost_types: list[
52
+ Literal[
53
+ "TOKEN_INPUT",
54
+ "TOKEN_OUTPUT",
55
+ "API_CALL",
56
+ "STORAGE",
57
+ "TIME",
58
+ "OTHER",
59
+ ]
60
+ ]
61
+ | None = None,
62
+ ) -> list[CostData]:
63
+ """Get filtered cost entries.
64
+
65
+ Args:
66
+ context: Module context containing the cost strategy
67
+ names: Optional list of names to filter by
68
+ cost_types: Optional list of cost types to filter by
69
+
70
+ Returns:
71
+ List of filtered cost data entries
72
+
73
+ Raises:
74
+ CostServiceError: If cost retrieval fails
75
+ """
76
+ return context.cost.get_filtered(names, cost_types)