enkryptai-sdk 0.1.4__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {enkryptai_sdk-0.1.4/src/enkryptai_sdk.egg-info → enkryptai_sdk-0.1.5}/PKG-INFO +1 -1
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/setup.py +1 -1
- enkryptai_sdk-0.1.5/src/enkryptai_sdk/__init__.py +13 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/config.py +107 -46
- enkryptai_sdk-0.1.5/src/enkryptai_sdk/dto/__init__.py +18 -0
- enkryptai_sdk-0.1.5/src/enkryptai_sdk/dto/models.py +202 -0
- enkryptai_sdk-0.1.5/src/enkryptai_sdk/dto/red_team.py +196 -0
- enkryptai_sdk-0.1.5/src/enkryptai_sdk/models.py +144 -0
- enkryptai_sdk-0.1.5/src/enkryptai_sdk/red_team.py +185 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5/src/enkryptai_sdk.egg-info}/PKG-INFO +1 -1
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk.egg-info/SOURCES.txt +3 -0
- enkryptai_sdk-0.1.4/src/enkryptai_sdk/__init__.py +0 -5
- enkryptai_sdk-0.1.4/src/enkryptai_sdk/models.py +0 -0
- enkryptai_sdk-0.1.4/src/enkryptai_sdk/red_team.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/LICENSE +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/README.md +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/setup.cfg +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/evals.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/guardrails.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/response.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk.egg-info/dependency_links.txt +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk.egg-info/top_level.txt +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_all.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_basic.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_detect_policy.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_injection_attack.py +0 -0
- {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_policy_violation.py +0 -0
|
@@ -8,7 +8,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as fh:
|
|
|
8
8
|
|
|
9
9
|
setup(
|
|
10
10
|
name="enkryptai-sdk", # This is the name of your package on PyPI
|
|
11
|
-
version="0.1.
|
|
11
|
+
version="0.1.5",
|
|
12
12
|
description="A Python SDK with guardrails and red teaming functionality for API interactions",
|
|
13
13
|
long_description=long_description,
|
|
14
14
|
long_description_content_type="text/markdown",
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .guardrails import GuardrailsClient
|
|
2
|
+
from .config import GuardrailsConfig
|
|
3
|
+
from .evals import EvalsClient
|
|
4
|
+
from .models import ModelClient
|
|
5
|
+
from .red_team import RedTeamClient
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"GuardrailsClient",
|
|
9
|
+
"GuardrailsConfig",
|
|
10
|
+
"EvalsClient",
|
|
11
|
+
"ModelClient",
|
|
12
|
+
"RedTeamClient",
|
|
13
|
+
]
|
|
@@ -1,17 +1,21 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
|
|
3
3
|
# Base default configuration for all detectors.
|
|
4
|
-
|
|
4
|
+
DEFAULT_GUARDRAILS_CONFIG = {
|
|
5
5
|
"topic_detector": {"enabled": False, "topic": []},
|
|
6
6
|
"nsfw": {"enabled": False},
|
|
7
7
|
"toxicity": {"enabled": False},
|
|
8
8
|
"pii": {"enabled": False, "entities": []},
|
|
9
9
|
"injection_attack": {"enabled": False},
|
|
10
10
|
"keyword_detector": {"enabled": False, "banned_keywords": []},
|
|
11
|
-
"policy_violation": {
|
|
11
|
+
"policy_violation": {
|
|
12
|
+
"enabled": False,
|
|
13
|
+
"policy_text": "",
|
|
14
|
+
"need_explanation": False,
|
|
15
|
+
},
|
|
12
16
|
"bias": {"enabled": False},
|
|
13
17
|
"copyright_ip": {"enabled": False},
|
|
14
|
-
"system_prompt": {"enabled": False, "index": "system"}
|
|
18
|
+
"system_prompt": {"enabled": False, "index": "system"},
|
|
15
19
|
}
|
|
16
20
|
|
|
17
21
|
|
|
@@ -24,29 +28,30 @@ class GuardrailsConfig:
|
|
|
24
28
|
|
|
25
29
|
def __init__(self, config=None):
|
|
26
30
|
# Use a deep copy of the default to avoid accidental mutation.
|
|
27
|
-
self.config =
|
|
31
|
+
self.config = (
|
|
32
|
+
copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG) if config is None else config
|
|
33
|
+
)
|
|
28
34
|
|
|
29
35
|
@classmethod
|
|
30
36
|
def injection_attack(cls):
|
|
31
37
|
"""
|
|
32
38
|
Returns a configuration instance pre-configured for injection attack detection.
|
|
33
39
|
"""
|
|
34
|
-
config = copy.deepcopy(
|
|
40
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
35
41
|
config["injection_attack"] = {"enabled": True}
|
|
36
42
|
return cls(config)
|
|
37
43
|
|
|
38
44
|
@classmethod
|
|
39
|
-
def policy_violation(cls,
|
|
40
|
-
policy_text: str,
|
|
41
|
-
need_explanation: bool = False):
|
|
45
|
+
def policy_violation(cls, policy_text: str, need_explanation: bool = False):
|
|
42
46
|
"""
|
|
43
47
|
Returns a configuration instance pre-configured for policy violation detection.
|
|
44
48
|
"""
|
|
45
|
-
config = copy.deepcopy(
|
|
46
|
-
config["policy_violation"] = {
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
49
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
50
|
+
config["policy_violation"] = {
|
|
51
|
+
"enabled": True,
|
|
52
|
+
"policy_text": policy_text,
|
|
53
|
+
"need_explanation": need_explanation,
|
|
54
|
+
}
|
|
50
55
|
return cls(config)
|
|
51
56
|
|
|
52
57
|
@classmethod
|
|
@@ -54,7 +59,7 @@ class GuardrailsConfig:
|
|
|
54
59
|
"""
|
|
55
60
|
Returns a configuration instance pre-configured for toxicity detection.
|
|
56
61
|
"""
|
|
57
|
-
config = copy.deepcopy(
|
|
62
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
58
63
|
config["toxicity"] = {"enabled": True}
|
|
59
64
|
return cls(config)
|
|
60
65
|
|
|
@@ -63,7 +68,7 @@ class GuardrailsConfig:
|
|
|
63
68
|
"""
|
|
64
69
|
Returns a configuration instance pre-configured for NSFW content detection.
|
|
65
70
|
"""
|
|
66
|
-
config = copy.deepcopy(
|
|
71
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
67
72
|
config["nsfw"] = {"enabled": True}
|
|
68
73
|
return cls(config)
|
|
69
74
|
|
|
@@ -72,7 +77,7 @@ class GuardrailsConfig:
|
|
|
72
77
|
"""
|
|
73
78
|
Returns a configuration instance pre-configured for bias detection.
|
|
74
79
|
"""
|
|
75
|
-
config = copy.deepcopy(
|
|
80
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
76
81
|
config["bias"] = {"enabled": True}
|
|
77
82
|
return cls(config)
|
|
78
83
|
|
|
@@ -80,14 +85,14 @@ class GuardrailsConfig:
|
|
|
80
85
|
def pii(cls, entities=None):
|
|
81
86
|
"""
|
|
82
87
|
Returns a configuration instance pre-configured for PII detection.
|
|
83
|
-
|
|
88
|
+
|
|
84
89
|
Args:
|
|
85
90
|
entities (list, optional): List of PII entity types to detect.
|
|
86
91
|
"""
|
|
87
|
-
config = copy.deepcopy(
|
|
92
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
88
93
|
config["pii"] = {
|
|
89
94
|
"enabled": True,
|
|
90
|
-
"entities": entities if entities is not None else []
|
|
95
|
+
"entities": entities if entities is not None else [],
|
|
91
96
|
}
|
|
92
97
|
return cls(config)
|
|
93
98
|
|
|
@@ -95,14 +100,14 @@ class GuardrailsConfig:
|
|
|
95
100
|
def topic(cls, topics=None):
|
|
96
101
|
"""
|
|
97
102
|
Returns a configuration instance pre-configured for topic detection.
|
|
98
|
-
|
|
103
|
+
|
|
99
104
|
Args:
|
|
100
105
|
topics (list, optional): List of topics to detect.
|
|
101
106
|
"""
|
|
102
|
-
config = copy.deepcopy(
|
|
107
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
103
108
|
config["topic_detector"] = {
|
|
104
109
|
"enabled": True,
|
|
105
|
-
"topic": topics if topics is not None else []
|
|
110
|
+
"topic": topics if topics is not None else [],
|
|
106
111
|
}
|
|
107
112
|
return cls(config)
|
|
108
113
|
|
|
@@ -110,14 +115,14 @@ class GuardrailsConfig:
|
|
|
110
115
|
def keyword(cls, keywords=None):
|
|
111
116
|
"""
|
|
112
117
|
Returns a configuration instance pre-configured for keyword detection.
|
|
113
|
-
|
|
118
|
+
|
|
114
119
|
Args:
|
|
115
120
|
keywords (list, optional): List of banned keywords to detect.
|
|
116
121
|
"""
|
|
117
|
-
config = copy.deepcopy(
|
|
122
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
118
123
|
config["keyword_detector"] = {
|
|
119
124
|
"enabled": True,
|
|
120
|
-
"banned_keywords": keywords if keywords is not None else []
|
|
125
|
+
"banned_keywords": keywords if keywords is not None else [],
|
|
121
126
|
}
|
|
122
127
|
return cls(config)
|
|
123
128
|
|
|
@@ -126,7 +131,7 @@ class GuardrailsConfig:
|
|
|
126
131
|
"""
|
|
127
132
|
Returns a configuration instance pre-configured for copyright/IP detection.
|
|
128
133
|
"""
|
|
129
|
-
config = copy.deepcopy(
|
|
134
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
130
135
|
config["copyright_ip"] = {"enabled": True}
|
|
131
136
|
return cls(config)
|
|
132
137
|
|
|
@@ -134,21 +139,18 @@ class GuardrailsConfig:
|
|
|
134
139
|
def system_prompt(cls, index="system"):
|
|
135
140
|
"""
|
|
136
141
|
Returns a configuration instance pre-configured for system prompt detection.
|
|
137
|
-
|
|
142
|
+
|
|
138
143
|
Args:
|
|
139
144
|
index (str, optional): Index name for system prompt detection. Defaults to "system".
|
|
140
145
|
"""
|
|
141
|
-
config = copy.deepcopy(
|
|
142
|
-
config["system_prompt"] = {
|
|
143
|
-
"enabled": True,
|
|
144
|
-
"index": index
|
|
145
|
-
}
|
|
146
|
+
config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
|
|
147
|
+
config["system_prompt"] = {"enabled": True, "index": index}
|
|
146
148
|
return cls(config)
|
|
147
149
|
|
|
148
150
|
def update(self, **kwargs):
|
|
149
151
|
"""
|
|
150
152
|
Update the configuration with custom values.
|
|
151
|
-
|
|
153
|
+
|
|
152
154
|
Only keys that exist in the default configuration can be updated.
|
|
153
155
|
For example:
|
|
154
156
|
config.update(nsfw={"enabled": True}, toxicity={"enabled": True})
|
|
@@ -170,16 +172,16 @@ class GuardrailsConfig:
|
|
|
170
172
|
def from_custom_config(cls, config_dict: dict):
|
|
171
173
|
"""
|
|
172
174
|
Configure guardrails from a dictionary input.
|
|
173
|
-
|
|
175
|
+
|
|
174
176
|
Validates that the input dictionary matches the expected schema structure.
|
|
175
177
|
Each key must exist in the default configuration, and its value must be a dictionary.
|
|
176
|
-
|
|
178
|
+
|
|
177
179
|
Args:
|
|
178
180
|
config_dict (dict): Dictionary containing guardrails configuration
|
|
179
|
-
|
|
181
|
+
|
|
180
182
|
Returns:
|
|
181
183
|
GuardrailsConfig: Returns a new GuardrailsConfig instance
|
|
182
|
-
|
|
184
|
+
|
|
183
185
|
Raises:
|
|
184
186
|
ValueError: If the input dictionary contains invalid keys or malformed values
|
|
185
187
|
"""
|
|
@@ -189,33 +191,92 @@ class GuardrailsConfig:
|
|
|
189
191
|
raise ValueError(f"Unknown detector config: {key}")
|
|
190
192
|
if not isinstance(value, dict):
|
|
191
193
|
raise ValueError(f"Config value for {key} must be a dictionary")
|
|
192
|
-
|
|
194
|
+
|
|
193
195
|
# Validate that all required fields exist in the default config
|
|
194
|
-
default_fields = set(
|
|
196
|
+
default_fields = set(DEFAULT_GUARDRAILS_CONFIG[key].keys())
|
|
195
197
|
provided_fields = set(value.keys())
|
|
196
|
-
|
|
198
|
+
|
|
197
199
|
if not provided_fields.issubset(default_fields):
|
|
198
200
|
invalid_fields = provided_fields - default_fields
|
|
199
201
|
raise ValueError(f"Invalid fields for {key}: {invalid_fields}")
|
|
200
|
-
|
|
202
|
+
|
|
201
203
|
instance.config[key] = value
|
|
202
|
-
|
|
204
|
+
|
|
203
205
|
return instance
|
|
204
206
|
|
|
205
207
|
def get_config(self, detector_name: str) -> dict:
|
|
206
208
|
"""
|
|
207
209
|
Get the configuration for a specific detector.
|
|
208
|
-
|
|
210
|
+
|
|
209
211
|
Args:
|
|
210
212
|
detector_name (str): Name of the detector to get configuration for
|
|
211
|
-
|
|
213
|
+
|
|
212
214
|
Returns:
|
|
213
215
|
dict: Configuration dictionary for the specified detector
|
|
214
|
-
|
|
216
|
+
|
|
215
217
|
Raises:
|
|
216
218
|
ValueError: If the detector name doesn't exist in the configuration
|
|
217
219
|
"""
|
|
218
220
|
if detector_name not in self.config:
|
|
219
221
|
raise ValueError(f"Unknown detector: {detector_name}")
|
|
220
|
-
|
|
222
|
+
|
|
221
223
|
return copy.deepcopy(self.config[detector_name])
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class RedTeamConfig:
|
|
227
|
+
"""
|
|
228
|
+
A helper class to manage RedTeam configuration.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(self, config=None):
|
|
232
|
+
if config is None:
|
|
233
|
+
config = copy.deepcopy(DEFAULT_REDTEAM_CONFIG)
|
|
234
|
+
# Only include advanced tests if dataset is not standard
|
|
235
|
+
if config.get("dataset_name") != "standard":
|
|
236
|
+
config["redteam_test_configurations"].update(
|
|
237
|
+
copy.deepcopy(ADVANCED_REDTEAM_TESTS)
|
|
238
|
+
)
|
|
239
|
+
self.config = config
|
|
240
|
+
|
|
241
|
+
def as_dict(self):
|
|
242
|
+
"""
|
|
243
|
+
Return the underlying configuration dictionary.
|
|
244
|
+
"""
|
|
245
|
+
return self.config
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class ModelConfig:
|
|
249
|
+
def __init__(self, config=None):
|
|
250
|
+
if config is None:
|
|
251
|
+
config = copy.deepcopy(DETAIL_MODEL_CONFIG)
|
|
252
|
+
self.config = config
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def model_name(self, model_name: str):
|
|
256
|
+
"""
|
|
257
|
+
Set the model name.
|
|
258
|
+
"""
|
|
259
|
+
self.config["model_name"] = model_name
|
|
260
|
+
return self
|
|
261
|
+
|
|
262
|
+
@classmethod
|
|
263
|
+
def testing_for(self, testing_for: str):
|
|
264
|
+
"""
|
|
265
|
+
Set the testing for.
|
|
266
|
+
"""
|
|
267
|
+
self.config["testing_for"] = testing_for
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def model_config(self, model_config: dict):
|
|
272
|
+
"""
|
|
273
|
+
Set the model config.
|
|
274
|
+
"""
|
|
275
|
+
self.config["model_config"] = model_config
|
|
276
|
+
return self
|
|
277
|
+
|
|
278
|
+
def as_dict(self):
|
|
279
|
+
"""
|
|
280
|
+
Return the underlying configuration dictionary.
|
|
281
|
+
"""
|
|
282
|
+
return self.config
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from .models import *
|
|
2
|
+
from .red_team import *
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"DetailModelConfig",
|
|
6
|
+
"ModelConfig",
|
|
7
|
+
"RedTeamConfig",
|
|
8
|
+
"AdvancedRedTeamTests",
|
|
9
|
+
"TestConfig",
|
|
10
|
+
"AttackMethods",
|
|
11
|
+
"RedTeamTestConfigurations",
|
|
12
|
+
"TargetModelConfiguration",
|
|
13
|
+
"Location",
|
|
14
|
+
"Metadata",
|
|
15
|
+
"DEFAULT_REDTEAM_CONFIG",
|
|
16
|
+
"ADVANCED_REDTEAM_TESTS",
|
|
17
|
+
"DETAIL_MODEL_CONFIG",
|
|
18
|
+
]
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from dataclasses import dataclass, field, asdict
|
|
2
|
+
from typing import Optional, List, Set, Dict, Any
|
|
3
|
+
from enum import Enum
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Modality(Enum):
|
|
8
|
+
TEXT = "text"
|
|
9
|
+
IMAGE = "image"
|
|
10
|
+
AUDIO = "audio"
|
|
11
|
+
VIDEO = "video"
|
|
12
|
+
|
|
13
|
+
def to_dict(self):
|
|
14
|
+
return self.value
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EndpointConfig:
|
|
19
|
+
scheme: str = "https"
|
|
20
|
+
host: str = "api.openai.com"
|
|
21
|
+
port: int = 443
|
|
22
|
+
base_path: str = "v1"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class PathsConfig:
|
|
27
|
+
completions: str = "/chat/completions"
|
|
28
|
+
chat: str = "chat/completions"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class AuthData:
|
|
33
|
+
header_name: str = "Authorization"
|
|
34
|
+
header_prefix: str = "Bearer"
|
|
35
|
+
space_after_prefix: bool = True
|
|
36
|
+
|
|
37
|
+
def to_dict(self):
|
|
38
|
+
return asdict(self)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_dict(cls, data: dict):
|
|
42
|
+
return cls(**data)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ModelDetailConfig:
|
|
47
|
+
model_version: Optional[str] = None
|
|
48
|
+
model_source: str = ""
|
|
49
|
+
model_provider: str = "openai"
|
|
50
|
+
system_prompt: str = ""
|
|
51
|
+
|
|
52
|
+
endpoint_url: str = "https://api.openai.com/v1/chat/completions"
|
|
53
|
+
auth_data: AuthData = field(default_factory=AuthData)
|
|
54
|
+
api_keys: Set[Optional[str]] = field(default_factory=lambda: {None})
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class DetailModelConfig:
|
|
59
|
+
model_saved_name: str = "Model Name"
|
|
60
|
+
testing_for: str = "LLM"
|
|
61
|
+
model_name: str = "gpt-4o-mini"
|
|
62
|
+
modality: Modality = Modality.TEXT
|
|
63
|
+
model_config: ModelDetailConfig = field(default_factory=ModelDetailConfig)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class ModelConfigDetails:
|
|
68
|
+
model_version: Optional[str] = None
|
|
69
|
+
model_source: str = ""
|
|
70
|
+
model_provider: str = "openai"
|
|
71
|
+
system_prompt: str = ""
|
|
72
|
+
conversation_template: str = ""
|
|
73
|
+
is_compatible_with: str = "openai"
|
|
74
|
+
hosting_type: str = "External"
|
|
75
|
+
endpoint_url: str = "https://api.openai.com/v1/chat/completions"
|
|
76
|
+
auth_data: AuthData = field(default_factory=AuthData)
|
|
77
|
+
apikey: Optional[str] = None
|
|
78
|
+
default_request_options: Dict[str, Any] = field(default_factory=dict)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_dict(cls, data: dict):
|
|
82
|
+
# Create a copy of the data to avoid modifying the original
|
|
83
|
+
data = data.copy()
|
|
84
|
+
|
|
85
|
+
# Remove known fields that we don't want in our model
|
|
86
|
+
unwanted_fields = ["queryParams", "paths"]
|
|
87
|
+
for field in unwanted_fields:
|
|
88
|
+
data.pop(field, None)
|
|
89
|
+
|
|
90
|
+
# Handle apikeys to apikey conversion
|
|
91
|
+
if "apikeys" in data:
|
|
92
|
+
apikeys = data.pop("apikeys")
|
|
93
|
+
if apikeys and not data.get("apikey"):
|
|
94
|
+
data["apikey"] = apikeys[0]
|
|
95
|
+
|
|
96
|
+
# Convert endpoint dict to endpoint_url if present
|
|
97
|
+
if "endpoint" in data:
|
|
98
|
+
endpoint = data.pop("endpoint")
|
|
99
|
+
scheme = endpoint.get("scheme", "https")
|
|
100
|
+
host = endpoint.get("host", "")
|
|
101
|
+
port = endpoint.get("port", "")
|
|
102
|
+
base_path = endpoint.get("base_path", "")
|
|
103
|
+
|
|
104
|
+
endpoint_url = f"{scheme}://{host}"
|
|
105
|
+
if port and port not in [80, 443]:
|
|
106
|
+
endpoint_url += f":{port}"
|
|
107
|
+
if base_path:
|
|
108
|
+
base_path = "/" + base_path.strip("/")
|
|
109
|
+
endpoint_url += base_path
|
|
110
|
+
|
|
111
|
+
data["endpoint_url"] = endpoint_url
|
|
112
|
+
|
|
113
|
+
# Handle nested AuthData
|
|
114
|
+
auth_data = data.pop("auth_data", {})
|
|
115
|
+
auth_data_obj = AuthData.from_dict(auth_data)
|
|
116
|
+
|
|
117
|
+
# Only keep fields that are defined in the dataclass
|
|
118
|
+
valid_fields = cls.__dataclass_fields__.keys()
|
|
119
|
+
filtered_data = {k: v for k, v in data.items() if k in valid_fields}
|
|
120
|
+
|
|
121
|
+
return cls(**filtered_data, auth_data=auth_data_obj)
|
|
122
|
+
|
|
123
|
+
def to_dict(self):
|
|
124
|
+
d = asdict(self)
|
|
125
|
+
# Handle AuthData specifically
|
|
126
|
+
d["auth_data"] = self.auth_data.to_dict()
|
|
127
|
+
return d
|
|
128
|
+
|
|
129
|
+
def to_json(self):
|
|
130
|
+
return json.dumps(self.to_dict())
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def from_json(cls, json_str: str):
|
|
134
|
+
return cls.from_dict(json.loads(json_str))
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass
|
|
138
|
+
class ModelConfig:
|
|
139
|
+
created_at: str = ""
|
|
140
|
+
updated_at: str = ""
|
|
141
|
+
model_id: str = ""
|
|
142
|
+
model_saved_name: str = "Model Name"
|
|
143
|
+
testing_for: str = "LLM"
|
|
144
|
+
model_name: str = "gpt-4o-mini"
|
|
145
|
+
model_type: str = "text_2_text"
|
|
146
|
+
modality: Modality = Modality.TEXT
|
|
147
|
+
certifications: List[str] = field(default_factory=list)
|
|
148
|
+
model_config: ModelConfigDetails = field(default_factory=ModelConfigDetails)
|
|
149
|
+
|
|
150
|
+
def to_dict(self) -> dict:
|
|
151
|
+
"""Convert the ModelConfig instance to a dictionary."""
|
|
152
|
+
# First create a shallow copy of self as dict
|
|
153
|
+
d = {}
|
|
154
|
+
for field in self.__dataclass_fields__:
|
|
155
|
+
value = getattr(self, field)
|
|
156
|
+
if field == "modality":
|
|
157
|
+
d[field] = value.value
|
|
158
|
+
elif field == "model_config":
|
|
159
|
+
if isinstance(value, ModelConfigDetails):
|
|
160
|
+
d[field] = value.to_dict()
|
|
161
|
+
else:
|
|
162
|
+
d[field] = value
|
|
163
|
+
else:
|
|
164
|
+
d[field] = value
|
|
165
|
+
return d
|
|
166
|
+
|
|
167
|
+
def to_json(self) -> str:
|
|
168
|
+
"""Convert the ModelConfig instance to a JSON string."""
|
|
169
|
+
return json.dumps(self.to_dict())
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def from_dict(cls, data: dict):
|
|
173
|
+
"""Create a ModelConfig instance from a dictionary."""
|
|
174
|
+
# Handle nested ModelConfigDetails
|
|
175
|
+
model_config_data = data.pop("model_config", {})
|
|
176
|
+
model_config = ModelConfigDetails.from_dict(model_config_data)
|
|
177
|
+
|
|
178
|
+
# Handle Modality enum
|
|
179
|
+
modality_value = data.pop("modality", "text")
|
|
180
|
+
modality = Modality(modality_value)
|
|
181
|
+
|
|
182
|
+
return cls(**data, modality=modality, model_config=model_config)
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def from_json(cls, json_str: str):
|
|
186
|
+
"""Create a ModelConfig instance from a JSON string."""
|
|
187
|
+
data = json.loads(json_str)
|
|
188
|
+
return cls.from_dict(data)
|
|
189
|
+
|
|
190
|
+
def __str__(self):
|
|
191
|
+
"""String representation of the ModelConfig."""
|
|
192
|
+
return f"ModelConfig(name={self.model_saved_name}, model={self.model_name})"
|
|
193
|
+
|
|
194
|
+
def __repr__(self):
|
|
195
|
+
"""Detailed string representation of the ModelConfig."""
|
|
196
|
+
return (
|
|
197
|
+
f"ModelConfig({', '.join(f'{k}={v!r}' for k, v in self.to_dict().items())})"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# Default configuration
|
|
202
|
+
DETAIL_MODEL_CONFIG = ModelConfig()
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from dataclasses import dataclass, field, asdict
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class RedTeamResponse:
|
|
8
|
+
task_id: Optional[str] = None
|
|
9
|
+
|
|
10
|
+
def to_dict(self) -> dict:
|
|
11
|
+
return asdict(self)
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def from_dict(cls, data: dict):
|
|
15
|
+
return cls(**data)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class RedTeamTaskStatus:
|
|
20
|
+
status: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
def to_dict(self) -> dict:
|
|
23
|
+
return asdict(self)
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def from_dict(cls, data: dict):
|
|
27
|
+
return cls(**data)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class RedTeamTaskDetails:
|
|
32
|
+
created_at: Optional[str] = None
|
|
33
|
+
model_name: Optional[str] = None
|
|
34
|
+
status: Optional[str] = None
|
|
35
|
+
task_id: Optional[str] = None
|
|
36
|
+
|
|
37
|
+
def to_dict(self) -> dict:
|
|
38
|
+
return asdict(self)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_dict(cls, data: dict):
|
|
42
|
+
return cls(**data)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class RedTeamResultSummary:
|
|
47
|
+
test_date: Optional[str] = None
|
|
48
|
+
test_name: Optional[str] = None
|
|
49
|
+
dataset_name: Optional[str] = None
|
|
50
|
+
model_name: Optional[str] = None
|
|
51
|
+
model_endpoint_url: Optional[str] = None
|
|
52
|
+
model_source: Optional[str] = None
|
|
53
|
+
model_provider: Optional[str] = None
|
|
54
|
+
risk_score: Optional[float] = None
|
|
55
|
+
test_type: Optional[List] = None
|
|
56
|
+
nist_category: Optional[List] = None
|
|
57
|
+
scenario: Optional[List] = None
|
|
58
|
+
category: Optional[List] = None
|
|
59
|
+
attack_method: Optional[List] = None
|
|
60
|
+
|
|
61
|
+
def to_dict(self) -> dict:
|
|
62
|
+
return asdict(self)
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_dict(cls, data: dict):
|
|
66
|
+
return cls(**data)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class RedTeamResultDetails: # To Be Updated
|
|
71
|
+
details: Optional[Dict] = None
|
|
72
|
+
|
|
73
|
+
def to_dict(self) -> dict:
|
|
74
|
+
return asdict(self)
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_dict(cls, data: dict):
|
|
78
|
+
return cls(**data)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class AttackMethods:
|
|
83
|
+
basic: List[str] = field(default_factory=lambda: ["basic"])
|
|
84
|
+
advanced: Dict[str, List[str]] = field(
|
|
85
|
+
default_factory=lambda: {"static": ["single_shot"], "dynamic": ["iterative"]}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def to_dict(self) -> dict:
|
|
89
|
+
return asdict(self)
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_dict(cls, data: dict):
|
|
93
|
+
return cls(**data)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class TestConfig:
|
|
98
|
+
sample_percentage: int = 100
|
|
99
|
+
attack_methods: AttackMethods = field(default_factory=AttackMethods)
|
|
100
|
+
|
|
101
|
+
def to_dict(self) -> dict:
|
|
102
|
+
return {
|
|
103
|
+
"sample_percentage": self.sample_percentage,
|
|
104
|
+
"attack_methods": self.attack_methods.to_dict(),
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def from_dict(cls, data: dict):
|
|
109
|
+
attack_methods = AttackMethods.from_dict(data.pop("attack_methods", {}))
|
|
110
|
+
return cls(**data, attack_methods=attack_methods)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class RedTeamTestConfigurations:
|
|
115
|
+
# Basic tests
|
|
116
|
+
bias_test: TestConfig = field(default=None)
|
|
117
|
+
cbrn_test: TestConfig = field(default=None)
|
|
118
|
+
insecure_code_test: TestConfig = field(default=None)
|
|
119
|
+
toxicity_test: TestConfig = field(default=None)
|
|
120
|
+
harmful_test: TestConfig = field(default=None)
|
|
121
|
+
# Advanced tests
|
|
122
|
+
adv_info_test: TestConfig = field(default=None)
|
|
123
|
+
adv_bias_test: TestConfig = field(default=None)
|
|
124
|
+
adv_command_test: TestConfig = field(default=None)
|
|
125
|
+
|
|
126
|
+
def to_dict(self) -> dict:
|
|
127
|
+
return asdict(self)
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_dict(cls, data: dict):
|
|
131
|
+
return cls(**{k: TestConfig.from_dict(v) for k, v in data.items()})
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class TargetModelConfiguration:
|
|
136
|
+
testing_for: str = "LLM"
|
|
137
|
+
model_name: str = "gpt-4o-mini"
|
|
138
|
+
model_version: Optional[str] = None
|
|
139
|
+
system_prompt: str = ""
|
|
140
|
+
conversation_template: str = ""
|
|
141
|
+
model_source: str = ""
|
|
142
|
+
model_provider: str = "openai"
|
|
143
|
+
model_endpoint_url: str = "https://api.openai.com/v1/chat/completions"
|
|
144
|
+
model_api_key: Optional[str] = None
|
|
145
|
+
|
|
146
|
+
def to_dict(self) -> dict:
|
|
147
|
+
return asdict(self)
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def from_dict(cls, data: dict):
|
|
151
|
+
return cls(**data)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class RedTeamConfig:
|
|
156
|
+
test_name: str = "Test Name"
|
|
157
|
+
dataset_name: str = "standard"
|
|
158
|
+
model_name: str = "gpt-4o-mini"
|
|
159
|
+
redteam_test_configurations: RedTeamTestConfigurations = field(
|
|
160
|
+
default_factory=RedTeamTestConfigurations
|
|
161
|
+
)
|
|
162
|
+
target_model_configuration: TargetModelConfiguration = field(
|
|
163
|
+
default_factory=TargetModelConfiguration
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def to_dict(self) -> dict:
|
|
167
|
+
d = asdict(self)
|
|
168
|
+
d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
|
|
169
|
+
d["target_model_configuration"] = self.target_model_configuration.to_dict()
|
|
170
|
+
return d
|
|
171
|
+
|
|
172
|
+
def to_json(self) -> str:
|
|
173
|
+
return json.dumps(self.to_dict())
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_dict(cls, data: dict):
|
|
177
|
+
data = data.copy()
|
|
178
|
+
test_configs = RedTeamTestConfigurations.from_dict(
|
|
179
|
+
data.pop("redteam_test_configurations", {})
|
|
180
|
+
)
|
|
181
|
+
target_config = TargetModelConfiguration.from_dict(
|
|
182
|
+
data.pop("target_model_configuration", {})
|
|
183
|
+
)
|
|
184
|
+
return cls(
|
|
185
|
+
**data,
|
|
186
|
+
redteam_test_configurations=test_configs,
|
|
187
|
+
target_model_configuration=target_config,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def from_json(cls, json_str: str):
|
|
192
|
+
return cls.from_dict(json.loads(json_str))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# Default configurations
|
|
196
|
+
DEFAULT_REDTEAM_CONFIG = RedTeamConfig()
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import urllib3
|
|
2
|
+
from .dto import ModelConfig, ModelDetailConfig
|
|
3
|
+
from urllib.parse import urlparse, urlsplit
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelClient:
|
|
7
|
+
def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
|
|
8
|
+
self.api_key = api_key
|
|
9
|
+
self.base_url = base_url
|
|
10
|
+
self.http = urllib3.PoolManager()
|
|
11
|
+
self.headers = {"apikey": self.api_key}
|
|
12
|
+
|
|
13
|
+
def _request(self, method, endpoint, payload=None, headers=None, **kwargs):
|
|
14
|
+
url = self.base_url + endpoint
|
|
15
|
+
request_headers = {
|
|
16
|
+
"Accept-Encoding": "gzip", # Add required gzip encoding
|
|
17
|
+
**self.headers,
|
|
18
|
+
}
|
|
19
|
+
if headers:
|
|
20
|
+
request_headers.update(headers)
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
response = self.http.request(method, url, headers=request_headers, **kwargs)
|
|
24
|
+
|
|
25
|
+
if response.status >= 400:
|
|
26
|
+
error_response = (
|
|
27
|
+
response.json()
|
|
28
|
+
if response.data
|
|
29
|
+
else {"message": f"HTTP {response.status}"}
|
|
30
|
+
)
|
|
31
|
+
raise urllib3.exceptions.HTTPError(
|
|
32
|
+
f"HTTP {response.status}: {error_response}"
|
|
33
|
+
)
|
|
34
|
+
return response.json()
|
|
35
|
+
except urllib3.exceptions.HTTPError as e:
|
|
36
|
+
return {"error": str(e)}
|
|
37
|
+
|
|
38
|
+
def health(self):
|
|
39
|
+
return self._request("GET", "/models/health")
|
|
40
|
+
|
|
41
|
+
def add_model(self, config: ModelConfig):
|
|
42
|
+
"""
|
|
43
|
+
Add a new model configuration to the system.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
config (ModelConfig): Configuration object containing model details
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
dict: Response from the API containing the added model details
|
|
50
|
+
"""
|
|
51
|
+
headers = {"Content-Type": "application/json"}
|
|
52
|
+
config = ModelConfig.from_dict(config)
|
|
53
|
+
# Parse endpoint_url into components
|
|
54
|
+
parsed_url = urlparse(config.model_config.endpoint_url)
|
|
55
|
+
path_parts = parsed_url.path.strip("/").split("/")
|
|
56
|
+
|
|
57
|
+
# Extract base_path and endpoint path
|
|
58
|
+
if len(path_parts) >= 1:
|
|
59
|
+
base_path = path_parts[0] # Usually 'v1'
|
|
60
|
+
remaining_path = "/".join(path_parts[1:]) # The rest of the path
|
|
61
|
+
else:
|
|
62
|
+
base_path = ""
|
|
63
|
+
remaining_path = ""
|
|
64
|
+
|
|
65
|
+
# Determine paths based on the endpoint
|
|
66
|
+
paths = {
|
|
67
|
+
"completions": f"/{remaining_path}" if remaining_path else "",
|
|
68
|
+
"chat": "",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
payload = {
|
|
72
|
+
"model_saved_name": config.model_saved_name,
|
|
73
|
+
"testing_for": config.testing_for,
|
|
74
|
+
"model_name": config.model_name,
|
|
75
|
+
"model_type": config.model_type,
|
|
76
|
+
"certifications": config.certifications,
|
|
77
|
+
"model_config": {
|
|
78
|
+
"is_compatible_with": config.model_config.is_compatible_with,
|
|
79
|
+
"model_version": config.model_config.model_version,
|
|
80
|
+
"hosting_type": config.model_config.hosting_type,
|
|
81
|
+
"model_source": config.model_config.model_source,
|
|
82
|
+
"model_provider": config.model_config.model_provider,
|
|
83
|
+
"system_prompt": config.model_config.system_prompt,
|
|
84
|
+
"conversation_template": config.model_config.conversation_template,
|
|
85
|
+
"endpoint": {
|
|
86
|
+
"scheme": parsed_url.scheme,
|
|
87
|
+
"host": parsed_url.hostname,
|
|
88
|
+
"port": parsed_url.port
|
|
89
|
+
or (443 if parsed_url.scheme == "https" else 80),
|
|
90
|
+
"base_path": f"/{base_path}/{paths['completions']}", # Just v1
|
|
91
|
+
},
|
|
92
|
+
"paths": paths,
|
|
93
|
+
"auth_data": {
|
|
94
|
+
"header_name": config.model_config.auth_data.header_name,
|
|
95
|
+
"header_prefix": config.model_config.auth_data.header_prefix,
|
|
96
|
+
"space_after_prefix": config.model_config.auth_data.space_after_prefix,
|
|
97
|
+
},
|
|
98
|
+
"apikeys": (
|
|
99
|
+
[config.model_config.apikey] if config.model_config.apikey else []
|
|
100
|
+
),
|
|
101
|
+
"default_request_options": config.model_config.default_request_options,
|
|
102
|
+
},
|
|
103
|
+
}
|
|
104
|
+
print(payload)
|
|
105
|
+
return self._request("POST", "/models/add-model", headers=headers, json=payload)
|
|
106
|
+
|
|
107
|
+
def get_model(self, model_id: str) -> ModelConfig:
|
|
108
|
+
"""
|
|
109
|
+
Get model configuration by model ID.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model_id (str): ID of the model to retrieve
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
ModelConfig: Configuration object containing model details
|
|
116
|
+
"""
|
|
117
|
+
headers = {"X-Enkrypt-Model": model_id}
|
|
118
|
+
response = self._request("GET", "/models/get-model", headers=headers)
|
|
119
|
+
return ModelConfig.from_dict(response)
|
|
120
|
+
|
|
121
|
+
def get_model_list(self):
|
|
122
|
+
"""
|
|
123
|
+
Get a list of all available models.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
dict: Response from the API containing the list of models
|
|
127
|
+
"""
|
|
128
|
+
try:
|
|
129
|
+
return self._request("GET", "/models/list-models")
|
|
130
|
+
except Exception as e:
|
|
131
|
+
return {"error": str(e)}
|
|
132
|
+
|
|
133
|
+
def delete_model(self, model_id: str):
|
|
134
|
+
"""
|
|
135
|
+
Delete a specific model from the system.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model_id (str): The identifier or name of the model to delete
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
dict: Response from the API containing the deletion status
|
|
142
|
+
"""
|
|
143
|
+
headers = {"X-Enkrypt-Model": model_id}
|
|
144
|
+
return self._request("DELETE", "/models/delete-model", headers=headers)
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import urllib3
|
|
2
|
+
from .dto import (
|
|
3
|
+
RedTeamConfig,
|
|
4
|
+
RedTeamResponse,
|
|
5
|
+
RedTeamResultSummary,
|
|
6
|
+
RedTeamResultDetails,
|
|
7
|
+
RedTeamTaskStatus,
|
|
8
|
+
RedTeamTaskDetails,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RedTeamError(Exception):
|
|
13
|
+
"""
|
|
14
|
+
A custom exception for Red Team errors.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RedTeamClient:
|
|
21
|
+
"""
|
|
22
|
+
A client for interacting with the Red Team API.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com"):
|
|
26
|
+
self.api_key = api_key
|
|
27
|
+
self.base_url = base_url
|
|
28
|
+
self.http = urllib3.PoolManager()
|
|
29
|
+
self.headers = {"apikey": self.api_key}
|
|
30
|
+
|
|
31
|
+
def _request(self, method, endpoint, headers=None, **kwargs):
|
|
32
|
+
url = self.base_url + endpoint
|
|
33
|
+
request_headers = {
|
|
34
|
+
"Accept-Encoding": "gzip", # Add required gzip encoding
|
|
35
|
+
**self.headers,
|
|
36
|
+
}
|
|
37
|
+
if headers:
|
|
38
|
+
request_headers.update(headers)
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
response = self.http.request(method, url, headers=request_headers, **kwargs)
|
|
42
|
+
if response.status >= 400:
|
|
43
|
+
error_response = (
|
|
44
|
+
response.json()
|
|
45
|
+
if response.data
|
|
46
|
+
else {"message": f"HTTP {response.status}"}
|
|
47
|
+
)
|
|
48
|
+
raise urllib3.exceptions.HTTPError(
|
|
49
|
+
f"HTTP {response.status}: {error_response}"
|
|
50
|
+
)
|
|
51
|
+
return response.json()
|
|
52
|
+
except urllib3.exceptions.HTTPError as e:
|
|
53
|
+
print(f"Request failed: {e}")
|
|
54
|
+
return {"error": str(e)}
|
|
55
|
+
|
|
56
|
+
def get_model(self, model):
|
|
57
|
+
models = self._request("GET", "/models/list-models")
|
|
58
|
+
if model in models:
|
|
59
|
+
return model
|
|
60
|
+
else:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
def add_task(
|
|
64
|
+
self,
|
|
65
|
+
config: RedTeamConfig,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Add a new red teaming task.
|
|
69
|
+
"""
|
|
70
|
+
config = RedTeamConfig.from_dict(config)
|
|
71
|
+
test_configs = config.redteam_test_configurations.to_dict()
|
|
72
|
+
# Remove None or empty test configurations
|
|
73
|
+
test_configs = {k: v for k, v in test_configs.items() if v is not None}
|
|
74
|
+
|
|
75
|
+
payload = {
|
|
76
|
+
# "async": config.async_enabled,
|
|
77
|
+
"dataset_name": config.dataset_name,
|
|
78
|
+
"test_name": config.test_name,
|
|
79
|
+
"redteam_test_configurations": test_configs,
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
model = config.model_name
|
|
83
|
+
saved_model = self.get_model(model)
|
|
84
|
+
|
|
85
|
+
if saved_model:
|
|
86
|
+
print("saved model found")
|
|
87
|
+
headers = {
|
|
88
|
+
"X-Enkrypt-Model": saved_model,
|
|
89
|
+
"Content-Type": "application/json",
|
|
90
|
+
}
|
|
91
|
+
payload["location"] = {"storage": "supabase", "container_name": "supabase"}
|
|
92
|
+
return self._request(
|
|
93
|
+
"POST",
|
|
94
|
+
"/redteam/v2/model/add-task",
|
|
95
|
+
headers=headers,
|
|
96
|
+
json=payload,
|
|
97
|
+
)
|
|
98
|
+
elif config.target_model_configuration:
|
|
99
|
+
payload["target_model_configuration"] = (
|
|
100
|
+
config.target_model_configuration.to_dict()
|
|
101
|
+
)
|
|
102
|
+
# print(payload)
|
|
103
|
+
response = self._request(
|
|
104
|
+
"POST",
|
|
105
|
+
"/redteam/v2/add-task",
|
|
106
|
+
json=payload,
|
|
107
|
+
)
|
|
108
|
+
return RedTeamResponse.from_dict(response)
|
|
109
|
+
else:
|
|
110
|
+
raise RedTeamError(
|
|
111
|
+
"Please use a saved model or provide a target model configuration"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def status(self, task_id: str):
|
|
115
|
+
"""
|
|
116
|
+
Get the status of a specific red teaming task.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
task_id (str): The ID of the task to check status
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
dict: The task status information
|
|
123
|
+
"""
|
|
124
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
125
|
+
|
|
126
|
+
response = self._request("GET", "/redteam/task-status", headers=headers)
|
|
127
|
+
return RedTeamTaskStatus.from_dict(response)
|
|
128
|
+
|
|
129
|
+
def cancel_task(self, task_id: str):
|
|
130
|
+
"""
|
|
131
|
+
Cancel a specific red teaming task.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
task_id (str): The ID of the task to cancel
|
|
135
|
+
"""
|
|
136
|
+
raise RedTeamError(
|
|
137
|
+
"This feature is currently under development. Please check our documentation "
|
|
138
|
+
"at https://docs.enkrypt.ai for updates or contact support@enkrypt.ai for assistance."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def get_task(self, task_id: str):
|
|
142
|
+
"""
|
|
143
|
+
Get the status and details of a specific red teaming task.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
task_id (str): The ID of the task to retrieve
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
dict: The task details and status
|
|
150
|
+
"""
|
|
151
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
152
|
+
|
|
153
|
+
response = self._request("GET", "/redteam/get-task", headers=headers)
|
|
154
|
+
response["data"]["task_id"] = response["data"].pop("job_id")
|
|
155
|
+
return RedTeamTaskDetails.from_dict(response["data"])
|
|
156
|
+
|
|
157
|
+
def get_result_summary(self, task_id: str):
|
|
158
|
+
"""
|
|
159
|
+
Get the summary of results for a specific red teaming task.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
task_id (str): The ID of the task to get results for
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
dict: The summary of the task results
|
|
166
|
+
"""
|
|
167
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
168
|
+
|
|
169
|
+
response = self._request("GET", "/redteam/results/summary", headers=headers)
|
|
170
|
+
return RedTeamResultSummary.from_dict(response["summary"])
|
|
171
|
+
|
|
172
|
+
def get_result_details(self, task_id: str):
|
|
173
|
+
"""
|
|
174
|
+
Get the detailed results for a specific red teaming task.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
task_id (str): The ID of the task to get detailed results for
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
dict: The detailed task results
|
|
181
|
+
"""
|
|
182
|
+
# TODO: Update the response to be updated
|
|
183
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
184
|
+
response = self._request("GET", "/redteam/results/details", headers=headers)
|
|
185
|
+
return RedTeamResultDetails.from_dict(response["details"])
|
|
@@ -12,6 +12,9 @@ src/enkryptai_sdk.egg-info/PKG-INFO
|
|
|
12
12
|
src/enkryptai_sdk.egg-info/SOURCES.txt
|
|
13
13
|
src/enkryptai_sdk.egg-info/dependency_links.txt
|
|
14
14
|
src/enkryptai_sdk.egg-info/top_level.txt
|
|
15
|
+
src/enkryptai_sdk/dto/__init__.py
|
|
16
|
+
src/enkryptai_sdk/dto/models.py
|
|
17
|
+
src/enkryptai_sdk/dto/red_team.py
|
|
15
18
|
tests/test_all.py
|
|
16
19
|
tests/test_basic.py
|
|
17
20
|
tests/test_detect_policy.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|