enkryptai-sdk 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
enkryptai_sdk/models.py CHANGED
@@ -1,44 +1,21 @@
1
- import urllib3
2
- from .dto import ModelConfig, ModelDetailConfig
1
+ from .base import BaseClient
3
2
  from urllib.parse import urlparse, urlsplit
3
+ from .dto import ModelConfig, ModelResponse, ModelCollection
4
4
 
5
5
 
6
- class ModelClient:
7
- def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
8
- self.api_key = api_key
9
- self.base_url = base_url
10
- self.http = urllib3.PoolManager()
11
- self.headers = {"apikey": self.api_key}
12
-
13
- def _request(self, method, endpoint, payload=None, headers=None, **kwargs):
14
- url = self.base_url + endpoint
15
- request_headers = {
16
- "Accept-Encoding": "gzip", # Add required gzip encoding
17
- **self.headers,
18
- }
19
- if headers:
20
- request_headers.update(headers)
6
+ class ModelClientError(Exception):
7
+ """
8
+ A custom exception for ModelClient errors.
9
+ """
10
+
11
+ pass
21
12
 
22
- try:
23
- response = self.http.request(method, url, headers=request_headers, **kwargs)
24
-
25
- if response.status >= 400:
26
- error_response = (
27
- response.json()
28
- if response.data
29
- else {"message": f"HTTP {response.status}"}
30
- )
31
- raise urllib3.exceptions.HTTPError(
32
- f"HTTP {response.status}: {error_response}"
33
- )
34
- return response.json()
35
- except urllib3.exceptions.HTTPError as e:
36
- return {"error": str(e)}
37
13
 
38
- def health(self):
39
- return self._request("GET", "/models/health")
14
+ class ModelClient(BaseClient):
15
+ def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
16
+ super().__init__(api_key, base_url)
40
17
 
41
- def add_model(self, config: ModelConfig):
18
+ def add_model(self, config: ModelConfig) -> ModelResponse:
42
19
  """
43
20
  Add a new model configuration to the system.
44
21
 
@@ -62,11 +39,15 @@ class ModelClient:
62
39
  base_path = ""
63
40
  remaining_path = ""
64
41
 
65
- # Determine paths based on the endpoint
66
- paths = {
67
- "completions": f"/{remaining_path}" if remaining_path else "",
68
- "chat": "",
69
- }
42
+ if config.model_config.paths:
43
+ paths = config.model_config.paths.to_dict()
44
+ else:
45
+ paths = {
46
+ "completions": (
47
+ f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
48
+ ),
49
+ "chat": f"/{remaining_path}" if remaining_path else "",
50
+ }
70
51
 
71
52
  payload = {
72
53
  "model_saved_name": config.model_saved_name,
@@ -75,11 +56,10 @@ class ModelClient:
75
56
  "model_type": config.model_type,
76
57
  "certifications": config.certifications,
77
58
  "model_config": {
78
- "is_compatible_with": config.model_config.is_compatible_with,
59
+ "model_provider": config.model_config.model_provider,
79
60
  "model_version": config.model_config.model_version,
80
61
  "hosting_type": config.model_config.hosting_type,
81
62
  "model_source": config.model_config.model_source,
82
- "model_provider": config.model_config.model_provider,
83
63
  "system_prompt": config.model_config.system_prompt,
84
64
  "conversation_template": config.model_config.conversation_template,
85
65
  "endpoint": {
@@ -87,7 +67,7 @@ class ModelClient:
87
67
  "host": parsed_url.hostname,
88
68
  "port": parsed_url.port
89
69
  or (443 if parsed_url.scheme == "https" else 80),
90
- "base_path": f"/{base_path}/{paths['completions']}", # Just v1
70
+ "base_path": f"/{base_path}", # Just v1
91
71
  },
92
72
  "paths": paths,
93
73
  "auth_data": {
@@ -101,21 +81,31 @@ class ModelClient:
101
81
  "default_request_options": config.model_config.default_request_options,
102
82
  },
103
83
  }
104
- print(payload)
105
- return self._request("POST", "/models/add-model", headers=headers, json=payload)
84
+ try:
85
+ response = self._request(
86
+ "POST", "/models/add-model", headers=headers, json=payload
87
+ )
88
+ if response.get("error"):
89
+ raise ModelClientError(response["error"])
90
+ return ModelResponse.from_dict(response)
91
+ except Exception as e:
92
+ raise ModelClientError(str(e))
106
93
 
107
- def get_model(self, model_id: str) -> ModelConfig:
94
+ def get_model(self, model_saved_name: str) -> ModelConfig:
108
95
  """
