enkryptai-sdk 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
enkryptai_sdk/models.py CHANGED
@@ -1,6 +1,6 @@
1
- import urllib3
2
- from .dto import ModelConfig, ModelResponse
1
+ from .base import BaseClient
3
2
  from urllib.parse import urlparse, urlsplit
3
+ from .dto import ModelConfig, ModelResponse, ModelCollection
4
4
 
5
5
 
6
6
  class ModelClientError(Exception):
@@ -11,41 +11,11 @@ class ModelClientError(Exception):
11
11
  pass
12
12
 
13
13
 
14
- class ModelClient:
14
+ class ModelClient(BaseClient):
15
15
  def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
16
- self.api_key = api_key
17
- self.base_url = base_url
18
- self.http = urllib3.PoolManager()
19
- self.headers = {"apikey": self.api_key}
20
-
21
- def _request(self, method, endpoint, payload=None, headers=None, **kwargs):
22
- url = self.base_url + endpoint
23
- request_headers = {
24
- "Accept-Encoding": "gzip", # Add required gzip encoding
25
- **self.headers,
26
- }
27
- if headers:
28
- request_headers.update(headers)
29
-
30
- try:
31
- response = self.http.request(method, url, headers=request_headers, **kwargs)
32
-
33
- if response.status >= 400:
34
- error_data = (
35
- response.json()
36
- if response.data
37
- else {"message": f"HTTP {response.status}"}
38
- )
39
- error_message = error_data.get("message", str(error_data))
40
- raise urllib3.exceptions.HTTPError(error_message)
41
- return response.json()
42
- except urllib3.exceptions.HTTPError as e:
43
- return {"error": str(e)}
44
-
45
- def health(self):
46
- return self._request("GET", "/models/health")
16
+ super().__init__(api_key, base_url)
47
17
 
48
- def add_model(self, config: ModelConfig):
18
+ def add_model(self, config: ModelConfig) -> ModelResponse:
49
19
  """
50
20
  Add a new model configuration to the system.
51
21
 
@@ -69,11 +39,15 @@ class ModelClient:
69
39
  base_path = ""
70
40
  remaining_path = ""
71
41
 
72
- # Determine paths based on the endpoint
73
- paths = {
74
- "completions": f"/{remaining_path}" if remaining_path else "",
75
- "chat": "",
76
- }
42
+ if config.model_config.paths:
43
+ paths = config.model_config.paths.to_dict()
44
+ else:
45
+ paths = {
46
+ "completions": (
47
+ f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
48
+ ),
49
+ "chat": f"/{remaining_path}" if remaining_path else "",
50
+ }
77
51
 
78
52
  payload = {
79
53
  "model_saved_name": config.model_saved_name,
@@ -82,11 +56,10 @@ class ModelClient:
82
56
  "model_type": config.model_type,
83
57
  "certifications": config.certifications,
84
58
  "model_config": {
85
- "is_compatible_with": config.model_config.is_compatible_with,
59
+ "model_provider": config.model_config.model_provider,
86
60
  "model_version": config.model_config.model_version,
87
61
  "hosting_type": config.model_config.hosting_type,
88
62
  "model_source": config.model_config.model_source,
89
- "model_provider": config.model_config.model_provider,
90
63
  "system_prompt": config.model_config.system_prompt,
91
64
  "conversation_template": config.model_config.conversation_template,
92
65
  "endpoint": {
@@ -94,7 +67,7 @@ class ModelClient:
94
67
  "host": parsed_url.hostname,
95
68
  "port": parsed_url.port
96
69
  or (443 if parsed_url.scheme == "https" else 80),
97
- "base_path": f"/{base_path}/{paths['completions']}", # Just v1
70
+ "base_path": f"/{base_path}", # Just v1
98
71
  },
99
72
  "paths": paths,
100
73
  "auth_data": {
@@ -118,18 +91,19 @@ class ModelClient:
118
91
  except Exception as e:
119
92
  raise ModelClientError(str(e))
120
93
 
121
- def get_model(self, model_id: str) -> ModelConfig:
94
+ def get_model(self, model_saved_name: str) -> ModelConfig:
122
95
  """
123
- Get model configuration by model ID.
96
+ Get model configuration by model saved name.
124
97
 
125
98
  Args:
126
- model_id (str): ID of the model to retrieve
99
+ model_saved_name (str): Saved name of the model to retrieve
127
100
 
128
101
  Returns:
129
102
  ModelConfig: Configuration object containing model details
