enkryptai-sdk 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enkryptai_sdk/__init__.py +11 -2
- enkryptai_sdk/{guardrails_config.py → config.py} +107 -46
- enkryptai_sdk/dto/__init__.py +18 -0
- enkryptai_sdk/dto/models.py +202 -0
- enkryptai_sdk/dto/red_team.py +196 -0
- enkryptai_sdk/evals.py +84 -0
- enkryptai_sdk/guardrails.py +11 -4
- enkryptai_sdk/models.py +144 -0
- enkryptai_sdk/red_team.py +185 -0
- enkryptai_sdk/response.py +135 -0
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/METADATA +111 -1
- enkryptai_sdk-0.1.5.dist-info/RECORD +15 -0
- enkryptai_sdk-0.1.3.dist-info/RECORD +0 -9
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/LICENSE +0 -0
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/WHEEL +0 -0
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/top_level.txt +0 -0
enkryptai_sdk/__init__.py
CHANGED
|
@@ -1,4 +1,13 @@
|
|
|
1
1
|
from .guardrails import GuardrailsClient
|
|
2
|
-
from .
|
|
2
|
+
from .config import GuardrailsConfig
|
|
3
|
+
from .evals import EvalsClient
|
|
4
|
+
from .models import ModelClient
|
|
5
|
+
from .red_team import RedTeamClient
|
|
3
6
|
|
|
4
|
-
__all__ = [
|
|
7
|
+
__all__ = [
|
|
8
|
+
"GuardrailsClient",
|
|
9
|
+
"GuardrailsConfig",
|
|
10
|
+
"EvalsClient",
|
|
11
|
+
"ModelClient",
|
|
12
|
+
"RedTeamClient",
|
|
13
|
+
]
|
|
@@ -1,17 +1,21 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
|
|
3
3
|
# Base default configuration for all detectors.
|
|
4
|
-
|
|
4
|
+
DEFAULT_GUARDRAILS_CONFIG = {
|
|
5
5
|
"topic_detector": {"enabled": False, "topic": []},
|
|
6
6
|
"nsfw": {"enabled": False},
|
|
7
7
|
"toxicity": {"enabled": False},
|
|
8
8
|
"pii": {"enabled": False, "entities": []},
|
|
9
9
|
"injection_attack": {"enabled": False},
|
|
10
10
|
"keyword_detector": {"enabled": False, "banned_keywords": []},
|
|
11
|
-
"policy_violation": {
|
|
11
|
+
"policy_violation": {
|
|
12
|
+
"enabled": False,
|
|
13
|
+
"policy_text": "",
|
|
14
|
+
"need_explanation": False,
|
|
15
|
+
},
|
|
12
16
|
"bias": {"enabled": False},
|
|
13
17
|
"copyright_ip": {"enabled": False},
|
|
14
|
-
"system_prompt": {"enabled": False, "index": "system"}
|
|
18
|
+
"system_prompt": {"enabled": False, "index": "system"},
|
|
15
19
|
}
|
|
16
20
|
|
|
17
21
|
|
|
@@ -24,29 +28,30 @@ class GuardrailsConfig:
|
|
|
24
28
|
|
|
25
29
|
def __init__(self, config=None):
|
|
26
30
|
# Use a deep copy of the default to avoid accidental mutation.
|
|
27
|
-
self.config =
|
|
31
|
+
self.config = (
|
|
32
|
+
copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG) if config is None else config
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
@classmethod
|
|
30
36
|
def injection_attack(cls):
|
|
31
37
|
"""
|
|
32
38
|
Returns a configuration instance pre-configured for injection attack detection.
|
|
33
39
|
"""
|
|
34
|
-
config = copy.deepcopy(
|
|
40
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
35
41
|
config["injection_attack"] = {"enabled": True}
|
|
36
42
|
return cls(config)
|
|
37
43
|
|
|
38
44
|
@classmethod
|
|
39
|
-
def policy_violation(cls,
|
|
40
|
-
policy_text: str,
|
|
41
|
-
need_explanation: bool = False):
|
|
45
|
+
def policy_violation(cls, policy_text: str, need_explanation: bool = False):
|
|
42
46
|
"""
|
|
43
47
|
Returns a configuration instance pre-configured for policy violation detection.
|
|
44
48
|
"""
|
|
45
|
-
config = copy.deepcopy(
|
|
46
|
-
config["policy_violation"] = {
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
49
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
50
|
+
config["policy_violation"] = {
|
|
51
|
+
"enabled": True,
|
|
52
|
+
"policy_text": policy_text,
|
|
53
|
+
"need_explanation": need_explanation,
|
|
54
|
+
}
|
|
50
55
|
return cls(config)
|
|
51
56
|
|
|
52
57
|
@classmethod
|
|
@@ -54,7 +59,7 @@ class GuardrailsConfig:
|
|
|
54
59
|
"""
|
|
55
60
|
Returns a configuration instance pre-configured for toxicity detection.
|
|
56
61
|
"""
|
|
57
|
-
config = copy.deepcopy(
|
|
62
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
58
63
|
config["toxicity"] = {"enabled": True}
|
|
59
64
|
return cls(config)
|
|
60
65
|
|
|
@@ -63,7 +68,7 @@ class GuardrailsConfig:
|
|
|
63
68
|
"""
|
|
64
69
|
Returns a configuration instance pre-configured for NSFW content detection.
|
|
65
70
|
"""
|
|
66
|
-
config = copy.deepcopy(
|
|
71
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
67
72
|
config["nsfw"] = {"enabled": True}
|
|
68
73
|
return cls(config)
|
|
69
74
|
|
|
@@ -72,7 +77,7 @@ class GuardrailsConfig:
|
|
|
72
77
|
"""
|
|
73
78
|
Returns a configuration instance pre-configured for bias detection.
|
|
74
79
|
"""
|
|
75
|
-
config = copy.deepcopy(
|
|
80
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
76
81
|
config["bias"] = {"enabled": True}
|
|
77
82
|
return cls(config)
|
|
78
83
|
|
|
@@ -80,14 +85,14 @@ class GuardrailsConfig:
|
|
|
80
85
|
def pii(cls, entities=None):
|
|
81
86
|
"""
|
|
82
87
|
Returns a configuration instance pre-configured for PII detection.
|
|
83
|
-
|
|
88
|
+
|
|
84
89
|
Args:
|
|
85
90
|
entities (list, optional): List of PII entity types to detect.
|
|
86
91
|
"""
|
|
87
|
-
config = copy.deepcopy(
|
|
92
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
88
93
|
config["pii"] = {
|
|
89
94
|
"enabled": True,
|
|
90
|
-
"entities": entities if entities is not None else []
|
|
95
|
+
"entities": entities if entities is not None else [],
|
|
91
96
|
}
|
|
92
97
|
return cls(config)
|
|
93
98
|
|
|
@@ -95,14 +100,14 @@ class GuardrailsConfig:
|
|
|
95
100
|
def topic(cls, topics=None):
|
|
96
101
|
"""
|
|
97
102
|
Returns a configuration instance pre-configured for topic detection.
|
|
98
|
-
|
|
103
|
+
|
|
99
104
|
Args:
|
|
100
105
|
topics (list, optional): List of topics to detect.
|
|
101
106
|
"""
|
|
102
|
-
config = copy.deepcopy(
|
|
107
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
103
108
|
config["topic_detector"] = {
|
|
104
109
|
"enabled": True,
|
|
105
|
-
"topic": topics if topics is not None else []
|
|
110
|
+
"topic": topics if topics is not None else [],
|
|
106
111
|
}
|
|
107
112
|
return cls(config)
|
|
108
113
|
|
|
@@ -110,14 +115,14 @@ class GuardrailsConfig:
|
|
|
110
115
|
def keyword(cls, keywords=None):
|
|
111
116
|
"""
|
|
112
117
|
Returns a configuration instance pre-configured for keyword detection.
|
|
113
|
-
|
|
118
|
+
|
|
114
119
|
Args:
|
|
115
120
|
keywords (list, optional): List of banned keywords to detect.
|
|
116
121
|
"""
|
|
117
|
-
config = copy.deepcopy(
|
|
122
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
118
123
|
config["keyword_detector"] = {
|
|
119
124
|
"enabled": True,
|
|
120
|
-
"banned_keywords": keywords if keywords is not None else []
|
|
125
|
+
"banned_keywords": keywords if keywords is not None else [],
|
|
121
126
|
}
|
|
122
127
|
return cls(config)
|
|
123
128
|
|
|
@@ -126,7 +131,7 @@ class GuardrailsConfig:
|
|
|
126
131
|
"""
|
|
127
132
|
Returns a configuration instance pre-configured for copyright/IP detection.
|
|
128
133
|
"""
|
|
129
|
-
config = copy.deepcopy(
|
|
134
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
130
135
|
config["copyright_ip"] = {"enabled": True}
|
|
131
136
|
return cls(config)
|
|
132
137
|
|
|
@@ -134,21 +139,18 @@ class GuardrailsConfig:
|
|
|
134
139
|
def system_prompt(cls, index="system"):
|
|
135
140
|
"""
|
|
136
141
|
Returns a configuration instance pre-configured for system prompt detection.
|
|
137
|
-
|
|
142
|
+
|
|
138
143
|
Args:
|
|
139
144
|
index (str, optional): Index name for system prompt detection. Defaults to "system".
|
|
140
145
|
"""
|
|
141
|
-
config = copy.deepcopy(
|
|
142
|
-
config["system_prompt"] = {
|
|
143
|
-
"enabled": True,
|
|
144
|
-
"index": index
|
|
145
|
-
}
|
|
146
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
147
|
+
config["system_prompt"] = {"enabled": True, "index": index}
|
|
146
148
|
return cls(config)
|
|
147
149
|
|
|
148
150
|
def update(self, **kwargs):
|
|
149
151
|
"""
|
|
150
152
|
Update the configuration with custom values.
|
|
151
|
-
|
|
153
|
+
|
|
152
154
|
Only keys that exist in the default configuration can be updated.
|
|
153
155
|
For example:
|
|
154
156
|
config.update(nsfw={"enabled": True}, toxicity={"enabled": True})
|
|
@@ -170,16 +172,16 @@ class GuardrailsConfig:
|
|
|
170
172
|
def from_custom_config(cls, config_dict: dict):
|
|
171
173
|
"""
|
|
172
174
|
Configure guardrails from a dictionary input.
|
|
173
|
-
|
|
175
|
+
|
|
174
176
|
Validates that the input dictionary matches the expected schema structure.
|
|
175
177
|
Each key must exist in the default configuration, and its value must be a dictionary.
|
|
176
|
-
|
|
178
|
+
|
|
177
179
|
Args:
|
|
178
180
|
config_dict (dict): Dictionary containing guardrails configuration
|
|
179
|
-
|
|
181
|
+
|
|
180
182
|
Returns:
|
|
181
183
|
GuardrailsConfig: Returns a new GuardrailsConfig instance
|
|
182
|
-
|
|
184
|
+
|
|
183
185
|
Raises:
|
|
184
186
|
ValueError: If the input dictionary contains invalid keys or malformed values
|
|
185
187
|
"""
|
|
@@ -189,33 +191,92 @@ class GuardrailsConfig:
|
|
|
189
191
|
raise ValueError(f"Unknown detector config: {key}")
|
|
190
192
|
if not isinstance(value, dict):
|
|
191
193
|
raise ValueError(f"Config value for {key} must be a dictionary")
|
|
192
|
-
|
|
194
|
+
|
|
193
195
|
# Validate that all required fields exist in the default config
|
|
194
|
-
default_fields = set(
|
|
196
|
+
default_fields = set(DEFAULT_GUARDRAILS_CONFIG[key].keys())
|
|
195
197
|
provided_fields = set(value.keys())
|
|
196
|
-
|
|
198
|
+
|
|
197
199
|
if not provided_fields.issubset(default_fields):
|
|
198
200
|
invalid_fields = provided_fields - default_fields
|
|
199
201
|
raise ValueError(f"Invalid fields for {key}: {invalid_fields}")
|
|
200
|
-
|
|
202
|
+
|
|
201
203
|
instance.config[key] = value
|
|
202
|
-
|
|
204
|
+
|
|
203
205
|
return instance
|
|
204
206
|
|
|
205
207
|
def get_config(self, detector_name: str) -> dict:
|
|
206
208
|
"""
|
|
207
209
|
Get the configuration for a specific detector.
|
|
208
|
-
|
|
210
|
+
|
|
209
211
|
Args:
|
|
210
212
|
detector_name (str): Name of the detector to get configuration for
|
|
211
|
-
|
|
213
|
+
|
|
212
214
|
Returns:
|
|
213
215
|
dict: Configuration dictionary for the specified detector
|
|
214
|
-
|
|
216
|
+
|
|
215
217
|
Raises:
|
|
216
218
|
ValueError: If the detector name doesn't exist in the configuration
|
|
217
219
|
"""
|
|
218
220
|
if detector_name not in self.config:
|
|
219
221
|
raise ValueError(f"Unknown detector: {detector_name}")
|
|
220
|
-
|
|
222
|
+
|
|
221
223
|
return copy.deepcopy(self.config[detector_name])
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class RedTeamConfig:
|
|
227
|
+
"""
|
|
228
|
+
A helper class to manage RedTeam configuration.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(self, config=None):
|
|
232
|
+
if config is None:
|
|
233
|
+
config = copy.deepcopy(DEFAULT_REDTEAM_CONFIG)
|
|
234
|
+
# Only include advanced tests if dataset is not standard
|
|
235
|
+
if config.get("dataset_name") != "standard":
|
|
236
|
+
config["redteam_test_configurations"].update(
|
|
237
|
+
copy.deepcopy(ADVANCED_REDTEAM_TESTS)
|
|
238
|
+
)
|
|
239
|
+
self.config = config
|
|
240
|
+
|
|
241
|
+
def as_dict(self):
|
|
242
|
+
"""
|
|
243
|
+
Return the underlying configuration dictionary.
|
|
244
|
+
"""
|
|
245
|
+
return self.config
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class ModelConfig:
|
|
249
|
+
def __init__(self, config=None):
|
|
250
|
+
if config is None:
|
|
251
|
+
config = copy.deepcopy(DETAIL_MODEL_CONFIG)
|
|
252
|
+
self.config = config
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def model_name(self, model_name: str):
|
|
256
|
+
"""
|
|
257
|
+
Set the model name.
|
|
258
|
+
"""
|
|
259
|
+
self.config["model_name"] = model_name
|
|
260
|
+
return self
|
|
261
|
+
|
|
262
|
+
@classmethod
|
|
263
|
+
def testing_for(self, testing_for: str):
|
|
264
|
+
"""
|
|
265
|
+
Set the testing for.
|
|
266
|
+
"""
|
|
267
|
+
self.config["testing_for"] = testing_for
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def model_config(self, model_config: dict):
|
|
272
|
+
"""
|
|
273
|
+
Set the model config.
|
|
274
|
+
"""
|
|
275
|
+
self.config["model_config"] = model_config
|
|
276
|
+
return self
|
|
277
|
+
|
|
278
|
+
def as_dict(self):
|
|
279
|
+
"""
|
|
280
|
+
Return the underlying configuration dictionary.
|
|
281
|
+
"""
|
|
282
|
+
return self.config
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .models import *
|
|
2
|
+
from .red_team import *
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"DetailModelConfig",
|
|
6
|
+
"ModelConfig",
|
|
7
|
+
"RedTeamConfig",
|
|
8
|
+
"AdvancedRedTeamTests",
|
|
9
|
+
"TestConfig",
|
|
10
|
+
"AttackMethods",
|
|
11
|
+
"RedTeamTestConfigurations",
|
|
12
|
+
"TargetModelConfiguration",
|
|
13
|
+
"Location",
|
|
14
|
+
"Metadata",
|
|
15
|
+
"DEFAULT_REDTEAM_CONFIG",
|
|
16
|
+
"ADVANCED_REDTEAM_TESTS",
|
|
17
|
+
"DETAIL_MODEL_CONFIG",
|
|
18
|
+
]
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from dataclasses import dataclass, field, asdict
|
|
2
|
+
from typing import Optional, List, Set, Dict, Any
|
|
3
|
+
from enum import Enum
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Modality(Enum):
|
|
8
|
+
TEXT = "text"
|
|
9
|
+
IMAGE = "image"
|
|
10
|
+
AUDIO = "audio"
|
|
11
|
+
VIDEO = "video"
|
|
12
|
+
|
|
13
|
+
def to_dict(self):
|
|
14
|
+
return self.value
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EndpointConfig:
|
|
19
|
+
scheme: str = "https"
|
|
20
|
+
host: str = "api.openai.com"
|
|
21
|
+
port: int = 443
|
|
22
|
+
base_path: str = "v1"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class PathsConfig:
|
|
27
|
+
completions: str = "/chat/completions"
|
|
28
|
+
chat: str = "chat/completions"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class AuthData:
|
|
33
|
+
header_name: str = "Authorization"
|
|
34
|
+
header_prefix: str = "Bearer"
|
|
35
|
+
space_after_prefix: bool = True
|
|
36
|
+
|
|
37
|
+
def to_dict(self):
|
|
38
|
+
return asdict(self)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_dict(cls, data: dict):
|
|
42
|
+
return cls(**data)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ModelDetailConfig:
|
|
47
|
+
model_version: Optional[str] = None
|
|
48
|
+
model_source: str = ""
|
|
49
|
+
model_provider: str = "openai"
|
|
50
|
+
system_prompt: str = ""
|
|
51
|
+
|
|
52
|
+
endpoint_url: str = "https://api.openai.com/v1/chat/completions"
|
|
53
|
+
auth_data: AuthData = field(default_factory=AuthData)
|
|
54
|
+
api_keys: Set[Optional[str]] = field(default_factory=lambda: {None})
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class DetailModelConfig:
|
|
59
|
+
model_saved_name: str = "Model Name"
|
|
60
|
+
testing_for: str = "LLM"
|
|
61
|
+
model_name: str = "gpt-4o-mini"
|
|
62
|
+
modality: Modality = Modality.TEXT
|
|
63
|
+
model_config: ModelDetailConfig = field(default_factory=ModelDetailConfig)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class ModelConfigDetails:
|
|
68
|
+
model_version: Optional[str] = None
|
|
69
|
+
model_source: str = ""
|
|
70
|
+
model_provider: str = "openai"
|
|
71
|
+
system_prompt: str = ""
|
|
72
|
+
conversation_template: str = ""
|
|
73
|
+
is_compatible_with: str = "openai"
|
|
74
|
+
hosting_type: str = "External"
|
|
75
|
+
endpoint_url: str = "https://api.openai.com/v1/chat/completions"
|
|
76
|
+
auth_data: AuthData = field(default_factory=AuthData)
|
|
77
|
+
apikey: Optional[str] = None
|
|
78
|
+
default_request_options: Dict[str, Any] = field(default_factory=dict)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_dict(cls, data: dict):
|
|
82
|
+
# Create a copy of the data to avoid modifying the original
|
|
83
|
+
data = data.copy()
|
|
84
|
+
|
|
85
|
+
# Remove known fields that we don't want in our model
|
|
86
|
+
unwanted_fields = ["queryParams", "paths"]
|
|
87
|
+
for field in unwanted_fields:
|
|
88
|
+
data.pop(field, None)
|
|
89
|
+
|
|
90
|
+
# Handle apikeys to apikey conversion
|
|
91
|
+
if "apikeys" in data:
|
|
92
|
+
apikeys = data.pop("apikeys")
|
|
93
|
+
if apikeys and not data.get("apikey"):
|
|
94
|
+
data["apikey"] = apikeys[0]
|
|
95
|
+
|
|
96
|
+
# Convert endpoint dict to endpoint_url if present
|
|
97
|
+
if "endpoint" in data:
|
|
98
|
+
endpoint = data.pop("endpoint")
|
|
99
|
+
scheme = endpoint.get("scheme", "https")
|
|
100
|
+
host = endpoint.get("host", "")
|
|
101
|
+
port = endpoint.get("port", "")
|
|
102
|
+
base_path = endpoint.get("base_path", "")
|
|
103
|
+
|
|
104
|
+
endpoint_url = f"{scheme}://{host}"
|
|
105
|
+
if port and port not in [80, 443]:
|
|
106
|
+
endpoint_url += f":{port}"
|
|
107
|
+
if base_path:
|
|
108
|
+
base_path = "/" + base_path.strip("/")
|
|
109
|
+
endpoint_url += base_path
|
|
110
|
+
|
|
111
|
+
data["endpoint_url"] = endpoint_url
|
|
112
|
+
|
|
113
|
+
# Handle nested AuthData
|
|
114
|
+
auth_data = data.pop("auth_data", {})
|
|
115
|
+
auth_data_obj = AuthData.from_dict(auth_data)
|
|
116
|
+
|
|
117
|
+
# Only keep fields that are defined in the dataclass
|
|
118
|
+
valid_fields = cls.__dataclass_fields__.keys()
|
|
119
|
+
filtered_data = {k: v for k, v in data.items() if k in valid_fields}
|
|
120
|
+
|
|
121
|
+
return cls(**filtered_data, auth_data=auth_data_obj)
|
|
122
|
+
|
|
123
|
+
def to_dict(self):
|
|
124
|
+
d = asdict(self)
|
|
125
|
+
# Handle AuthData specifically
|
|
126
|
+
d["auth_data"] = self.auth_data.to_dict()
|
|
127
|
+
return d
|
|
128
|
+
|
|
129
|
+
def to_json(self):
|
|
130
|
+
return json.dumps(self.to_dict())
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_json(cls, json_str: str):
|
|
134
|
+
return cls.from_dict(json.loads(json_str))
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass
|
|
138
|
+
class ModelConfig:
|
|
139
|
+
created_at: str = ""
|
|
140
|
+
updated_at: str = ""
|
|
141
|
+
model_id: str = ""
|
|
142
|
+
model_saved_name: str = "Model Name"
|
|
143
|
+
testing_for: str = "LLM"
|
|
144
|
+
model_name: str = "gpt-4o-mini"
|
|
145
|
+
model_type: str = "text_2_text"
|
|
146
|
+
modality: Modality = Modality.TEXT
|
|
147
|
+
certifications: List[str] = field(default_factory=list)
|
|
148
|
+
model_config: ModelConfigDetails = field(default_factory=ModelConfigDetails)
|
|
149
|
+
|
|
150
|
+
def to_dict(self) -> dict:
|
|
151
|
+
"""Convert the ModelConfig instance to a dictionary."""
|
|
152
|
+
# First create a shallow copy of self as dict
|
|
153
|
+
d = {}
|
|
154
|
+
for field in self.__dataclass_fields__:
|
|
155
|
+
value = getattr(self, field)
|
|
156
|
+
if field == "modality":
|
|
157
|
+
d[field] = value.value
|
|
158
|
+
elif field == "model_config":
|
|
159
|
+
if isinstance(value, ModelConfigDetails):
|
|
160
|
+
d[field] = value.to_dict()
|
|
161
|
+
else:
|
|
162
|
+
d[field] = value
|
|
163
|
+
else:
|
|
164
|
+
d[field] = value
|
|
165
|
+
return d
|
|
166
|
+
|
|
167
|
+
def to_json(self) -> str:
|
|
168
|
+
"""Convert the ModelConfig instance to a JSON string."""
|
|
169
|
+
return json.dumps(self.to_dict())
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def from_dict(cls, data: dict):
|
|
173
|
+
"""Create a ModelConfig instance from a dictionary."""
|
|
174
|
+
# Handle nested ModelConfigDetails
|
|
175
|
+
model_config_data = data.pop("model_config", {})
|
|
176
|
+
model_config = ModelConfigDetails.from_dict(model_config_data)
|
|
177
|
+
|
|
178
|
+
# Handle Modality enum
|
|
179
|
+
modality_value = data.pop("modality", "text")
|
|
180
|
+
modality = Modality(modality_value)
|
|
181
|
+
|
|
182
|
+
return cls(**data, modality=modality, model_config=model_config)
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def from_json(cls, json_str: str):
|
|
186
|
+
"""Create a ModelConfig instance from a JSON string."""
|
|
187
|
+
data = json.loads(json_str)
|
|
188
|
+
return cls.from_dict(data)
|
|
189
|
+
|
|
190
|
+
def __str__(self):
|
|
191
|
+
"""String representation of the ModelConfig."""
|
|
192
|
+
return f"ModelConfig(name={self.model_saved_name}, model={self.model_name})"
|
|
193
|
+
|
|
194
|
+
def __repr__(self):
|
|
195
|
+
"""Detailed string representation of the ModelConfig."""
|
|
196
|
+
return (
|
|
197
|
+
f"ModelConfig({', '.join(f'{k}={v!r}' for k, v in self.to_dict().items())})"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# Default configuration
|
|
202
|
+
DETAIL_MODEL_CONFIG = ModelConfig()
|