109
- Get model configuration by model ID.
96
+ Get model configuration by model saved name.
110
97
 
111
98
  Args:
112
- model_id (str): ID of the model to retrieve
99
+ model_saved_name (str): Saved name of the model to retrieve
113
100
 
114
101
  Returns:
115
102
  ModelConfig: Configuration object containing model details
116
103
  """
117
- headers = {"X-Enkrypt-Model": model_id}
104
+ headers = {"X-Enkrypt-Model": model_saved_name}
118
105
  response = self._request("GET", "/models/get-model", headers=headers)
106
+ print(response)
107
+ if response.get("error"):
108
+ raise ModelClientError(response["error"])
119
109
  return ModelConfig.from_dict(response)
120
110
 
121
111
  def get_model_list(self):
@@ -126,19 +116,111 @@ class ModelClient:
126
116
  dict: Response from the API containing the list of models
127
117
  """
128
118
  try:
129
- return self._request("GET", "/models/list-models")
119
+ response = self._request("GET", "/models/list-models")
120
+ if response.get("error"):
121
+ raise ModelClientError(response["error"])
122
+ return ModelCollection.from_dict(response)
130
123
  except Exception as e:
131
124
  return {"error": str(e)}
132
125
 
133
- def delete_model(self, model_id: str):
126
+ def modify_model(self, config: ModelConfig, old_model_saved_name=None) -> ModelResponse:
127
+ """
128
+ Modify an existing model in the system.
129
+
130
+ Args:
131
+ model_saved_name (str): The saved name of the model to modify
132
+ config (ModelConfig): Configuration object containing model details
133
+
134
+ Returns:
135
+ dict: Response from the API containing the modified model details
136
+ """
137
+ if old_model_saved_name is None:
138
+ old_model_saved_name = config["model_saved_name"]
139
+
140
+ headers = {"Content-Type": "application/json", "X-Enkrypt-Model": old_model_saved_name}
141
+ # print(config)
142
+ config = ModelConfig.from_dict(config)
143
+ # Parse endpoint_url into components
144
+ parsed_url = urlparse(config.model_config.endpoint_url)
145
+ path_parts = parsed_url.path.strip("/").split("/")
146
+
147
+ # Extract base_path and endpoint path
148
+ if len(path_parts) >= 1:
149
+ base_path = path_parts[0] # Usually 'v1'
150
+ remaining_path = "/".join(path_parts[1:]) # The rest of the path
151
+ else:
152
+ base_path = ""
153
+ remaining_path = ""
154
+
155
+ if config.model_config.paths:
156
+ paths = config.model_config.paths.to_dict()
157
+ else:
158
+ # Determine paths based on the endpoint
159
+ paths = {
160
+ "completions": (
161
+ f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
162
+ ),
163
+ "chat": f"/{remaining_path}" if remaining_path else "",
164
+ }
165
+
166
+ payload = {
167
+ "model_saved_name": config.model_saved_name,
168
+ "testing_for": config.testing_for,
169
+ "model_name": config.model_name,
170
+ "model_type": config.model_type,
171
+ "certifications": config.certifications,
172
+ "model_config": {
173
+ "model_provider": config.model_config.model_provider,
174
+ "model_version": config.model_config.model_version,
175
+ "hosting_type": config.model_config.hosting_type,
176
+ "model_source": config.model_config.model_source,
177
+ "system_prompt": config.model_config.system_prompt,
178
+ "conversation_template": config.model_config.conversation_template,
179
+ "endpoint": {
180
+ "scheme": parsed_url.scheme,
181
+ "host": parsed_url.hostname,
182
+ "port": parsed_url.port
183
+ or (443 if parsed_url.scheme == "https" else 80),
184
+ "base_path": f"/{base_path}", # Just v1
185
+ },
186
+ "paths": paths,
187
+ "auth_data": {
188
+ "header_name": config.model_config.auth_data.header_name,
189
+ "header_prefix": config.model_config.auth_data.header_prefix,
190
+ "space_after_prefix": config.model_config.auth_data.space_after_prefix,
191
+ },
192
+ "apikeys": (
193
+ [config.model_config.apikey] if config.model_config.apikey else []
194
+ ),
195
+ "default_request_options": config.model_config.default_request_options,
196
+ },
197
+ }
198
+ try:
199
+ response = self._request(
200
+ "PATCH", "/models/modify-model", headers=headers, json=payload
201
+ )
202
+ if response.get("error"):
203
+ raise ModelClientError(response["error"])
204
+ return ModelResponse.from_dict(response)
205
+ except Exception as e:
206
+ raise ModelClientError(str(e))
207
+
208
+ def delete_model(self, model_saved_name: str) -> ModelResponse:
134
209
  """
135
210
  Delete a specific model from the system.
136
211
 
137
212
  Args:
138
- model_id (str): The identifier or name of the model to delete
213
+ model_saved_name (str): The saved name of the model to delete
139
214
 
140
215
  Returns:
141
216
  dict: Response from the API containing the deletion status
142
217
  """