130
103
  """
131
- headers = {"X-Enkrypt-Model": model_id}
104
+ headers = {"X-Enkrypt-Model": model_saved_name}
132
105
  response = self._request("GET", "/models/get-model", headers=headers)
106
+ print(response)
133
107
  if response.get("error"):
134
108
  raise ModelClientError(response["error"])
135
109
  return ModelConfig.from_dict(response)
@@ -142,19 +116,111 @@ class ModelClient:
142
116
  dict: Response from the API containing the list of models
143
117
  """
144
118
  try:
145
- return self._request("GET", "/models/list-models")
119
+ response = self._request("GET", "/models/list-models")
120
+ if response.get("error"):
121
+ raise ModelClientError(response["error"])
122
+ return ModelCollection.from_dict(response)
146
123
  except Exception as e:
147
124
  return {"error": str(e)}
148
125
 
149
- def delete_model(self, model_id: str):
126
+ def modify_model(self, config: ModelConfig, old_model_saved_name=None) -> ModelResponse:
127
+ """
128
+ Modify an existing model in the system.
129
+
130
+ Args:
131
+ model_saved_name (str): The saved name of the model to modify
132
+ config (ModelConfig): Configuration object containing model details
133
+
134
+ Returns:
135
+ dict: Response from the API containing the modified model details
136
+ """
137
+ if old_model_saved_name is None:
138
+ old_model_saved_name = config["model_saved_name"]
139
+
140
+ headers = {"Content-Type": "application/json", "X-Enkrypt-Model": old_model_saved_name}
141
+ # print(config)
142
+ config = ModelConfig.from_dict(config)
143
+ # Parse endpoint_url into components
144
+ parsed_url = urlparse(config.model_config.endpoint_url)
145
+ path_parts = parsed_url.path.strip("/").split("/")
146
+
147
+ # Extract base_path and endpoint path
148
+ if len(path_parts) >= 1:
149
+ base_path = path_parts[0] # Usually 'v1'
150
+ remaining_path = "/".join(path_parts[1:]) # The rest of the path
151
+ else:
152
+ base_path = ""
153
+ remaining_path = ""
154
+
155
+ if config.model_config.paths:
156
+ paths = config.model_config.paths.to_dict()
157
+ else:
158
+ # Determine paths based on the endpoint
159
+ paths = {
160
+ "completions": (
161
+ f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
162
+ ),
163
+ "chat": f"/{remaining_path}" if remaining_path else "",
164
+ }
165
+
166
+ payload = {
167
+ "model_saved_name": config.model_saved_name,
168
+ "testing_for": config.testing_for,
169
+ "model_name": config.model_name,
170
+ "model_type": config.model_type,
171
+ "certifications": config.certifications,
172
+ "model_config": {
173
+ "model_provider": config.model_config.model_provider,
174
+ "model_version": config.model_config.model_version,
175
+ "hosting_type": config.model_config.hosting_type,
176
+ "model_source": config.model_config.model_source,
177
+ "system_prompt": config.model_config.system_prompt,
178
+ "conversation_template": config.model_config.conversation_template,
179
+ "endpoint": {
180
+ "scheme": parsed_url.scheme,
181
+ "host": parsed_url.hostname,
182
+ "port": parsed_url.port
183
+ or (443 if parsed_url.scheme == "https" else 80),
184
+ "base_path": f"/{base_path}", # Just v1
185
+ },
186
+ "paths": paths,
187
+ "auth_data": {
188
+ "header_name": config.model_config.auth_data.header_name,
189
+ "header_prefix": config.model_config.auth_data.header_prefix,
190
+ "space_after_prefix": config.model_config.auth_data.space_after_prefix,
191
+ },
192
+ "apikeys": (
193
+ [config.model_config.apikey] if config.model_config.apikey else []
194
+ ),
195
+ "default_request_options": config.model_config.default_request_options,
196
+ },
197
+ }
198
+ try:
199
+ response = self._request(
200
+ "PATCH", "/models/modify-model", headers=headers, json=payload
201
+ )
202
+ if response.get("error"):
203
+ raise ModelClientError(response["error"])
204
+ return ModelResponse.from_dict(response)
205
+ except Exception as e:
206
+ raise ModelClientError(str(e))
207
+
208
+ def delete_model(self, model_saved_name: str) -> ModelResponse:
150
209
  """
151
210
  Delete a specific model from the system.
152
211
 
153
212
  Args:
154
- model_id (str): The identifier or name of the model to delete
213
+ model_saved_name (str): The saved name of the model to delete
155
214
 
156
215
  Returns:
157
216
  dict: Response from the API containing the deletion status
158
217
  """
