enkryptai-sdk 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enkryptai_sdk/__init__.py +16 -4
- enkryptai_sdk/ai_proxy.py +70 -0
- enkryptai_sdk/base.py +36 -0
- enkryptai_sdk/datasets.py +142 -0
- enkryptai_sdk/deployments.py +121 -0
- enkryptai_sdk/dto/__init__.py +64 -0
- enkryptai_sdk/dto/ai_proxy.py +325 -0
- enkryptai_sdk/dto/base.py +70 -0
- enkryptai_sdk/dto/datasets.py +152 -0
- enkryptai_sdk/dto/deployments.py +334 -0
- enkryptai_sdk/dto/guardrails.py +1261 -0
- enkryptai_sdk/dto/models.py +211 -45
- enkryptai_sdk/dto/red_team.py +279 -62
- enkryptai_sdk/guardrails.py +219 -70
- enkryptai_sdk/guardrails_old.py +195 -0
- enkryptai_sdk/models.py +136 -54
- enkryptai_sdk/red_team.py +167 -63
- enkryptai_sdk-0.1.7.dist-info/METADATA +1205 -0
- enkryptai_sdk-0.1.7.dist-info/RECORD +25 -0
- {enkryptai_sdk-0.1.5.dist-info → enkryptai_sdk-0.1.7.dist-info}/WHEEL +1 -1
- enkryptai_sdk-0.1.5.dist-info/METADATA +0 -301
- enkryptai_sdk-0.1.5.dist-info/RECORD +0 -15
- {enkryptai_sdk-0.1.5.dist-info → enkryptai_sdk-0.1.7.dist-info/licenses}/LICENSE +0 -0
- {enkryptai_sdk-0.1.5.dist-info → enkryptai_sdk-0.1.7.dist-info}/top_level.txt +0 -0
enkryptai_sdk/models.py
CHANGED
|
@@ -1,44 +1,21 @@
|
|
|
1
|
-
import
|
|
2
|
-
from .dto import ModelConfig, ModelDetailConfig
|
|
1
|
+
from .base import BaseClient
|
|
3
2
|
from urllib.parse import urlparse, urlsplit
|
|
3
|
+
from .dto import ModelConfig, ModelResponse, ModelCollection
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
class
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def _request(self, method, endpoint, payload=None, headers=None, **kwargs):
|
|
14
|
-
url = self.base_url + endpoint
|
|
15
|
-
request_headers = {
|
|
16
|
-
"Accept-Encoding": "gzip", # Add required gzip encoding
|
|
17
|
-
**self.headers,
|
|
18
|
-
}
|
|
19
|
-
if headers:
|
|
20
|
-
request_headers.update(headers)
|
|
6
|
+
class ModelClientError(Exception):
|
|
7
|
+
"""
|
|
8
|
+
A custom exception for ModelClient errors.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
pass
|
|
21
12
|
|
|
22
|
-
try:
|
|
23
|
-
response = self.http.request(method, url, headers=request_headers, **kwargs)
|
|
24
|
-
|
|
25
|
-
if response.status >= 400:
|
|
26
|
-
error_response = (
|
|
27
|
-
response.json()
|
|
28
|
-
if response.data
|
|
29
|
-
else {"message": f"HTTP {response.status}"}
|
|
30
|
-
)
|
|
31
|
-
raise urllib3.exceptions.HTTPError(
|
|
32
|
-
f"HTTP {response.status}: {error_response}"
|
|
33
|
-
)
|
|
34
|
-
return response.json()
|
|
35
|
-
except urllib3.exceptions.HTTPError as e:
|
|
36
|
-
return {"error": str(e)}
|
|
37
13
|
|
|
38
|
-
|
|
39
|
-
|
|
14
|
+
class ModelClient(BaseClient):
|
|
15
|
+
def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
|
|
16
|
+
super().__init__(api_key, base_url)
|
|
40
17
|
|
|
41
|
-
def add_model(self, config: ModelConfig):
|
|
18
|
+
def add_model(self, config: ModelConfig) -> ModelResponse:
|
|
42
19
|
"""
|
|
43
20
|
Add a new model configuration to the system.
|
|
44
21
|
|
|
@@ -62,11 +39,15 @@ class ModelClient:
|
|
|
62
39
|
base_path = ""
|
|
63
40
|
remaining_path = ""
|
|
64
41
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
42
|
+
if config.model_config.paths:
|
|
43
|
+
paths = config.model_config.paths.to_dict()
|
|
44
|
+
else:
|
|
45
|
+
paths = {
|
|
46
|
+
"completions": (
|
|
47
|
+
f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
|
|
48
|
+
),
|
|
49
|
+
"chat": f"/{remaining_path}" if remaining_path else "",
|
|
50
|
+
}
|
|
70
51
|
|
|
71
52
|
payload = {
|
|
72
53
|
"model_saved_name": config.model_saved_name,
|
|
@@ -75,11 +56,10 @@ class ModelClient:
|
|
|
75
56
|
"model_type": config.model_type,
|
|
76
57
|
"certifications": config.certifications,
|
|
77
58
|
"model_config": {
|
|
78
|
-
"
|
|
59
|
+
"model_provider": config.model_config.model_provider,
|
|
79
60
|
"model_version": config.model_config.model_version,
|
|
80
61
|
"hosting_type": config.model_config.hosting_type,
|
|
81
62
|
"model_source": config.model_config.model_source,
|
|
82
|
-
"model_provider": config.model_config.model_provider,
|
|
83
63
|
"system_prompt": config.model_config.system_prompt,
|
|
84
64
|
"conversation_template": config.model_config.conversation_template,
|
|
85
65
|
"endpoint": {
|
|
@@ -87,7 +67,7 @@ class ModelClient:
|
|
|
87
67
|
"host": parsed_url.hostname,
|
|
88
68
|
"port": parsed_url.port
|
|
89
69
|
or (443 if parsed_url.scheme == "https" else 80),
|
|
90
|
-
"base_path": f"/{base_path}
|
|
70
|
+
"base_path": f"/{base_path}", # Just v1
|
|
91
71
|
},
|
|
92
72
|
"paths": paths,
|
|
93
73
|
"auth_data": {
|
|
@@ -101,21 +81,31 @@ class ModelClient:
|
|
|
101
81
|
"default_request_options": config.model_config.default_request_options,
|
|
102
82
|
},
|
|
103
83
|
}
|
|
104
|
-
|
|
105
|
-
|
|
84
|
+
try:
|
|
85
|
+
response = self._request(
|
|
86
|
+
"POST", "/models/add-model", headers=headers, json=payload
|
|
87
|
+
)
|
|
88
|
+
if response.get("error"):
|
|
89
|
+
raise ModelClientError(response["error"])
|
|
90
|
+
return ModelResponse.from_dict(response)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise ModelClientError(str(e))
|
|
106
93
|
|
|
107
|
-
def get_model(self,
|
|
94
|
+
def get_model(self, model_saved_name: str) -> ModelConfig:
|
|
108
95
|
"""
|
|
109
|
-
Get model configuration by model
|
|
96
|
+
Get model configuration by model saved name.
|
|
110
97
|
|
|
111
98
|
Args:
|
|
112
|
-
|
|
99
|
+
model_saved_name (str): Saved name of the model to retrieve
|
|
113
100
|
|
|
114
101
|
Returns:
|
|
115
102
|
ModelConfig: Configuration object containing model details
|
|
116
103
|
"""
|
|
117
|
-
headers = {"X-Enkrypt-Model":
|
|
104
|
+
headers = {"X-Enkrypt-Model": model_saved_name}
|
|
118
105
|
response = self._request("GET", "/models/get-model", headers=headers)
|
|
106
|
+
print(response)
|
|
107
|
+
if response.get("error"):
|
|
108
|
+
raise ModelClientError(response["error"])
|
|
119
109
|
return ModelConfig.from_dict(response)
|
|
120
110
|
|
|
121
111
|
def get_model_list(self):
|
|
@@ -126,19 +116,111 @@ class ModelClient:
|
|
|
126
116
|
dict: Response from the API containing the list of models
|
|
127
117
|
"""
|
|
128
118
|
try:
|
|
129
|
-
|
|
119
|
+
response = self._request("GET", "/models/list-models")
|
|
120
|
+
if response.get("error"):
|
|
121
|
+
raise ModelClientError(response["error"])
|
|
122
|
+
return ModelCollection.from_dict(response)
|
|
130
123
|
except Exception as e:
|
|
131
124
|
return {"error": str(e)}
|
|
132
125
|
|
|
133
|
-
def
|
|
126
|
+
def modify_model(self, config: ModelConfig, old_model_saved_name=None) -> ModelResponse:
|
|
127
|
+
"""
|
|
128
|
+
Modify an existing model in the system.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
model_saved_name (str): The saved name of the model to modify
|
|
132
|
+
config (ModelConfig): Configuration object containing model details
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
dict: Response from the API containing the modified model details
|
|
136
|
+
"""
|
|
137
|
+
if old_model_saved_name is None:
|
|
138
|
+
old_model_saved_name = config["model_saved_name"]
|
|
139
|
+
|
|
140
|
+
headers = {"Content-Type": "application/json", "X-Enkrypt-Model": old_model_saved_name}
|
|
141
|
+
# print(config)
|
|
142
|
+
config = ModelConfig.from_dict(config)
|
|
143
|
+
# Parse endpoint_url into components
|
|
144
|
+
parsed_url = urlparse(config.model_config.endpoint_url)
|
|
145
|
+
path_parts = parsed_url.path.strip("/").split("/")
|
|
146
|
+
|
|
147
|
+
# Extract base_path and endpoint path
|
|
148
|
+
if len(path_parts) >= 1:
|
|
149
|
+
base_path = path_parts[0] # Usually 'v1'
|
|
150
|
+
remaining_path = "/".join(path_parts[1:]) # The rest of the path
|
|
151
|
+
else:
|
|
152
|
+
base_path = ""
|
|
153
|
+
remaining_path = ""
|
|
154
|
+
|
|
155
|
+
if config.model_config.paths:
|
|
156
|
+
paths = config.model_config.paths.to_dict()
|
|
157
|
+
else:
|
|
158
|
+
# Determine paths based on the endpoint
|
|
159
|
+
paths = {
|
|
160
|
+
"completions": (
|
|
161
|
+
f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
|
|
162
|
+
),
|
|
163
|
+
"chat": f"/{remaining_path}" if remaining_path else "",
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
payload = {
|
|
167
|
+
"model_saved_name": config.model_saved_name,
|
|
168
|
+
"testing_for": config.testing_for,
|
|
169
|
+
"model_name": config.model_name,
|
|
170
|
+
"model_type": config.model_type,
|
|
171
|
+
"certifications": config.certifications,
|
|
172
|
+
"model_config": {
|
|
173
|
+
"model_provider": config.model_config.model_provider,
|
|
174
|
+
"model_version": config.model_config.model_version,
|
|
175
|
+
"hosting_type": config.model_config.hosting_type,
|
|
176
|
+
"model_source": config.model_config.model_source,
|
|
177
|
+
"system_prompt": config.model_config.system_prompt,
|
|
178
|
+
"conversation_template": config.model_config.conversation_template,
|
|
179
|
+
"endpoint": {
|
|
180
|
+
"scheme": parsed_url.scheme,
|
|
181
|
+
"host": parsed_url.hostname,
|
|
182
|
+
"port": parsed_url.port
|
|
183
|
+
or (443 if parsed_url.scheme == "https" else 80),
|
|
184
|
+
"base_path": f"/{base_path}", # Just v1
|
|
185
|
+
},
|
|
186
|
+
"paths": paths,
|
|
187
|
+
"auth_data": {
|
|
188
|
+
"header_name": config.model_config.auth_data.header_name,
|
|
189
|
+
"header_prefix": config.model_config.auth_data.header_prefix,
|
|
190
|
+
"space_after_prefix": config.model_config.auth_data.space_after_prefix,
|
|
191
|
+
},
|
|
192
|
+
"apikeys": (
|
|
193
|
+
[config.model_config.apikey] if config.model_config.apikey else []
|
|
194
|
+
),
|
|
195
|
+
"default_request_options": config.model_config.default_request_options,
|
|
196
|
+
},
|
|
197
|
+
}
|
|
198
|
+
try:
|
|
199
|
+
response = self._request(
|
|
200
|
+
"PATCH", "/models/modify-model", headers=headers, json=payload
|
|
201
|
+
)
|
|
202
|
+
if response.get("error"):
|
|
203
|
+
raise ModelClientError(response["error"])
|
|
204
|
+
return ModelResponse.from_dict(response)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
raise ModelClientError(str(e))
|
|
207
|
+
|
|
208
|
+
def delete_model(self, model_saved_name: str) -> ModelResponse:
|
|
134
209
|
"""
|
|
135
210
|
Delete a specific model from the system.
|
|
136
211
|
|
|
137
212
|
Args:
|
|
138
|
-
|
|
213
|
+
model_saved_name (str): The saved name of the model to delete
|
|
139
214
|
|
|
140
215
|
Returns:
|
|
141
216
|
dict: Response from the API containing the deletion status
|
|
142
217
|
"""
|
|
143
|
-
headers = {"X-Enkrypt-Model":
|
|
144
|
-
|
|
218
|
+
headers = {"X-Enkrypt-Model": model_saved_name}
|
|
219
|
+
|
|
220
|
+
try:
|
|
221
|
+
response = self._request("DELETE", "/models/delete-model", headers=headers)
|
|
222
|
+
if response.get("error"):
|
|
223
|
+
raise ModelClientError(response["error"])
|
|
224
|
+
return ModelResponse.from_dict(response)
|
|
225
|
+
except Exception as e:
|
|
226
|
+
raise ModelClientError(str(e))
|
enkryptai_sdk/red_team.py
CHANGED
|
@@ -1,15 +1,21 @@
|
|
|
1
1
|
import urllib3
|
|
2
|
+
from .base import BaseClient
|
|
2
3
|
from .dto import (
|
|
4
|
+
RedteamHealthResponse,
|
|
5
|
+
RedTeamModelHealthConfig,
|
|
6
|
+
RedteamModelHealthResponse,
|
|
3
7
|
RedTeamConfig,
|
|
8
|
+
# RedTeamConfigWithSavedModel,
|
|
4
9
|
RedTeamResponse,
|
|
5
10
|
RedTeamResultSummary,
|
|
6
11
|
RedTeamResultDetails,
|
|
7
12
|
RedTeamTaskStatus,
|
|
8
13
|
RedTeamTaskDetails,
|
|
14
|
+
RedTeamTaskList,
|
|
9
15
|
)
|
|
10
16
|
|
|
11
17
|
|
|
12
|
-
class
|
|
18
|
+
class RedTeamClientError(Exception):
|
|
13
19
|
"""
|
|
14
20
|
A custom exception for Red Team errors.
|
|
15
21
|
"""
|
|
@@ -17,48 +23,64 @@ class RedTeamError(Exception):
|
|
|
17
23
|
pass
|
|
18
24
|
|
|
19
25
|
|
|
20
|
-
class RedTeamClient:
|
|
26
|
+
class RedTeamClient(BaseClient):
|
|
21
27
|
"""
|
|
22
28
|
A client for interacting with the Red Team API.
|
|
23
29
|
"""
|
|
24
30
|
|
|
25
31
|
def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com"):
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
}
|
|
37
|
-
if headers:
|
|
38
|
-
request_headers.update(headers)
|
|
32
|
+
super().__init__(api_key, base_url)
|
|
33
|
+
|
|
34
|
+
# def get_model(self, model):
|
|
35
|
+
# models = self._request("GET", "/models/list-models")
|
|
36
|
+
# models = models["models"]
|
|
37
|
+
# for _model_data in models:
|
|
38
|
+
# if _model_data["model_saved_name"] == model:
|
|
39
|
+
# return _model_data["model_saved_name"]
|
|
40
|
+
# else:
|
|
41
|
+
# return None
|
|
39
42
|
|
|
43
|
+
def get_health(self):
|
|
44
|
+
"""
|
|
45
|
+
Get the health status of the service.
|
|
46
|
+
"""
|
|
40
47
|
try:
|
|
41
|
-
response = self.
|
|
42
|
-
if response.
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
48
|
+
response = self._request("GET", "/redteam/health")
|
|
49
|
+
if response.get("error"):
|
|
50
|
+
raise RedTeamClientError(response["error"])
|
|
51
|
+
return RedteamHealthResponse.from_dict(response)
|
|
52
|
+
except Exception as e:
|
|
53
|
+
raise RedTeamClientError(str(e))
|
|
54
|
+
|
|
55
|
+
def check_model_health(self, config: RedTeamModelHealthConfig):
|
|
56
|
+
"""
|
|
57
|
+
Get the health status of a model.
|
|
58
|
+
"""
|
|
59
|
+
try:
|
|
60
|
+
config = RedTeamModelHealthConfig.from_dict(config)
|
|
61
|
+
response = self._request("POST", "/redteam/model-health", json=config.to_dict())
|
|
62
|
+
# if response.get("error"):
|
|
63
|
+
if response.get("error") not in [None, ""]:
|
|
64
|
+
raise RedTeamClientError(response["error"])
|
|
65
|
+
return RedteamModelHealthResponse.from_dict(response)
|
|
66
|
+
except Exception as e:
|
|
67
|
+
raise RedTeamClientError(str(e))
|
|
68
|
+
|
|
69
|
+
def check_saved_model_health(self, model_saved_name: str):
|
|
70
|
+
"""
|
|
71
|
+
Get the health status of a saved model.
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
headers = {
|
|
75
|
+
"X-Enkrypt-Model": model_saved_name,
|
|
76
|
+
}
|
|
77
|
+
response = self._request("POST", "/redteam/model/model-health", headers=headers)
|
|
78
|
+
# if response.get("error"):
|
|
79
|
+
if response.get("error") not in [None, ""]:
|
|
80
|
+
raise RedTeamClientError(response["error"])
|
|
81
|
+
return RedteamModelHealthResponse.from_dict(response)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
raise RedTeamClientError(str(e))
|
|
62
84
|
|
|
63
85
|
def add_task(
|
|
64
86
|
self,
|
|
@@ -76,25 +98,26 @@ class RedTeamClient:
|
|
|
76
98
|
# "async": config.async_enabled,
|
|
77
99
|
"dataset_name": config.dataset_name,
|
|
78
100
|
"test_name": config.test_name,
|
|
79
|
-
"redteam_test_configurations":
|
|
101
|
+
"redteam_test_configurations": {
|
|
102
|
+
k: v.to_dict() for k, v in test_configs.items()
|
|
103
|
+
},
|
|
80
104
|
}
|
|
81
105
|
|
|
82
|
-
|
|
83
|
-
saved_model = self.get_model(model)
|
|
84
|
-
|
|
106
|
+
saved_model = config.model_saved_name
|
|
85
107
|
if saved_model:
|
|
86
|
-
print("saved model found")
|
|
87
108
|
headers = {
|
|
88
109
|
"X-Enkrypt-Model": saved_model,
|
|
89
110
|
"Content-Type": "application/json",
|
|
90
111
|
}
|
|
91
|
-
|
|
92
|
-
return self._request(
|
|
112
|
+
response = self._request(
|
|
93
113
|
"POST",
|
|
94
114
|
"/redteam/v2/model/add-task",
|
|
95
115
|
headers=headers,
|
|
96
116
|
json=payload,
|
|
97
117
|
)
|
|
118
|
+
if response.get("error"):
|
|
119
|
+
raise RedTeamClientError(response["error"])
|
|
120
|
+
return RedTeamResponse.from_dict(response)
|
|
98
121
|
elif config.target_model_configuration:
|
|
99
122
|
payload["target_model_configuration"] = (
|
|
100
123
|
config.target_model_configuration.to_dict()
|
|
@@ -105,81 +128,162 @@ class RedTeamClient:
|
|
|
105
128
|
"/redteam/v2/add-task",
|
|
106
129
|
json=payload,
|
|
107
130
|
)
|
|
131
|
+
if response.get("error"):
|
|
132
|
+
raise RedTeamClientError(response["error"])
|
|
108
133
|
return RedTeamResponse.from_dict(response)
|
|
109
134
|
else:
|
|
110
|
-
raise
|
|
135
|
+
raise RedTeamClientError(
|
|
111
136
|
"Please use a saved model or provide a target model configuration"
|
|
112
137
|
)
|
|
113
138
|
|
|
114
|
-
def status(self, task_id: str):
|
|
139
|
+
def status(self, task_id: str = None, test_name: str = None):
|
|
115
140
|
"""
|
|
116
141
|
Get the status of a specific red teaming task.
|
|
117
142
|
|
|
118
143
|
Args:
|
|
119
|
-
task_id (str): The ID of the task to check status
|
|
144
|
+
task_id (str, optional): The ID of the task to check status
|
|
145
|
+
test_name (str, optional): The name of the test to check status
|
|
120
146
|
|
|
121
147
|
Returns:
|
|
122
148
|
dict: The task status information
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
|
|
123
152
|
"""
|
|
124
|
-
|
|
153
|
+
if not task_id and not test_name:
|
|
154
|
+
raise RedTeamClientError("Either task_id or test_name must be provided")
|
|
155
|
+
|
|
156
|
+
headers = {}
|
|
157
|
+
if task_id:
|
|
158
|
+
headers["X-Enkrypt-Task-ID"] = task_id
|
|
159
|
+
if test_name:
|
|
160
|
+
headers["X-Enkrypt-Test-Name"] = test_name
|
|
125
161
|
|
|
126
162
|
response = self._request("GET", "/redteam/task-status", headers=headers)
|
|
163
|
+
if response.get("error"):
|
|
164
|
+
raise RedTeamClientError(response["error"])
|
|
127
165
|
return RedTeamTaskStatus.from_dict(response)
|
|
128
166
|
|
|
129
|
-
def cancel_task(self, task_id: str):
|
|
167
|
+
def cancel_task(self, task_id: str = None, test_name: str = None):
|
|
130
168
|
"""
|
|
131
169
|
Cancel a specific red teaming task.
|
|
132
170
|
|
|
133
171
|
Args:
|
|
134
|
-
task_id (str): The ID of the task to cancel
|
|
172
|
+
task_id (str, optional): The ID of the task to cancel
|
|
173
|
+
test_name (str, optional): The name of the test to cancel
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
dict: The cancellation response
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
|
|
135
180
|
"""
|
|
136
|
-
raise
|
|
181
|
+
raise RedTeamClientError(
|
|
137
182
|
"This feature is currently under development. Please check our documentation "
|
|
138
183
|
"at https://docs.enkrypt.ai for updates or contact support@enkrypt.ai for assistance."
|
|
139
184
|
)
|
|
140
185
|
|
|
141
|
-
|
|
186
|
+
if not task_id and not test_name:
|
|
187
|
+
raise RedTeamClientError("Either task_id or test_name must be provided")
|
|
188
|
+
|
|
189
|
+
headers = {}
|
|
190
|
+
if task_id:
|
|
191
|
+
headers["X-Enkrypt-Task-ID"] = task_id
|
|
192
|
+
if test_name:
|
|
193
|
+
headers["X-Enkrypt-Test-Name"] = test_name
|
|
194
|
+
|
|
195
|
+
response = self._request("POST", "/redteam/cancel-task", headers=headers)
|
|
196
|
+
if response.get("error"):
|
|
197
|
+
raise RedTeamClientError(response["error"])
|
|
198
|
+
return response
|
|
199
|
+
|
|
200
|
+
def get_task(self, task_id: str = None, test_name: str = None):
|
|
142
201
|
"""
|
|
143
202
|
Get the status and details of a specific red teaming task.
|
|
144
203
|
|
|
145
204
|
Args:
|
|
146
|
-
task_id (str): The ID of the task to retrieve
|
|
205
|
+
task_id (str, optional): The ID of the task to retrieve
|
|
206
|
+
test_name (str, optional): The name of the test to retrieve
|
|
147
207
|
|
|
148
208
|
Returns:
|
|
149
209
|
dict: The task details and status
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
|
|
150
213
|
"""
|
|
151
|
-
|
|
214
|
+
if not task_id and not test_name:
|
|
215
|
+
raise RedTeamClientError("Either task_id or test_name must be provided")
|
|
216
|
+
|
|
217
|
+
headers = {}
|
|
218
|
+
if task_id:
|
|
219
|
+
headers["X-Enkrypt-Task-ID"] = task_id
|
|
220
|
+
if test_name:
|
|
221
|
+
headers["X-Enkrypt-Test-Name"] = test_name
|
|
152
222
|
|
|
153
223
|
response = self._request("GET", "/redteam/get-task", headers=headers)
|
|
154
|
-
|
|
224
|
+
if response.get("error"):
|
|
225
|
+
raise RedTeamClientError(response["error"])
|
|
155
226
|
return RedTeamTaskDetails.from_dict(response["data"])
|
|
156
227
|
|
|
157
|
-
def get_result_summary(self, task_id: str):
|
|
228
|
+
def get_result_summary(self, task_id: str = None, test_name: str = None):
|
|
158
229
|
"""
|
|
159
230
|
Get the summary of results for a specific red teaming task.
|
|
160
231
|
|
|
161
232
|
Args:
|
|
162
|
-
task_id (str): The ID of the task to get results for
|
|
233
|
+
task_id (str, optional): The ID of the task to get results for
|
|
234
|
+
test_name (str, optional): The name of the test to get results for
|
|
163
235
|
|
|
164
236
|
Returns:
|
|
165
237
|
dict: The summary of the task results
|
|
238
|
+
|
|
239
|
+
Raises:
|
|
240
|
+
RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
|
|
166
241
|
"""
|
|
167
|
-
|
|
242
|
+
if not task_id and not test_name:
|
|
243
|
+
raise RedTeamClientError("Either task_id or test_name must be provided")
|
|
244
|
+
|
|
245
|
+
headers = {}
|
|
246
|
+
if task_id:
|
|
247
|
+
headers["X-Enkrypt-Task-ID"] = task_id
|
|
248
|
+
if test_name:
|
|
249
|
+
headers["X-Enkrypt-Test-Name"] = test_name
|
|
168
250
|
|
|
169
251
|
response = self._request("GET", "/redteam/results/summary", headers=headers)
|
|
170
|
-
|
|
252
|
+
if response.get("error"):
|
|
253
|
+
raise RedTeamClientError(response["error"])
|
|
254
|
+
print(f"Response: {response}")
|
|
255
|
+
return RedTeamResultSummary.from_dict(response)
|
|
171
256
|
|
|
172
|
-
def get_result_details(self, task_id: str):
|
|
257
|
+
def get_result_details(self, task_id: str = None, test_name: str = None):
|
|
173
258
|
"""
|
|
174
259
|
Get the detailed results for a specific red teaming task.
|
|
175
260
|
|
|
176
261
|
Args:
|
|
177
|
-
task_id (str): The ID of the task to get detailed results for
|
|
262
|
+
task_id (str, optional): The ID of the task to get detailed results for
|
|
263
|
+
test_name (str, optional): The name of the test to get detailed results for
|
|
178
264
|
|
|
179
265
|
Returns:
|
|
180
266
|
dict: The detailed task results
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
|
|
181
270
|
"""
|
|
182
|
-
|
|
183
|
-
|
|
271
|
+
if not task_id and not test_name:
|
|
272
|
+
raise RedTeamClientError("Either task_id or test_name must be provided")
|
|
273
|
+
|
|
274
|
+
headers = {}
|
|
275
|
+
if task_id:
|
|
276
|
+
headers["X-Enkrypt-Task-ID"] = task_id
|
|
277
|
+
if test_name:
|
|
278
|
+
headers["X-Enkrypt-Test-Name"] = test_name
|
|
279
|
+
|
|
184
280
|
response = self._request("GET", "/redteam/results/details", headers=headers)
|
|
185
|
-
|
|
281
|
+
if response.get("error"):
|
|
282
|
+
raise RedTeamClientError(response["error"])
|
|
283
|
+
return RedTeamResultDetails.from_dict(response)
|
|
284
|
+
|
|
285
|
+
def get_task_list(self):
|
|
286
|
+
response = self._request("GET", "/redteam/list-tasks")
|
|
287
|
+
if response.get("error"):
|
|
288
|
+
raise RedTeamClientError(response["error"])
|
|
289
|
+
return RedTeamTaskList.from_dict(response)
|