143
- headers = {"X-Enkrypt-Model": model_id}
144
- return self._request("DELETE", "/models/delete-model", headers=headers)
218
+ headers = {"X-Enkrypt-Model": model_saved_name}
219
+
220
+ try:
221
+ response = self._request("DELETE", "/models/delete-model", headers=headers)
222
+ if response.get("error"):
223
+ raise ModelClientError(response["error"])
224
+ return ModelResponse.from_dict(response)
225
+ except Exception as e:
226
+ raise ModelClientError(str(e))
enkryptai_sdk/red_team.py CHANGED
@@ -1,15 +1,21 @@
1
1
  import urllib3
2
+ from .base import BaseClient
2
3
  from .dto import (
4
+ RedteamHealthResponse,
5
+ RedTeamModelHealthConfig,
6
+ RedteamModelHealthResponse,
3
7
  RedTeamConfig,
8
+ # RedTeamConfigWithSavedModel,
4
9
  RedTeamResponse,
5
10
  RedTeamResultSummary,
6
11
  RedTeamResultDetails,
7
12
  RedTeamTaskStatus,
8
13
  RedTeamTaskDetails,
14
+ RedTeamTaskList,
9
15
  )
10
16
 
11
17
 
12
- class RedTeamError(Exception):
18
+ class RedTeamClientError(Exception):
13
19
  """
14
20
  A custom exception for Red Team errors.
15
21
  """
@@ -17,48 +23,64 @@ class RedTeamError(Exception):
17
23
  pass
18
24
 
19
25
 
20
- class RedTeamClient:
26
+ class RedTeamClient(BaseClient):
21
27
  """
22
28
  A client for interacting with the Red Team API.
23
29
  """
24
30
 
25
31
  def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com"):
26
- self.api_key = api_key
27
- self.base_url = base_url
28
- self.http = urllib3.PoolManager()
29
- self.headers = {"apikey": self.api_key}
30
-
31
- def _request(self, method, endpoint, headers=None, **kwargs):
32
- url = self.base_url + endpoint
33
- request_headers = {
34
- "Accept-Encoding": "gzip", # Add required gzip encoding
35
- **self.headers,
36
- }
37
- if headers:
38
- request_headers.update(headers)
32
+ super().__init__(api_key, base_url)
33
+
34
+ # def get_model(self, model):
35
+ # models = self._request("GET", "/models/list-models")
36
+ # models = models["models"]
37
+ # for _model_data in models:
38
+ # if _model_data["model_saved_name"] == model:
39
+ # return _model_data["model_saved_name"]
40
+ # else:
41
+ # return None
39
42
 
43
+ def get_health(self):
44
+ """
45
+ Get the health status of the service.
46
+ """
40
47
  try:
41
- response = self.http.request(method, url, headers=request_headers, **kwargs)
42
- if response.status >= 400:
43
- error_response = (
44
- response.json()
45
- if response.data
46
- else {"message": f"HTTP {response.status}"}
47
- )
48
- raise urllib3.exceptions.HTTPError(
49
- f"HTTP {response.status}: {error_response}"
50
- )
51
- return response.json()
52
- except urllib3.exceptions.HTTPError as e:
53
- print(f"Request failed: {e}")
54
- return {"error": str(e)}
55
-
56
- def get_model(self, model):
57
- models = self._request("GET", "/models/list-models")
58
- if model in models:
59
- return model
60
- else:
61
- return None
48
+ response = self._request("GET", "/redteam/health")
49
+ if response.get("error"):
50
+ raise RedTeamClientError(response["error"])
51
+ return RedteamHealthResponse.from_dict(response)
52
+ except Exception as e:
53
+ raise RedTeamClientError(str(e))
54
+
55
+ def check_model_health(self, config: RedTeamModelHealthConfig):
56
+ """
57
+ Get the health status of a model.
58
+ """
59
+ try:
60
+ config = RedTeamModelHealthConfig.from_dict(config)
61
+ response = self._request("POST", "/redteam/model-health", json=config.to_dict())
62
+ # if response.get("error"):
63
+ if response.get("error") not in [None, ""]:
64
+ raise RedTeamClientError(response["error"])
65
+ return RedteamModelHealthResponse.from_dict(response)
66
+ except Exception as e:
67
+ raise RedTeamClientError(str(e))
68
+
69
+ def check_saved_model_health(self, model_saved_name: str):
70
+ """
71
+ Get the health status of a saved model.
72
+ """
73
+ try:
74
+ headers = {
75
+ "X-Enkrypt-Model": model_saved_name,
76
+ }
77
+ response = self._request("POST", "/redteam/model/model-health", headers=headers)
78
+ # if response.get("error"):
79
+ if response.get("error") not in [None, ""]:
80
+ raise RedTeamClientError(response["error"])
81
+ return RedteamModelHealthResponse.from_dict(response)
82
+ except Exception as e:
83
+ raise RedTeamClientError(str(e))
62
84
 