159
- headers = {"X-Enkrypt-Model": model_id}
160
- return self._request("DELETE", "/models/delete-model", headers=headers)
218
+ headers = {"X-Enkrypt-Model": model_saved_name}
219
+
220
+ try:
221
+ response = self._request("DELETE", "/models/delete-model", headers=headers)
222
+ if response.get("error"):
223
+ raise ModelClientError(response["error"])
224
+ return ModelResponse.from_dict(response)
225
+ except Exception as e:
226
+ raise ModelClientError(str(e))
enkryptai_sdk/red_team.py CHANGED
@@ -1,11 +1,17 @@
1
1
  import urllib3
2
+ from .base import BaseClient
2
3
  from .dto import (
4
+ RedteamHealthResponse,
5
+ RedTeamModelHealthConfig,
6
+ RedteamModelHealthResponse,
3
7
  RedTeamConfig,
8
+ # RedTeamConfigWithSavedModel,
4
9
  RedTeamResponse,
5
10
  RedTeamResultSummary,
6
11
  RedTeamResultDetails,
7
12
  RedTeamTaskStatus,
8
13
  RedTeamTaskDetails,
14
+ RedTeamTaskList,
9
15
  )
10
16
 
11
17
 
@@ -17,47 +23,64 @@ class RedTeamClientError(Exception):
17
23
  pass
18
24
 
19
25
 
20
- class RedTeamClient:
26
+ class RedTeamClient(BaseClient):
21
27
  """
22
28
  A client for interacting with the Red Team API.
23
29
  """
24
30
 
25
31
  def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com"):
26
- self.api_key = api_key
27
- self.base_url = base_url
28
- self.http = urllib3.PoolManager()
29
- self.headers = {"apikey": self.api_key}
30
-
31
- def _request(self, method, endpoint, headers=None, **kwargs):
32
- url = self.base_url + endpoint
33
- request_headers = {
34
- "Accept-Encoding": "gzip", # Add required gzip encoding
35
- **self.headers,
36
- }
37
- if headers:
38
- request_headers.update(headers)
32
+ super().__init__(api_key, base_url)
33
+
34
+ # def get_model(self, model):
35
+ # models = self._request("GET", "/models/list-models")
36
+ # models = models["models"]
37
+ # for _model_data in models:
38
+ # if _model_data["model_saved_name"] == model:
39
+ # return _model_data["model_saved_name"]
40
+ # else:
41
+ # return None
39
42
 
43
+ def get_health(self):
44
+ """
45
+ Get the health status of the service.
46
+ """
40
47
  try:
41
- response = self.http.request(method, url, headers=request_headers, **kwargs)
42
-
43
- if response.status >= 400:
44
- error_data = (
45
- response.json()
46
- if response.data
47
- else {"message": f"HTTP {response.status}"}
48
- )
49
- error_message = error_data.get("message", str(error_data))
50
- raise urllib3.exceptions.HTTPError(error_message)
51
- return response.json()
52
- except urllib3.exceptions.HTTPError as e:
53
- return {"error": str(e)}
54
-
55
- def get_model(self, model):
56
- models = self._request("GET", "/models/list-models")
57
- if model in models:
58
- return model
59
- else:
60
- return None
48
+ response = self._request("GET", "/redteam/health")
49
+ if response.get("error"):
50
+ raise RedTeamClientError(response["error"])
51
+ return RedteamHealthResponse.from_dict(response)
52
+ except Exception as e:
53
+ raise RedTeamClientError(str(e))
54
+
55
+ def check_model_health(self, config: RedTeamModelHealthConfig):
56
+ """
57
+ Get the health status of a model.
58
+ """
59
+ try:
60
+ config = RedTeamModelHealthConfig.from_dict(config)
61
+ response = self._request("POST", "/redteam/model-health", json=config.to_dict())
62
+ # if response.get("error"):
63
+ if response.get("error") not in [None, ""]:
64
+ raise RedTeamClientError(response["error"])
65
+ return RedteamModelHealthResponse.from_dict(response)
66
+ except Exception as e:
67
+ raise RedTeamClientError(str(e))
68
+
69
+ def check_saved_model_health(self, model_saved_name: str):
70
+ """
71
+ Get the health status of a saved model.
72
+ """
73
+ try:
74
+ headers = {
75
+ "X-Enkrypt-Model": model_saved_name,
76
+ }
77
+ response = self._request("POST", "/redteam/model/model-health", headers=headers)
78
+ # if response.get("error"):
79
+ if response.get("error") not in [None, ""]:
80
+ raise RedTeamClientError(response["error"])
81
+ return RedteamModelHealthResponse.from_dict(response)
82
+ except Exception as e:
83
+ raise RedTeamClientError(str(e))
61
84
 
