digitalkin 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- digitalkin/__init__.py +18 -0
- digitalkin/__version__.py +11 -0
- digitalkin/grpc/__init__.py +31 -0
- digitalkin/grpc/_base_server.py +488 -0
- digitalkin/grpc/module_server.py +233 -0
- digitalkin/grpc/module_servicer.py +304 -0
- digitalkin/grpc/registry_server.py +63 -0
- digitalkin/grpc/registry_servicer.py +451 -0
- digitalkin/grpc/utils/exceptions.py +33 -0
- digitalkin/grpc/utils/factory.py +178 -0
- digitalkin/grpc/utils/models.py +169 -0
- digitalkin/grpc/utils/types.py +24 -0
- digitalkin/logger.py +17 -0
- digitalkin/models/__init__.py +11 -0
- digitalkin/models/module/__init__.py +5 -0
- digitalkin/models/module/module.py +31 -0
- digitalkin/models/services/__init__.py +6 -0
- digitalkin/models/services/cost.py +53 -0
- digitalkin/models/services/storage.py +10 -0
- digitalkin/modules/__init__.py +7 -0
- digitalkin/modules/_base_module.py +177 -0
- digitalkin/modules/archetype_module.py +14 -0
- digitalkin/modules/job_manager.py +158 -0
- digitalkin/modules/tool_module.py +14 -0
- digitalkin/modules/trigger_module.py +14 -0
- digitalkin/py.typed +0 -0
- digitalkin/services/__init__.py +28 -0
- digitalkin/services/agent/__init__.py +6 -0
- digitalkin/services/agent/agent_strategy.py +22 -0
- digitalkin/services/agent/default_agent.py +16 -0
- digitalkin/services/cost/__init__.py +6 -0
- digitalkin/services/cost/cost_strategy.py +15 -0
- digitalkin/services/cost/default_cost.py +13 -0
- digitalkin/services/default_service.py +13 -0
- digitalkin/services/development_service.py +10 -0
- digitalkin/services/filesystem/__init__.py +6 -0
- digitalkin/services/filesystem/default_filesystem.py +29 -0
- digitalkin/services/filesystem/filesystem_strategy.py +31 -0
- digitalkin/services/identity/__init__.py +6 -0
- digitalkin/services/identity/default_identity.py +15 -0
- digitalkin/services/identity/identity_strategy.py +12 -0
- digitalkin/services/registry/__init__.py +6 -0
- digitalkin/services/registry/default_registry.py +13 -0
- digitalkin/services/registry/registry_strategy.py +17 -0
- digitalkin/services/service_provider.py +27 -0
- digitalkin/services/snapshot/__init__.py +6 -0
- digitalkin/services/snapshot/default_snapshot.py +39 -0
- digitalkin/services/snapshot/snapshot_strategy.py +31 -0
- digitalkin/services/storage/__init__.py +6 -0
- digitalkin/services/storage/default_storage.py +91 -0
- digitalkin/services/storage/grpc_storage.py +207 -0
- digitalkin/services/storage/storage_strategy.py +42 -0
- digitalkin/utils/__init__.py +1 -0
- digitalkin/utils/arg_parser.py +136 -0
- digitalkin-0.1.1.dist-info/METADATA +588 -0
- digitalkin-0.1.1.dist-info/RECORD +59 -0
- digitalkin-0.1.1.dist-info/WHEEL +5 -0
- digitalkin-0.1.1.dist-info/licenses/LICENSE +430 -0
- digitalkin-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"""Data models for gRPC server configurations."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
|
8
|
+
|
|
9
|
+
from digitalkin.grpc.utils.exceptions import ConfigurationError, SecurityError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ServerMode(str, Enum):
|
|
13
|
+
"""Enum for server operation mode."""
|
|
14
|
+
|
|
15
|
+
SYNC = "sync"
|
|
16
|
+
ASYNC = "async"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SecurityMode(str, Enum):
|
|
20
|
+
"""Enum for server security mode."""
|
|
21
|
+
|
|
22
|
+
SECURE = "secure"
|
|
23
|
+
INSECURE = "insecure"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ServerCredentials(BaseModel):
|
|
27
|
+
"""Model for server credentials in secure mode.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
server_key_path: Path to the server private key
|
|
31
|
+
server_cert_path: Path to the server certificate
|
|
32
|
+
root_cert_path: Optional path to the root certificate
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
server_key_path: Path = Field(..., description="Path to the server private key")
|
|
36
|
+
server_cert_path: Path = Field(..., description="Path to the server certificate")
|
|
37
|
+
root_cert_path: Path | None = Field(None, description="Path to the root certificate")
|
|
38
|
+
|
|
39
|
+
# Enable __slots__ for memory efficiency
|
|
40
|
+
model_config = {
|
|
41
|
+
"extra": "forbid",
|
|
42
|
+
"arbitrary_types_allowed": True,
|
|
43
|
+
"validate_assignment": True,
|
|
44
|
+
"use_enum_values": True,
|
|
45
|
+
"frozen": True, # Make immutable
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
@field_validator("server_key_path", "server_cert_path", "root_cert_path")
|
|
49
|
+
@classmethod
|
|
50
|
+
def check_path_exists(cls, v: Path | None) -> Path | None:
|
|
51
|
+
"""Validate that the file path exists.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
v: Path to validate
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
The validated path
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
SecurityError: If the path does not exist
|
|
61
|
+
"""
|
|
62
|
+
if v is not None and not v.exists():
|
|
63
|
+
msg = f"File not found: {v}"
|
|
64
|
+
raise SecurityError(msg)
|
|
65
|
+
return v
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ServerConfig(BaseModel):
|
|
69
|
+
"""Base configuration for gRPC servers.
|
|
70
|
+
|
|
71
|
+
Attributes:
|
|
72
|
+
host: Host address to bind the server to
|
|
73
|
+
port: Port to listen on
|
|
74
|
+
max_workers: Maximum number of workers for sync mode
|
|
75
|
+
mode: Server operation mode (sync/async)
|
|
76
|
+
security: Security mode (secure/insecure)
|
|
77
|
+
credentials: Server credentials for secure mode
|
|
78
|
+
server_options: Additional server options
|
|
79
|
+
enable_reflection: Enable reflection for the server
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
host: str = Field("0.0.0.0", description="Host address to bind the server to") # noqa: S104
|
|
83
|
+
port: int = Field(50051, description="Port to listen on")
|
|
84
|
+
max_workers: int = Field(10, description="Maximum number of workers for sync mode")
|
|
85
|
+
mode: ServerMode = Field(ServerMode.SYNC, description="Server operation mode (sync/async)")
|
|
86
|
+
security: SecurityMode = Field(SecurityMode.INSECURE, description="Security mode (secure/insecure)")
|
|
87
|
+
credentials: ServerCredentials | None = Field(None, description="Server credentials for secure mode")
|
|
88
|
+
server_options: list[tuple[str, Any]] = Field(default_factory=list, description="Additional server options")
|
|
89
|
+
enable_reflection: bool = Field(default=True, description="Enable reflection for the server")
|
|
90
|
+
enable_health_check: bool = Field(default=True, description="Enable health check service")
|
|
91
|
+
|
|
92
|
+
# Enable __slots__ for memory efficiency
|
|
93
|
+
model_config = {
|
|
94
|
+
"extra": "forbid",
|
|
95
|
+
"arbitrary_types_allowed": True,
|
|
96
|
+
"validate_assignment": True,
|
|
97
|
+
"use_enum_values": True,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
@field_validator("credentials")
|
|
101
|
+
@classmethod
|
|
102
|
+
def validate_credentials(cls, v: ServerCredentials | None, info: ValidationInfo) -> ServerCredentials | None:
|
|
103
|
+
"""Validate that credentials are provided when in secure mode.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
v: The credentials value
|
|
107
|
+
info: ValidationInfo containing other field values
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
The validated credentials
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ConfigurationError: If credentials are missing in secure mode
|
|
114
|
+
"""
|
|
115
|
+
# Access security mode from the info.data dictionary
|
|
116
|
+
security = info.data.get("security")
|
|
117
|
+
|
|
118
|
+
if security == SecurityMode.SECURE and v is None:
|
|
119
|
+
msg = "Credentials must be provided when using secure mode"
|
|
120
|
+
raise ConfigurationError(msg)
|
|
121
|
+
return v
|
|
122
|
+
|
|
123
|
+
@field_validator("port")
|
|
124
|
+
@classmethod
|
|
125
|
+
def validate_port(cls, v: int) -> int:
|
|
126
|
+
"""Validate that the port is in a valid range.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
v: Port number to validate
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
The validated port number
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
ConfigurationError: If port is outside valid range
|
|
136
|
+
"""
|
|
137
|
+
if not 0 < v < 65536: # noqa: PLR2004
|
|
138
|
+
msg = f"Port must be between 1 and 65535, got {v}"
|
|
139
|
+
raise ConfigurationError(msg)
|
|
140
|
+
return v
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def address(self) -> str:
|
|
144
|
+
"""Get the server address.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
The formatted address string
|
|
148
|
+
"""
|
|
149
|
+
return f"{self.host}:{self.port}"
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ModuleServerConfig(ServerConfig):
|
|
153
|
+
"""Configuration for Module gRPC server.
|
|
154
|
+
|
|
155
|
+
Attributes:
|
|
156
|
+
registry_address: Address of the registry server
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
registry_address: str | None = Field(None, description="Address of the registry server")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class RegistryServerConfig(ServerConfig):
|
|
163
|
+
"""Configuration for Registry gRPC server.
|
|
164
|
+
|
|
165
|
+
Attributes:
|
|
166
|
+
database_url: Database URL for registry data storage
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
database_url: str | None = Field(None, description="Database URL for registry data storage")
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Type definitions for gRPC utilities."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, TypeVar
|
|
4
|
+
|
|
5
|
+
import grpc
|
|
6
|
+
from grpc import aio as grpc_aio
|
|
7
|
+
|
|
8
|
+
GrpcServer = grpc.Server | grpc_aio.Server
|
|
9
|
+
|
|
10
|
+
# Create a type variable for servicer implementations
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ServiceObject(Protocol):
|
|
15
|
+
"""Protocol for individual services in a gRPC descriptor."""
|
|
16
|
+
|
|
17
|
+
full_name: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Create a protocol for service descriptors
|
|
21
|
+
class ServiceDescriptor(Protocol):
|
|
22
|
+
"""Protocol for gRPC service descriptors."""
|
|
23
|
+
|
|
24
|
+
services_by_name: dict[str, ServiceObject]
|
digitalkin/logger.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""This module sets up a logger."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
logging.basicConfig(
|
|
7
|
+
level=logging.DEBUG,
|
|
8
|
+
stream=sys.stdout,
|
|
9
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
10
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
logging.getLogger("grpc").setLevel(logging.DEBUG)
|
|
14
|
+
logging.getLogger("asyncio").setLevel(logging.DEBUG)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("digitalkin")
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Module model."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModuleStatus(Enum):
|
|
9
|
+
"""Possible module's state."""
|
|
10
|
+
|
|
11
|
+
CREATED = auto() # Module created but not started
|
|
12
|
+
STARTING = auto() # Module is starting
|
|
13
|
+
RUNNING = auto() # Module do run
|
|
14
|
+
STOPPING = auto() # Module is stopping
|
|
15
|
+
STOPPED = auto() # Module stop successfuly
|
|
16
|
+
FAILED = auto() # Module stopped due to internal error
|
|
17
|
+
NOT_FOUND = auto()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Module(BaseModel):
|
|
21
|
+
"""Module model."""
|
|
22
|
+
|
|
23
|
+
name: str
|
|
24
|
+
cost_schema: str
|
|
25
|
+
input_schema: str
|
|
26
|
+
output_schema: str
|
|
27
|
+
setup_schema: str
|
|
28
|
+
secret_schema: str
|
|
29
|
+
type: str
|
|
30
|
+
version: str
|
|
31
|
+
description: str
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Pydantic models for cost service."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CostTypeEnum(Enum):
|
|
11
|
+
"""Enumeration of supported cost types."""
|
|
12
|
+
|
|
13
|
+
TOKEN_INPUT = "token_input"
|
|
14
|
+
TOKEN_OUTPUT = "token_output"
|
|
15
|
+
API_CALL = "api_call"
|
|
16
|
+
STORAGE = "storage"
|
|
17
|
+
TIME = "time"
|
|
18
|
+
CUSTOM = "custom"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CostConfig(BaseModel):
|
|
22
|
+
"""Pydantic model that defines a cost configuration.
|
|
23
|
+
|
|
24
|
+
:param cost_name: Name of the cost (unique identifier in the service).
|
|
25
|
+
:param cost_type: The type/category of the cost.
|
|
26
|
+
:param description: A short description of the cost.
|
|
27
|
+
:param unit: The unit of measurement (e.g. token, call, MB).
|
|
28
|
+
:param rate: The cost per unit (e.g. dollars per token).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
name: str
|
|
32
|
+
type: CostTypeEnum
|
|
33
|
+
description: str | None = None
|
|
34
|
+
unit: str
|
|
35
|
+
rate: float
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class CostEvent(BaseModel):
|
|
39
|
+
"""Pydantic model that represents a cost event registered during service execution.
|
|
40
|
+
|
|
41
|
+
:param cost_name: Identifier for the cost configuration.
|
|
42
|
+
:param cost_type: The type of cost.
|
|
43
|
+
:param usage: The amount or units consumed.
|
|
44
|
+
:param cost_amount: The computed cost amount; if not provided it is computed as usage*rate.
|
|
45
|
+
:param timestamp: The time when the cost event was recorded.
|
|
46
|
+
:param metadata: Additional contextual information about the cost event.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
name: str
|
|
50
|
+
usage: float
|
|
51
|
+
amount: float
|
|
52
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
53
|
+
metadata: dict[str, Any] | None = None
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""BaseModule is the abstract base for all modules in the DigitalKin SDK."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import json
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Any, ClassVar, Generic, TypeVar
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from digitalkin.logger import logger
|
|
13
|
+
from digitalkin.models.module import ModuleStatus
|
|
14
|
+
from digitalkin.services.service_provider import ServiceProvider
|
|
15
|
+
from digitalkin.services.storage.storage_strategy import StorageStrategy
|
|
16
|
+
|
|
17
|
+
InputModelT = TypeVar("InputModelT", bound=BaseModel)
|
|
18
|
+
OutputModelT = TypeVar("OutputModelT", bound=BaseModel)
|
|
19
|
+
SetupModelT = TypeVar("SetupModelT", bound=BaseModel)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseModule(ABC, Generic[InputModelT, OutputModelT, SetupModelT]):
|
|
23
|
+
"""BaseModule is the abstract base for all modules in the DigitalKin SDK."""
|
|
24
|
+
|
|
25
|
+
input_format: type[InputModelT]
|
|
26
|
+
output_format: type[OutputModelT]
|
|
27
|
+
setup_format: type[SetupModelT]
|
|
28
|
+
metadata: ClassVar[dict[str, Any]]
|
|
29
|
+
|
|
30
|
+
local_services: type[ServiceProvider]
|
|
31
|
+
dev_services: type[ServiceProvider]
|
|
32
|
+
|
|
33
|
+
storage: StorageStrategy
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
job_id: str,
|
|
38
|
+
name: str | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize the module."""
|
|
41
|
+
self.job_id: str = job_id
|
|
42
|
+
self.name = name or self.__class__.__name__
|
|
43
|
+
self._status = ModuleStatus.CREATED
|
|
44
|
+
self._task: asyncio.Task | None = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def status(self) -> ModuleStatus:
|
|
48
|
+
"""Get the module status.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
The module status
|
|
52
|
+
"""
|
|
53
|
+
return self._status
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def get_input_format(cls, llm_format: bool) -> str: # noqa: FBT001
|
|
57
|
+
"""Get the JSON schema of the input format model.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
NotImplementedError: If the `input_format` is not defined.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The JSON schema of the input format as a string.
|
|
64
|
+
"""
|
|
65
|
+
if cls.output_format is not None:
|
|
66
|
+
if llm_format:
|
|
67
|
+
return json.dumps(cls.input_format, indent=2)
|
|
68
|
+
return json.dumps(cls.input_format.model_json_schema(), indent=2)
|
|
69
|
+
msg = f"{cls.__name__}' class does not define an 'input_format'."
|
|
70
|
+
raise NotImplementedError(msg)
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def get_output_format(cls, llm_format: bool) -> str: # noqa: FBT001
|
|
74
|
+
"""Get the JSON schema of the output format model.
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
NotImplementedError: If the `output_format` is not defined.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
The JSON schema of the output format as a string.
|
|
81
|
+
"""
|
|
82
|
+
if cls.output_format is not None:
|
|
83
|
+
if llm_format:
|
|
84
|
+
return json.dumps(cls.output_format, indent=2)
|
|
85
|
+
return json.dumps(cls.output_format.model_json_schema(), indent=2)
|
|
86
|
+
msg = "'%s' class does not define an 'output_format'."
|
|
87
|
+
raise NotImplementedError(msg)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def get_setup_format(cls, llm_format: bool) -> str: # noqa: FBT001
|
|
91
|
+
"""Gets the JSON schema of the setup format model.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
NotImplementedError: If the `setup_format` is not defined.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
The JSON schema of the setup format as a string.
|
|
98
|
+
"""
|
|
99
|
+
if cls.setup_format is not None:
|
|
100
|
+
if llm_format:
|
|
101
|
+
return json.dumps(cls.setup_format, indent=2)
|
|
102
|
+
return json.dumps(cls.setup_format.model_json_schema(), indent=2)
|
|
103
|
+
msg = "'%s' class does not define an 'setup_format'."
|
|
104
|
+
raise NotImplementedError(msg)
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
async def initialize(self, setup_data: dict[str, Any]) -> None:
|
|
108
|
+
"""Initialize the module."""
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
async def run(
|
|
113
|
+
self,
|
|
114
|
+
input_data: dict[str, Any],
|
|
115
|
+
setup_data: dict[str, Any],
|
|
116
|
+
callback: Callable,
|
|
117
|
+
) -> None:
|
|
118
|
+
"""Run the module."""
|
|
119
|
+
raise NotImplementedError
|
|
120
|
+
|
|
121
|
+
@abstractmethod
|
|
122
|
+
async def cleanup(self) -> None:
|
|
123
|
+
"""Run the module."""
|
|
124
|
+
raise NotImplementedError
|
|
125
|
+
|
|
126
|
+
async def _run_lifecycle(
|
|
127
|
+
self,
|
|
128
|
+
input_data: dict[str, Any],
|
|
129
|
+
setup_data: dict[str, Any],
|
|
130
|
+
callback: Callable,
|
|
131
|
+
) -> None:
|
|
132
|
+
"""Run the module lifecycle.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
asyncio.CancelledError: If the module is cancelled
|
|
136
|
+
"""
|
|
137
|
+
try:
|
|
138
|
+
await self.run(input_data, setup_data, callback)
|
|
139
|
+
await self.stop()
|
|
140
|
+
except asyncio.CancelledError:
|
|
141
|
+
logger.info(f"Module {self.name} cancelled")
|
|
142
|
+
except Exception:
|
|
143
|
+
self._status = ModuleStatus.FAILED
|
|
144
|
+
logger.exception("Error inside module %s", self.name)
|
|
145
|
+
else:
|
|
146
|
+
self._status = ModuleStatus.STOPPED
|
|
147
|
+
|
|
148
|
+
async def start(
|
|
149
|
+
self,
|
|
150
|
+
input_data: dict[str, Any],
|
|
151
|
+
setup_data: dict[str, Any],
|
|
152
|
+
callback: Callable,
|
|
153
|
+
) -> None:
|
|
154
|
+
"""Start the module."""
|
|
155
|
+
try:
|
|
156
|
+
await self.initialize(setup_data=setup_data)
|
|
157
|
+
self._status = ModuleStatus.RUNNING
|
|
158
|
+
self._task = asyncio.create_task(self._run_lifecycle(input_data, setup_data, callback))
|
|
159
|
+
except Exception:
|
|
160
|
+
self._status = ModuleStatus.FAILED
|
|
161
|
+
logger.exception("Error starting module")
|
|
162
|
+
|
|
163
|
+
async def stop(self) -> None:
|
|
164
|
+
"""Stop the module."""
|
|
165
|
+
if self._status != ModuleStatus.RUNNING:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
self._status = ModuleStatus.STOPPING
|
|
170
|
+
if self._task and not self._task.done():
|
|
171
|
+
self._task.cancel()
|
|
172
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
173
|
+
await self._task
|
|
174
|
+
await self.cleanup()
|
|
175
|
+
except Exception:
|
|
176
|
+
self._status = ModuleStatus.FAILED
|
|
177
|
+
logger.exception("Error stopping module")
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""ArchetypeModule extends BaseModule to implement specific module types."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
|
|
5
|
+
from digitalkin.modules._base_module import BaseModule
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ArchetypeModule(BaseModule, ABC):
|
|
9
|
+
"""ArchetypeModule extends BaseModule to implement specific module types."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, name: str | None = None) -> None:
|
|
12
|
+
"""Initialize the module with the given metadata."""
|
|
13
|
+
super().__init__(self.job_id, name=name)
|
|
14
|
+
self.capabilities = ["archetype"]
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""Background module manager."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from argparse import ArgumentParser, Namespace
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from digitalkin.logger import logger
|
|
10
|
+
from digitalkin.models import ModuleStatus
|
|
11
|
+
from digitalkin.modules._base_module import BaseModule
|
|
12
|
+
from digitalkin.utils.arg_parser import ArgParser, DevelopmentModeMappingAction
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class JobManager(ArgParser):
|
|
16
|
+
"""Background module manager."""
|
|
17
|
+
|
|
18
|
+
args: Namespace
|
|
19
|
+
|
|
20
|
+
def _add_parser_args(self, parser: ArgumentParser) -> None:
|
|
21
|
+
class_mapping = {
|
|
22
|
+
"local": self.module_class.local_services,
|
|
23
|
+
"development": self.module_class.dev_services,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
super()._add_parser_args(parser)
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"-d",
|
|
29
|
+
"--dev-mode",
|
|
30
|
+
env_var="SERVICE_PROVIDER",
|
|
31
|
+
class_mapping=class_mapping,
|
|
32
|
+
choices=class_mapping.keys(),
|
|
33
|
+
default="local",
|
|
34
|
+
action=DevelopmentModeMappingAction,
|
|
35
|
+
dest="service_providers",
|
|
36
|
+
help="Define Module Service configurations for endpoints",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def __init__(self, module_class: type[BaseModule]) -> None:
|
|
40
|
+
"""Initialize the job manager."""
|
|
41
|
+
self.module_class = module_class
|
|
42
|
+
self.modules: dict[str, BaseModule] = {}
|
|
43
|
+
self._lock = asyncio.Lock()
|
|
44
|
+
super().__init__()
|
|
45
|
+
|
|
46
|
+
explicit_fields = {
|
|
47
|
+
name: self.args.service_providers.__dict__[name]
|
|
48
|
+
for name in self.args.service_providers.__class_vars__
|
|
49
|
+
if name in self.args.service_providers.__dict__
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
# services are now available as class vars.
|
|
53
|
+
# init the services provided allowing cold start during module creation
|
|
54
|
+
for service_name in explicit_fields:
|
|
55
|
+
service_type = getattr(self.args.service_providers, service_name)
|
|
56
|
+
setattr(self.module_class, service_name, service_type)
|
|
57
|
+
|
|
58
|
+
async def create_job( # noqa: D417
|
|
59
|
+
self,
|
|
60
|
+
input_data: dict[str, Any],
|
|
61
|
+
setup_data: dict[str, Any],
|
|
62
|
+
callback: Callable,
|
|
63
|
+
*args: tuple,
|
|
64
|
+
**kwargs: dict,
|
|
65
|
+
) -> tuple[str, BaseModule]:
|
|
66
|
+
"""Start new module job in background (asyncio).
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
module_class: Classe du module à instancier
|
|
70
|
+
*args: Arguments à passer au constructeur du module
|
|
71
|
+
**kwargs: Arguments à passer au constructeur du module
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
str: job_id of the module entity
|
|
75
|
+
"""
|
|
76
|
+
job_id = str(uuid.uuid4())
|
|
77
|
+
"""TODO: check uniqueness of the job_id"""
|
|
78
|
+
# Création et démarrage du module
|
|
79
|
+
module = self.module_class(job_id, *args, **kwargs) # type: ignore
|
|
80
|
+
self.modules[job_id] = module
|
|
81
|
+
try:
|
|
82
|
+
await module.start(input_data, setup_data, callback)
|
|
83
|
+
logger.info("Module %s (%s) started successfully", job_id, module.name)
|
|
84
|
+
except Exception:
|
|
85
|
+
# En cas d'erreur, supprimer le module du gestionnaire
|
|
86
|
+
del self.modules[job_id]
|
|
87
|
+
logger.exception("Échec du démarrage du module %s: %s", job_id)
|
|
88
|
+
raise
|
|
89
|
+
else:
|
|
90
|
+
return job_id, module
|
|
91
|
+
|
|
92
|
+
async def stop_module(self, job_id: str) -> bool:
|
|
93
|
+
"""Arrête un module en cours d'exécution.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
job_id: Identifiant du module à arrêter
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
True si le module a été arrêté, False s'il n'existe pas.
|
|
100
|
+
"""
|
|
101
|
+
async with self._lock:
|
|
102
|
+
module = self.modules.get(job_id)
|
|
103
|
+
if not module:
|
|
104
|
+
logger.warning(f"Module {job_id} introuvable")
|
|
105
|
+
return False
|
|
106
|
+
try:
|
|
107
|
+
await module.stop()
|
|
108
|
+
logger.info(f"Module {job_id} ({module.name}) arrêté avec succès")
|
|
109
|
+
except Exception as e:
|
|
110
|
+
logger.error(f"Erreur lors de l'arrêt du module {job_id}: {e}")
|
|
111
|
+
raise
|
|
112
|
+
else:
|
|
113
|
+
return True
|
|
114
|
+
|
|
115
|
+
def get_module_status(self, job_id: str) -> ModuleStatus | None:
|
|
116
|
+
"""Obtient le statut d'un module.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
job_id: Identifiant du module
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Le statut du module ou None si le module n'existe pas.
|
|
123
|
+
"""
|
|
124
|
+
module = self.modules.get(job_id)
|
|
125
|
+
return module.status if module else None
|
|
126
|
+
|
|
127
|
+
def get_module(self, job_id: str) -> BaseModule | None:
|
|
128
|
+
"""Récupère une référence au module.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
job_id: Identifiant du module
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Le module ou None s'il n'existe pas.
|
|
135
|
+
"""
|
|
136
|
+
return self.modules.get(job_id)
|
|
137
|
+
|
|
138
|
+
async def stop_all_modules(self) -> None:
|
|
139
|
+
"""Arrête tous les modules en cours d'exécution."""
|
|
140
|
+
async with self._lock:
|
|
141
|
+
stop_tasks = [self.stop_module(job_id) for job_id in list(self.modules.keys())]
|
|
142
|
+
if stop_tasks:
|
|
143
|
+
await asyncio.gather(*stop_tasks, return_exceptions=True)
|
|
144
|
+
|
|
145
|
+
def list_modules(self) -> dict[str, dict[str, Any]]:
|
|
146
|
+
"""Liste tous les modules avec leur statut.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Dictionnaire des modules avec leurs informations.
|
|
150
|
+
"""
|
|
151
|
+
return {
|
|
152
|
+
job_id: {
|
|
153
|
+
"name": module.name,
|
|
154
|
+
"status": module.status,
|
|
155
|
+
"class": module.__class__.__name__,
|
|
156
|
+
}
|
|
157
|
+
for job_id, module in self.modules.items()
|
|
158
|
+
}
|