63
85
  def add_task(
64
86
  self,
@@ -76,25 +98,26 @@ class RedTeamClient:
76
98
  # "async": config.async_enabled,
77
99
  "dataset_name": config.dataset_name,
78
100
  "test_name": config.test_name,
79
- "redteam_test_configurations": test_configs,
101
+ "redteam_test_configurations": {
102
+ k: v.to_dict() for k, v in test_configs.items()
103
+ },
80
104
  }
81
105
 
82
- model = config.model_name
83
- saved_model = self.get_model(model)
84
-
106
+ saved_model = config.model_saved_name
85
107
  if saved_model:
86
- print("saved model found")
87
108
  headers = {
88
109
  "X-Enkrypt-Model": saved_model,
89
110
  "Content-Type": "application/json",
90
111
  }
91
- payload["location"] = {"storage": "supabase", "container_name": "supabase"}
92
- return self._request(
112
+ response = self._request(
93
113
  "POST",
94
114
  "/redteam/v2/model/add-task",
95
115
  headers=headers,
96
116
  json=payload,
97
117
  )
118
+ if response.get("error"):
119
+ raise RedTeamClientError(response["error"])
120
+ return RedTeamResponse.from_dict(response)
98
121
  elif config.target_model_configuration:
99
122
  payload["target_model_configuration"] = (
100
123
  config.target_model_configuration.to_dict()
@@ -105,81 +128,162 @@ class RedTeamClient:
105
128
  "/redteam/v2/add-task",
106
129
  json=payload,
107
130
  )
131
+ if response.get("error"):
132
+ raise RedTeamClientError(response["error"])
108
133
  return RedTeamResponse.from_dict(response)
109
134
  else:
110
- raise RedTeamError(
135
+ raise RedTeamClientError(
111
136
  "Please use a saved model or provide a target model configuration"
112
137
  )
113
138
 
114
- def status(self, task_id: str):
139
+ def status(self, task_id: str = None, test_name: str = None):
115
140
  """
116
141
  Get the status of a specific red teaming task.
117
142
 
118
143
  Args:
119
- task_id (str): The ID of the task to check status
144
+ task_id (str, optional): The ID of the task to check status
145
+ test_name (str, optional): The name of the test to check status
120
146
 
121
147
  Returns:
122
148
  dict: The task status information
149
+
150
+ Raises:
151
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
123
152
  """
124
- headers = {"X-Enkrypt-Task-ID": task_id}
153
+ if not task_id and not test_name:
154
+ raise RedTeamClientError("Either task_id or test_name must be provided")
155
+
156
+ headers = {}
157
+ if task_id:
158
+ headers["X-Enkrypt-Task-ID"] = task_id
159
+ if test_name:
160
+ headers["X-Enkrypt-Test-Name"] = test_name
125
161
 
126
162
  response = self._request("GET", "/redteam/task-status", headers=headers)
163
+ if response.get("error"):
164
+ raise RedTeamClientError(response["error"])
127
165
  return RedTeamTaskStatus.from_dict(response)
128
166
 
129
- def cancel_task(self, task_id: str):
167
+ def cancel_task(self, task_id: str = None, test_name: str = None):
130
168
  """
131
169
  Cancel a specific red teaming task.
132
170
 
133
171
  Args:
134
- task_id (str): The ID of the task to cancel
172
+ task_id (str, optional): The ID of the task to cancel
173
+ test_name (str, optional): The name of the test to cancel
174
+
175
+ Returns:
176
+ dict: The cancellation response
177
+
178
+ Raises:
179
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
135
180
  """
136
- raise RedTeamError(
181
+ raise RedTeamClientError(
137
182
  "This feature is currently under development. Please check our documentation "
138
183
  "at https://docs.enkrypt.ai for updates or contact support@enkrypt.ai for assistance."
139
184
  )
140
185
 
141
- def get_task(self, task_id: str):
186
+ if not task_id and not test_name:
187
+ raise RedTeamClientError("Either task_id or test_name must be provided")
188
+
189
+ headers = {}
190
+ if task_id:
191
+ headers["X-Enkrypt-Task-ID"] = task_id
192
+ if test_name:
193
+ headers["X-Enkrypt-Test-Name"] = test_name
194
+
195
+ response = self._request("POST", "/redteam/cancel-task", headers=headers)
196
+ if response.get("error"):
197
+ raise RedTeamClientError(response["error"])
198
+ return response
199
+
200
+ def get_task(self, task_id: str = None, test_name: str = None):
142
201
  """
143
202
  Get the status and details of a specific red teaming task.
144
203
 
145
204
  Args:
146
- task_id (str): The ID of the task to retrieve
205
+ task_id (str, optional): The ID of the task to retrieve
206
+ test_name (str, optional): The name of the test to retrieve
147
207
 
148
208
  Returns:
149
209
  dict: The task details and status
210
+
211
+ Raises:
212
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
150
213
  """
151
- headers = {"X-Enkrypt-Task-ID": task_id}
214
+ if not task_id and not test_name:
215
+ raise RedTeamClientError("Either task_id or test_name must be provided")
216
+
217
+ headers = {}
218
+ if task_id:
219
+ headers["X-Enkrypt-Task-ID"] = task_id
220
+ if test_name:
221
+ headers["X-Enkrypt-Test-Name"] = test_name
152
222
 
153
223
  response = self._request("GET", "/redteam/get-task", headers=headers)
154
- response["data"]["task_id"] = response["data"].pop("job_id")
224
+ if response.get("error"):
225
+ raise RedTeamClientError(response["error"])
155
226
  return RedTeamTaskDetails.from_dict(response["data"])
156
227
 
157
- def get_result_summary(self, task_id: str):
228
+ def get_result_summary(self, task_id: str = None, test_name: str = None):
158
229
  """
159
230
  Get the summary of results for a specific red teaming task.
160
231
 
161
232
  Args:
162
- task_id (str): The ID of the task to get results for
233
+ task_id (str, optional): The ID of the task to get results for
234
+ test_name (str, optional): The name of the test to get results for
163
235
 
164
236
  Returns:
165
237
  dict: The summary of the task results
238
+
239
+ Raises:
240
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
166
241
  """
167
- headers = {"X-Enkrypt-Task-ID": task_id}
242
+ if not task_id and not test_name:
243
+ raise RedTeamClientError("Either task_id or test_name must be provided")
244
+
245
+ headers = {}
246
+ if task_id:
247
+ headers["X-Enkrypt-Task-ID"] = task_id
248
+ if test_name:
249
+ headers["X-Enkrypt-Test-Name"] = test_name
168
250
 
169
251
  response = self._request("GET", "/redteam/results/summary", headers=headers)
170
- return RedTeamResultSummary.from_dict(response["summary"])
252
+ if response.get("error"):
253
+ raise RedTeamClientError(response["error"])
254
+ print(f"Response: {response}")
255
+ return RedTeamResultSummary.from_dict(response)
171
256
 
172
- def get_result_details(self, task_id: str):
257
+ def get_result_details(self, task_id: str = None, test_name: str = None):
173
258
  """
174
259
  Get the detailed results for a specific red teaming task.
175
260
 
176
261
  Args:
177
- task_id (str): The ID of the task to get detailed results for
262
+ task_id (str, optional): The ID of the task to get detailed results for
263
+ test_name (str, optional): The name of the test to get detailed results for
178
264
 
179
265
  Returns:
180
266
  dict: The detailed task results
267
+
268
+ Raises:
269
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
181
270
  """
182
- # TODO: Update the response to be updated
183
- headers = {"X-Enkrypt-Task-ID": task_id}
271
+ if not task_id and not test_name:
272
+ raise RedTeamClientError("Either task_id or test_name must be provided")
273
+
274
+ headers = {}
275
+ if task_id:
276
+ headers["X-Enkrypt-Task-ID"] = task_id
277
+ if test_name:
278
+ headers["X-Enkrypt-Test-Name"] = test_name
279
+
184
280
  response = self._request("GET", "/redteam/results/details", headers=headers)
185
- return RedTeamResultDetails.from_dict(response["details"])
281
+ if response.get("error"):
282
+ raise RedTeamClientError(response["error"])
283
+ return RedTeamResultDetails.from_dict(response)
284
+
285
+ def get_task_list(self):
286
+ response = self._request("GET", "/redteam/list-tasks")
287
+ if response.get("error"):
288
+ raise RedTeamClientError(response["error"])
289
+ return RedTeamTaskList.from_dict(response)