62
85
  def add_task(
63
86
  self,
@@ -75,25 +98,26 @@ class RedTeamClient:
75
98
  # "async": config.async_enabled,
76
99
  "dataset_name": config.dataset_name,
77
100
  "test_name": config.test_name,
78
- "redteam_test_configurations": test_configs,
101
+ "redteam_test_configurations": {
102
+ k: v.to_dict() for k, v in test_configs.items()
103
+ },
79
104
  }
80
105
 
81
- model = config.model_name
82
- saved_model = self.get_model(model)
83
-
106
+ saved_model = config.model_saved_name
84
107
  if saved_model:
85
- print("saved model found")
86
108
  headers = {
87
109
  "X-Enkrypt-Model": saved_model,
88
110
  "Content-Type": "application/json",
89
111
  }
90
- payload["location"] = {"storage": "supabase", "container_name": "supabase"}
91
- return self._request(
112
+ response = self._request(
92
113
  "POST",
93
114
  "/redteam/v2/model/add-task",
94
115
  headers=headers,
95
116
  json=payload,
96
117
  )
118
+ if response.get("error"):
119
+ raise RedTeamClientError(response["error"])
120
+ return RedTeamResponse.from_dict(response)
97
121
  elif config.target_model_configuration:
98
122
  payload["target_model_configuration"] = (
99
123
  config.target_model_configuration.to_dict()
@@ -112,84 +136,154 @@ class RedTeamClient:
112
136
  "Please use a saved model or provide a target model configuration"
113
137
  )
114
138
 
115
- def status(self, task_id: str):
139
+ def status(self, task_id: str = None, test_name: str = None):
116
140
  """
117
141
  Get the status of a specific red teaming task.
118
142
 
119
143
  Args:
120
- task_id (str): The ID of the task to check status
144
+ task_id (str, optional): The ID of the task to check status
145
+ test_name (str, optional): The name of the test to check status
121
146
 
122
147
  Returns:
123
148
  dict: The task status information
149
+
150
+ Raises:
151
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
124
152
  """
125
- headers = {"X-Enkrypt-Task-ID": task_id}
153
+ if not task_id and not test_name:
154
+ raise RedTeamClientError("Either task_id or test_name must be provided")
155
+
156
+ headers = {}
157
+ if task_id:
158
+ headers["X-Enkrypt-Task-ID"] = task_id
159
+ if test_name:
160
+ headers["X-Enkrypt-Test-Name"] = test_name
126
161
 
127
162
  response = self._request("GET", "/redteam/task-status", headers=headers)
128
163
  if response.get("error"):
129
164
  raise RedTeamClientError(response["error"])
130
165
  return RedTeamTaskStatus.from_dict(response)
131
166
 
132
- def cancel_task(self, task_id: str):
167
+ def cancel_task(self, task_id: str = None, test_name: str = None):
133
168
  """
134
169
  Cancel a specific red teaming task.
135
170
 
136
171
  Args:
137
- task_id (str): The ID of the task to cancel
172
+ task_id (str, optional): The ID of the task to cancel
173
+ test_name (str, optional): The name of the test to cancel
174
+
175
+ Returns:
176
+ dict: The cancellation response
177
+
178
+ Raises:
179
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
138
180
  """
139
181
  raise RedTeamClientError(
140
182
  "This feature is currently under development. Please check our documentation "
141
183
  "at https://docs.enkrypt.ai for updates or contact support@enkrypt.ai for assistance."
142
184
  )
143
185
 
144
- def get_task(self, task_id: str):
186
+ if not task_id and not test_name:
187
+ raise RedTeamClientError("Either task_id or test_name must be provided")
188
+
189
+ headers = {}
190
+ if task_id:
191
+ headers["X-Enkrypt-Task-ID"] = task_id
192
+ if test_name:
193
+ headers["X-Enkrypt-Test-Name"] = test_name
194
+
195
+ response = self._request("POST", "/redteam/cancel-task", headers=headers)
196
+ if response.get("error"):
197
+ raise RedTeamClientError(response["error"])
198
+ return response
199
+
200
+ def get_task(self, task_id: str = None, test_name: str = None):
145
201
  """
146
202
  Get the status and details of a specific red teaming task.
147
203
 
148
204
  Args:
149
- task_id (str): The ID of the task to retrieve
205
+ task_id (str, optional): The ID of the task to retrieve
206
+ test_name (str, optional): The name of the test to retrieve
150
207
 
151
208
  Returns:
152
209
  dict: The task details and status
210
+
211
+ Raises:
212
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
153
213
  """
154
- headers = {"X-Enkrypt-Task-ID": task_id}
214
+ if not task_id and not test_name:
215
+ raise RedTeamClientError("Either task_id or test_name must be provided")
216
+
217
+ headers = {}
218
+ if task_id:
219
+ headers["X-Enkrypt-Task-ID"] = task_id
220
+ if test_name:
221
+ headers["X-Enkrypt-Test-Name"] = test_name
155
222
 
156
223
  response = self._request("GET", "/redteam/get-task", headers=headers)
157
224
  if response.get("error"):
158
225
  raise RedTeamClientError(response["error"])
159
- if response.get("data").get("job_id "):
160
- response["data"]["task_id"] = response["data"].pop("job_id")
161
226
  return RedTeamTaskDetails.from_dict(response["data"])
162
227
 
163
- def get_result_summary(self, task_id: str):
228
+ def get_result_summary(self, task_id: str = None, test_name: str = None):
164
229
  """
165
230
  Get the summary of results for a specific red teaming task.
166
231
 
167
232
  Args:
168
- task_id (str): The ID of the task to get results for
233
+ task_id (str, optional): The ID of the task to get results for
234
+ test_name (str, optional): The name of the test to get results for
169
235
 
170
236
  Returns:
171
237
  dict: The summary of the task results
238
+
239
+ Raises:
240
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
172
241
  """
173
- headers = {"X-Enkrypt-Task-ID": task_id}
242
+ if not task_id and not test_name:
243
+ raise RedTeamClientError("Either task_id or test_name must be provided")
244
+
245
+ headers = {}
246
+ if task_id:
247
+ headers["X-Enkrypt-Task-ID"] = task_id
248
+ if test_name:
249
+ headers["X-Enkrypt-Test-Name"] = test_name
174
250
 
175
251
  response = self._request("GET", "/redteam/results/summary", headers=headers)
176
252
  if response.get("error"):
177
253
  raise RedTeamClientError(response["error"])
178
- return RedTeamResultSummary.from_dict(response["summary"])
254
+ print(f"Response: {response}")
255
+ return RedTeamResultSummary.from_dict(response)
179
256
 
180
- def get_result_details(self, task_id: str):
257
+ def get_result_details(self, task_id: str = None, test_name: str = None):
181
258
  """
182
259
  Get the detailed results for a specific red teaming task.
183
260
 
184
261
  Args:
185
- task_id (str): The ID of the task to get detailed results for
262
+ task_id (str, optional): The ID of the task to get detailed results for
263
+ test_name (str, optional): The name of the test to get detailed results for
186
264
 
187
265
  Returns:
188
266
  dict: The detailed task results
267
+
268
+ Raises:
269
+ RedTeamClientError: If neither task_id nor test_name is provided, or if there's an error from the API
189
270
  """
190
- # TODO: Update the response to be updated
191
- headers = {"X-Enkrypt-Task-ID": task_id}
271
+ if not task_id and not test_name:
272
+ raise RedTeamClientError("Either task_id or test_name must be provided")
273
+
274
+ headers = {}
275
+ if task_id:
276
+ headers["X-Enkrypt-Task-ID"] = task_id
277
+ if test_name:
278
+ headers["X-Enkrypt-Test-Name"] = test_name
279
+
192
280
  response = self._request("GET", "/redteam/results/details", headers=headers)
193
281
  if response.get("error"):
194
282
  raise RedTeamClientError(response["error"])
195
- return RedTeamResultDetails.from_dict(response["details"])
283
+ return RedTeamResultDetails.from_dict(response)
284
+
285
+ def get_task_list(self):
286
+ response = self._request("GET", "/redteam/list-tasks")
287
+ if response.get("error"):
288
+ raise RedTeamClientError(response["error"])
289
+ return RedTeamTaskList.from_dict(response)