enkryptai-sdk 1.0.10__tar.gz → 1.0.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {enkryptai_sdk-1.0.10/src/enkryptai_sdk.egg-info → enkryptai_sdk-1.0.12}/PKG-INFO +1 -1
  2. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/setup.py +1 -1
  3. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/base.py +10 -7
  4. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/base.py +12 -1
  5. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/models.py +17 -0
  6. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/red_team.py +8 -0
  7. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/models.py +12 -6
  8. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/red_team.py +8 -13
  9. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12/src/enkryptai_sdk.egg-info}/PKG-INFO +1 -1
  10. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_all_v2.py +137 -42
  11. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_model.py +29 -4
  12. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/LICENSE +0 -0
  13. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/README.md +0 -0
  14. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/setup.cfg +0 -0
  15. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/__init__.py +0 -0
  16. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/ai_proxy.py +0 -0
  17. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/coc.py +0 -0
  18. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/config.py +0 -0
  19. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/datasets.py +0 -0
  20. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/deployments.py +0 -0
  21. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/__init__.py +0 -0
  22. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/ai_proxy.py +0 -0
  23. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/coc.py +0 -0
  24. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/datasets.py +0 -0
  25. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/deployments.py +0 -0
  26. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/dto/guardrails.py +0 -0
  27. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/evals.py +0 -0
  28. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/guardrails.py +0 -0
  29. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/guardrails_old.py +0 -0
  30. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk/response.py +0 -0
  31. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk.egg-info/SOURCES.txt +0 -0
  32. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk.egg-info/dependency_links.txt +0 -0
  33. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/src/enkryptai_sdk.egg-info/top_level.txt +0 -0
  34. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_ai_proxy.py +0 -0
  35. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_all.py +0 -0
  36. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_basic.py +0 -0
  37. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_coc.py +0 -0
  38. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_datasets.py +0 -0
  39. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_deployments.py +0 -0
  40. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_detect_policy.py +0 -0
  41. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_guardrails.py +0 -0
  42. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_injection_attack.py +0 -0
  43. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_openai.py +0 -0
  44. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_policy_violation.py +0 -0
  45. {enkryptai_sdk-1.0.10 → enkryptai_sdk-1.0.12}/tests/test_redteam.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: enkryptai-sdk
3
- Version: 1.0.10
3
+ Version: 1.0.12
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -9,7 +9,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as fh:
9
9
  setup(
10
10
  name="enkryptai-sdk", # This is the name of your package on PyPI
11
11
  # NOTE: Also change this in .github/workflows/test.yaml
12
- version="1.0.10", # Update this for new versions
12
+ version="1.0.12", # Update this for new versions
13
13
  description="A Python SDK with guardrails and red teaming functionality for API interactions",
14
14
  long_description=long_description,
15
15
  long_description_content_type="text/markdown",
@@ -72,14 +72,17 @@ class BaseClient:
72
72
  )
73
73
 
74
74
  if response.status >= 400:
75
- error_data = (
76
- response.json()
77
- if response.data
78
- else {"message": f"HTTP {response.status}"}
79
- )
80
- error_message = error_data.get("message", str(error_data))
75
+ try:
76
+ error_data = response.json()
77
+ error_message = error_data.get("message", str(error_data))
78
+ except:
79
+ error_message = response.data.decode('utf-8') if response.data else f"HTTP {response.status}"
81
80
  raise urllib3.exceptions.HTTPError(error_message)
82
- return response.json()
81
+
82
+ try:
83
+ return response.json()
84
+ except:
85
+ return {"error": response.data.decode('utf-8') if response.data else "Invalid JSON response"}
83
86
  except urllib3.exceptions.HTTPError as e:
84
87
  return {"error": str(e)}
85
88
 
@@ -52,7 +52,18 @@ class BaseDTO:
52
52
 
53
53
  def to_dict(self) -> Dict[str, Any]:
54
54
  """Convert the instance to a dictionary."""
55
- d = {k: v for k, v in self.__dict__.items() if k != "_extra_fields"}
55
+ d = {}
56
+ for k, v in self.__dict__.items():
57
+ if k == "_extra_fields":
58
+ continue
59
+ if hasattr(v, "to_dict"):
60
+ d[k] = v.to_dict()
61
+ elif isinstance(v, list):
62
+ d[k] = [item.to_dict() if hasattr(item, "to_dict") else item for item in v]
63
+ elif isinstance(v, dict):
64
+ d[k] = {key: val.to_dict() if hasattr(val, "to_dict") else val for key, val in v.items()}
65
+ else:
66
+ d[k] = v
56
67
  d.update(self._extra_fields)
57
68
  return d
58
69
 
@@ -37,6 +37,8 @@ class ModelProviders(str, Enum):
37
37
  OPENAI_COMPATIBLE = "openai_compatible"
38
38
  COHERE_COMPATIBLE = "cohere_compatible"
39
39
  ANTHROPIC_COMPATIBLE = "anthropic_compatible"
40
+ CUSTOM = "custom"
41
+ HR = "hr"
40
42
 
41
43
 
42
44
  @dataclass
@@ -247,6 +249,14 @@ class ModelConfigDetails(BaseDTO):
247
249
  d = super().to_dict()
248
250
  # Handle AuthData specifically
249
251
  d["auth_data"] = self.auth_data.to_dict()
252
+ # Handle CustomHeader list
253
+ d["custom_headers"] = [header.to_dict() for header in self.custom_headers]
254
+ # Handle ModelProviders enum
255
+ if isinstance(d["model_provider"], ModelProviders):
256
+ d["model_provider"] = d["model_provider"].value
257
+ # Handle input/output modalities
258
+ d["input_modalities"] = [m.value for m in self.input_modalities]
259
+ d["output_modalities"] = [m.value for m in self.output_modalities]
250
260
  return d
251
261
 
252
262
  def to_json(self):
@@ -289,6 +299,13 @@ class ModelConfig(BaseDTO):
289
299
 
290
300
  return cls(**data, model_config=model_config)
291
301
 
302
+ def to_dict(self) -> dict:
303
+ """Convert the ModelConfig instance to a dictionary."""
304
+ d = super().to_dict()
305
+ # Handle nested ModelConfigDetails
306
+ d["model_config"] = self.model_config.to_dict()
307
+ return d
308
+
292
309
  @classmethod
293
310
  def __str__(self):
294
311
  """String representation of the ModelConfig."""
@@ -367,7 +367,15 @@ class TargetModelConfiguration(BaseDTO):
367
367
 
368
368
  @classmethod
369
369
  def from_dict(cls, data: dict):
370
+ data = data.copy()
371
+ if "custom_headers" in data:
372
+ data["custom_headers"] = [CustomHeader.from_dict(header) for header in data["custom_headers"]]
370
373
  return cls(**data)
374
+
375
+ def to_dict(self) -> dict:
376
+ d = asdict(self)
377
+ d["custom_headers"] = [header.to_dict() for header in self.custom_headers]
378
+ return d
371
379
 
372
380
 
373
381
  @dataclass
@@ -49,6 +49,9 @@ class ModelClient(BaseClient):
49
49
  "chat": f"/{remaining_path}" if remaining_path else "",
50
50
  }
51
51
 
52
+ # Convert custom_headers to list of dictionaries
53
+ custom_headers = [header.to_dict() for header in config.model_config.custom_headers]
54
+
52
55
  payload = {
53
56
  "testing_for": config.testing_for,
54
57
  "model_name": config.model_name,
@@ -72,9 +75,9 @@ class ModelClient(BaseClient):
72
75
  },
73
76
  "apikeys": [config.model_config.apikey] if config.model_config.apikey else [],
74
77
  "tools": config.model_config.tools,
75
- "input_modalities": config.model_config.input_modalities,
76
- "output_modalities": config.model_config.output_modalities,
77
- "custom_headers": config.model_config.custom_headers,
78
+ "input_modalities": [m.value if hasattr(m, 'value') else m for m in config.model_config.input_modalities],
79
+ "output_modalities": [m.value if hasattr(m, 'value') else m for m in config.model_config.output_modalities],
80
+ "custom_headers": custom_headers,
78
81
  "custom_payload": config.model_config.custom_payload,
79
82
  "custom_response_content_type": config.model_config.custom_response_content_type,
80
83
  "custom_response_format": config.model_config.custom_response_format,
@@ -189,6 +192,9 @@ class ModelClient(BaseClient):
189
192
  "chat": f"/{remaining_path}" if remaining_path else "",
190
193
  }
191
194
 
195
+ # Convert custom_headers to list of dictionaries
196
+ custom_headers = [header.to_dict() for header in config.model_config.custom_headers]
197
+
192
198
  payload = {
193
199
  "model_saved_name": config.model_saved_name,
194
200
  "model_version": config.model_version,
@@ -217,9 +223,9 @@ class ModelClient(BaseClient):
217
223
  [config.model_config.apikey] if config.model_config.apikey else []
218
224
  ),
219
225
  "tools": config.model_config.tools,
220
- "input_modalities": config.model_config.input_modalities,
221
- "output_modalities": config.model_config.output_modalities,
222
- "custom_headers": config.model_config.custom_headers,
226
+ "input_modalities": [m.value if hasattr(m, 'value') else m for m in config.model_config.input_modalities],
227
+ "output_modalities": [m.value if hasattr(m, 'value') else m for m in config.model_config.output_modalities],
228
+ "custom_headers": custom_headers,
223
229
  "custom_payload": config.model_config.custom_payload,
224
230
  "custom_response_content_type": config.model_config.custom_response_content_type,
225
231
  "custom_response_format": config.model_config.custom_response_format,
@@ -1,4 +1,5 @@
1
- import urllib3
1
+ # import json
2
+ # import urllib3
2
3
  from .base import BaseClient
3
4
  from .models import ModelClient
4
5
  from .datasets import DatasetClient
@@ -62,6 +63,8 @@ class RedTeamClient(BaseClient):
62
63
  """
63
64
  try:
64
65
  config = RedTeamModelHealthConfig.from_dict(config)
66
+ # Print the config as json string
67
+ # print(f"Config: {json.dumps(config.to_dict(), indent=4)}")
65
68
  response = self._request("POST", "/redteam/model-health", json=config.to_dict())
66
69
  # if response.get("error"):
67
70
  if response.get("error") not in [None, ""]:
@@ -103,9 +106,7 @@ class RedTeamClient(BaseClient):
103
106
  # "async": config.async_enabled,
104
107
  "dataset_name": config.dataset_name,
105
108
  "test_name": config.test_name,
106
- "redteam_test_configurations": {
107
- k: v.to_dict() for k, v in test_configs.items()
108
- },
109
+ "redteam_test_configurations": test_configs,
109
110
  }
110
111
 
111
112
  if config.target_model_configuration:
@@ -150,9 +151,7 @@ class RedTeamClient(BaseClient):
150
151
  # "async": config.async_enabled,
151
152
  "dataset_name": config.dataset_name,
152
153
  "test_name": config.test_name,
153
- "redteam_test_configurations": {
154
- k: v.to_dict() for k, v in test_configs.items()
155
- },
154
+ "redteam_test_configurations": test_configs,
156
155
  }
157
156
 
158
157
  headers = {
@@ -193,9 +192,7 @@ class RedTeamClient(BaseClient):
193
192
  payload = {
194
193
  # "async": config.async_enabled,
195
194
  "test_name": config.test_name,
196
- "redteam_test_configurations": {
197
- k: v.to_dict() for k, v in test_configs.items()
198
- },
195
+ "redteam_test_configurations": test_configs,
199
196
  }
200
197
 
201
198
  if config.dataset_configuration:
@@ -249,9 +246,7 @@ class RedTeamClient(BaseClient):
249
246
  payload = {
250
247
  # "async": config.async_enabled,
251
248
  "test_name": config.test_name,
252
- "redteam_test_configurations": {
253
- k: v.to_dict() for k, v in test_configs.items()
254
- },
249
+ "redteam_test_configurations": test_configs,
255
250
  }
256
251
 
257
252
  if config.dataset_configuration:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: enkryptai-sdk
3
- Version: 1.0.10
3
+ Version: 1.0.12
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -19,6 +19,7 @@ test_guardrails_policy_name = "Test Guardrails Policy"
19
19
  model_saved_name = None
20
20
  model_version = None
21
21
  test_model_saved_name = "Test Model"
22
+ test_custom_model_saved_name = "Test Custom Model"
22
23
  test_model_version = "v1"
23
24
  test_deployment_name = "test-deployment"
24
25
 
@@ -155,6 +156,7 @@ def sample_detectors():
155
156
  },
156
157
  }
157
158
 
159
+
158
160
  @pytest.fixture
159
161
  def sample_model_config():
160
162
  return {
@@ -171,6 +173,46 @@ def sample_model_config():
171
173
  },
172
174
  }
173
175
 
176
+
177
+ @pytest.fixture
178
+ def sample_custom_model_config():
179
+ return {
180
+ "model_saved_name": test_custom_model_saved_name,
181
+ "model_version": test_model_version,
182
+ "testing_for": "foundationModels",
183
+ "model_name": model_name,
184
+ "model_config": {
185
+ # "model_provider": model_provider,
186
+ "model_provider": "custom",
187
+ "endpoint_url": model_endpoint_url,
188
+ # "apikey": OPENAI_API_KEY,
189
+ "input_modalities": ["text"],
190
+ "output_modalities": ["text"],
191
+ "custom_headers": [
192
+ {
193
+ "key": "Content-Type",
194
+ "value": "application/json"
195
+ },
196
+ {
197
+ "key": "Authorization",
198
+ "value": "Bearer " + OPENAI_API_KEY
199
+ }
200
+ ],
201
+ "custom_payload": {
202
+ "model": model_name,
203
+ "messages": [
204
+ {
205
+ "role": "user",
206
+ "content": "{prompt}"
207
+ }
208
+ ]
209
+ },
210
+ "custom_response_content_type": "json",
211
+ "custom_response_format": ".choices[0].message.content",
212
+ },
213
+ }
214
+
215
+
174
216
  @pytest.fixture
175
217
  def sample_deployment_config():
176
218
  return {
@@ -235,9 +277,31 @@ def sample_redteam_model_health_config():
235
277
  "testing_for": "foundationModels",
236
278
  "model_version": "v1",
237
279
  "model_source": "https://openai.com",
238
- "model_provider": model_provider,
280
+ # "model_provider": model_provider,
281
+ "model_provider": "custom",
239
282
  "model_endpoint_url": model_endpoint_url,
240
- "model_api_key": OPENAI_API_KEY,
283
+ # "model_api_key": OPENAI_API_KEY,
284
+ "custom_headers": [
285
+ {
286
+ "key": "Content-Type",
287
+ "value": "application/json"
288
+ },
289
+ {
290
+ "key": "Authorization",
291
+ "value": "Bearer " + OPENAI_API_KEY
292
+ }
293
+ ],
294
+ "custom_payload": {
295
+ "model": model_name,
296
+ "messages": [
297
+ {
298
+ "role": "user",
299
+ "content": "{prompt}"
300
+ }
301
+ ]
302
+ },
303
+ "custom_response_content_type": "json",
304
+ "custom_response_format": ".choices[0].message.content",
241
305
  "system_prompt": "",
242
306
  "rate_per_min": 20,
243
307
  "input_modalities": ["text"],
@@ -659,18 +723,19 @@ def test_adherence(guardrails_client):
659
723
  assert summary.adherence_score == 0
660
724
 
661
725
 
662
- def test_relevancy(guardrails_client):
663
- print("\n\nTesting relevancy")
664
- # Test the relevancy method
665
- response = guardrails_client.relevancy(llm_answer="Hello! How can I help you today? If you have any questions or need assistance with something, feel free to ask. I'm here to provide information and support. Is there something specific you'd like to know or discuss?", question="Hi")
666
- print("\nResponse from relevancy: ", response)
667
- print("\nResponse data type: ", type(response))
668
- assert response is not None
669
- assert hasattr(response, "summary")
670
- summary = response.summary
671
- assert summary.relevancy_score == 0
672
- print("\nSleeping for 60 seconds after guardrails tests...")
673
- time.sleep(60)
726
+ # Not being used in Prod at this time
727
+ # def test_relevancy(guardrails_client):
728
+ # print("\n\nTesting relevancy")
729
+ # # Test the relevancy method
730
+ # response = guardrails_client.relevancy(llm_answer="Hello! How can I help you today? If you have any questions or need assistance with something, feel free to ask. I'm here to provide information and support. Is there something specific you'd like to know or discuss?", question="Hi")
731
+ # print("\nResponse from relevancy: ", response)
732
+ # print("\nResponse data type: ", type(response))
733
+ # assert response is not None
734
+ # assert hasattr(response, "summary")
735
+ # summary = response.summary
736
+ # assert summary.relevancy_score == 0
737
+ # print("\nSleeping for 60 seconds after guardrails tests...")
738
+ # time.sleep(60)
674
739
 
675
740
 
676
741
  def test_list_policies(guardrails_client):
@@ -900,6 +965,20 @@ def test_add_model(model_client, sample_model_config):
900
965
  assert model_info["model_saved_name"] == test_model_saved_name
901
966
  assert model_info["model_version"] == test_model_version
902
967
 
968
+ def test_add_custom_model(model_client, sample_custom_model_config):
969
+ print("\n\nTesting adding a new custom model")
970
+ response = model_client.add_model(config=sample_custom_model_config)
971
+ print("\nResponse from adding a new custom model: ", response)
972
+ print("\nResponse data type: ", type(response))
973
+ assert response is not None
974
+ # assert hasattr(response, "message")
975
+ assert response.message == "Model details added successfully"
976
+ # assert hasattr(response, "data")
977
+ model_info = response.data
978
+ assert model_info is not None
979
+ assert model_info["model_saved_name"] == test_custom_model_saved_name
980
+ assert model_info["model_version"] == test_model_version
981
+
903
982
  def test_list_models(model_client):
904
983
  print("\n\nTesting list_models")
905
984
  # Test the list_models method
@@ -917,7 +996,7 @@ def test_list_models(model_client):
917
996
  # assert model_info["model_version"] == test_model_version
918
997
 
919
998
  def test_get_model(model_client):
920
- print("\n\nTesting get_model")
999
+ print("\n\nTesting get_model with custom model")
921
1000
  # global model_saved_name
922
1001
  # if model_saved_name is None:
923
1002
  # print("\nModel saved name is None, fetching it from list_models")
@@ -927,7 +1006,7 @@ def test_get_model(model_client):
927
1006
  # model_saved_name = model_info.model_saved_name
928
1007
  # assert model_saved_name is not None
929
1008
  # print("\nPicked model in get_model: ", model_saved_name)
930
- model_saved_name = test_model_saved_name
1009
+ model_saved_name = test_custom_model_saved_name
931
1010
  model_version = test_model_version
932
1011
 
933
1012
  # Now test the get_model method
@@ -939,10 +1018,10 @@ def test_get_model(model_client):
939
1018
  # assert hasattr(model, "status")
940
1019
 
941
1020
 
942
- def test_modify_model(model_client, sample_model_config):
943
- print("\n\nTesting modifying a new model")
1021
+ def test_modify_model(model_client, sample_custom_model_config):
1022
+ print("\n\nTesting modifying custom model")
944
1023
  # Test creating a new model
945
- response = model_client.modify_model(config=sample_model_config, old_model_saved_name=None, old_model_version=None)
1024
+ response = model_client.modify_model(config=sample_custom_model_config, old_model_saved_name=None, old_model_version=None)
946
1025
  print("\nResponse from modifying a new model: ", response)
947
1026
  print("\nResponse data type: ", type(response))
948
1027
  assert response is not None
@@ -951,7 +1030,7 @@ def test_modify_model(model_client, sample_model_config):
951
1030
  # assert hasattr(response, "data")
952
1031
  model_info = response.data
953
1032
  assert model_info is not None
954
- assert model_info["model_saved_name"] == test_model_saved_name
1033
+ assert model_info["model_saved_name"] == test_custom_model_saved_name
955
1034
  assert model_info["model_version"] == test_model_version
956
1035
  print("\nSleeping for 60 seconds after model tests...")
957
1036
  time.sleep(60)
@@ -1174,25 +1253,29 @@ def test_get_health(redteam_client):
1174
1253
  assert response.status == "healthy"
1175
1254
 
1176
1255
 
1177
- def test_model_health(redteam_client, sample_redteam_model_health_config):
1178
- print("\n\nTesting check_model_health")
1179
- response = redteam_client.check_model_health(config=sample_redteam_model_health_config)
1180
- print("\nResponse from check_model_health: ", response)
1181
- assert response is not None
1182
- assert hasattr(response, "status")
1183
- assert response.status == "healthy"
1256
+ # ---------------------------------------------------------
1257
+ # Commenting as we already test for saved model health
1258
+ # ---------------------------------------------------------
1259
+ # def test_model_health(redteam_client, sample_redteam_model_health_config):
1260
+ # print("\n\nTesting check_model_health")
1261
+ # response = redteam_client.check_model_health(config=sample_redteam_model_health_config)
1262
+ # print("\nResponse from check_model_health: ", response)
1263
+ # assert response is not None
1264
+ # assert hasattr(response, "status")
1265
+ # assert response.status == "healthy"
1184
1266
 
1185
1267
 
1186
1268
  def test_saved_model_health(redteam_client):
1187
1269
  print("\n\nTesting check_saved_model_health")
1188
- response = redteam_client.check_saved_model_health(model_saved_name=test_model_saved_name, model_version=test_model_version)
1270
+ response = redteam_client.check_saved_model_health(model_saved_name=test_custom_model_saved_name, model_version=test_model_version)
1189
1271
  print("\nResponse from check_saved_model_health: ", response)
1190
1272
  assert response is not None
1191
1273
  assert hasattr(response, "status")
1192
1274
  assert response.status == "healthy"
1193
1275
 
1194
1276
 
1195
- # # Testing only via saved model as it should be sufficient
1277
+ # ---------------------------------------------------------
1278
+ # Commenting as we are testing with add custom task with saved model
1196
1279
  # ---------------------------------------------------------
1197
1280
  # def test_add_task_with_target_model(redteam_client, sample_redteam_target_config):
1198
1281
  # print("\n\nTesting adding a new redteam task with target model")
@@ -1210,20 +1293,24 @@ def test_saved_model_health(redteam_client):
1210
1293
  # time.sleep(60)
1211
1294
 
1212
1295
 
1213
- def test_add_task_with_saved_model(redteam_client, sample_redteam_model_config):
1214
- print("\n\nTesting adding a new redteam task with saved model")
1215
- response = redteam_client.add_task_with_saved_model(config=sample_redteam_model_config,model_saved_name=test_model_saved_name, model_version=test_model_version)
1216
- print("\nResponse from adding a new redteam task with saved model: ", response)
1217
- assert response is not None
1218
- assert hasattr(response, "task_id")
1219
- assert hasattr(response, "message")
1220
- response.message == "Redteam task has been added successfully"
1221
- # Sleep for a while to let the task complete
1222
- # This is also useful to avoid rate limiting issues
1223
- print("\nSleeping for 60 seconds to let the task complete if possible ...")
1224
- time.sleep(60)
1296
+ # ---------------------------------------------------------
1297
+ # Commenting as we are testing with add custom task with saved model
1298
+ # ---------------------------------------------------------
1299
+ # def test_add_task_with_saved_model(redteam_client, sample_redteam_model_config):
1300
+ # print("\n\nTesting adding a new redteam task with saved model")
1301
+ # response = redteam_client.add_task_with_saved_model(config=sample_redteam_model_config,model_saved_name=test_model_saved_name, model_version=test_model_version)
1302
+ # print("\nResponse from adding a new redteam task with saved model: ", response)
1303
+ # assert response is not None
1304
+ # assert hasattr(response, "task_id")
1305
+ # assert hasattr(response, "message")
1306
+ # response.message == "Redteam task has been added successfully"
1307
+ # # Sleep for a while to let the task complete
1308
+ # # This is also useful to avoid rate limiting issues
1309
+ # print("\nSleeping for 60 seconds to let the task complete if possible ...")
1310
+ # time.sleep(60)
1225
1311
 
1226
1312
 
1313
+ # ---------------------------------------------------------
1227
1314
  # # Testing only via saved model as it should be sufficient
1228
1315
  # ---------------------------------------------------------
1229
1316
  # def test_add_custom_task_with_target_model(redteam_client, sample_custom_redteam_target_config):
@@ -1244,7 +1331,7 @@ def test_add_task_with_saved_model(redteam_client, sample_redteam_model_config):
1244
1331
 
1245
1332
  def test_add_custom_task_with_saved_model(redteam_client, sample_custom_redteam_model_config):
1246
1333
  print("\n\nTesting adding a new custom redteam task with saved model")
1247
- response = redteam_client.add_custom_task_with_saved_model(config=sample_custom_redteam_model_config, model_saved_name=test_model_saved_name, model_version=test_model_version)
1334
+ response = redteam_client.add_custom_task_with_saved_model(config=sample_custom_redteam_model_config, model_saved_name=test_custom_model_saved_name, model_version=test_model_version)
1248
1335
  print("\nResponse from adding a new custom redteam task with saved model: ", response)
1249
1336
  assert response is not None
1250
1337
  assert hasattr(response, "task_id")
@@ -1386,6 +1473,14 @@ def test_delete_model(model_client):
1386
1473
  # assert hasattr(response, "message")
1387
1474
  assert response.message == "Model details deleted successfully"
1388
1475
 
1476
+ def test_delete_custom_model(model_client):
1477
+ print("\n\nTesting delete_custom_model")
1478
+ response = model_client.delete_model(model_saved_name=test_custom_model_saved_name, model_version=test_model_version)
1479
+ print("\nResponse from delete_custom_model: ", response)
1480
+ assert response is not None
1481
+ # assert hasattr(response, "message")
1482
+ assert response.message == "Model details deleted successfully"
1483
+
1389
1484
  def test_delete_deployment(deployment_client):
1390
1485
  print("\n\nTesting delete_deployment")
1391
1486
  response = deployment_client.delete_deployment(deployment_name=test_deployment_name)
@@ -14,6 +14,9 @@ model_saved_name = None
14
14
  test_model_saved_name = "Test Model"
15
15
  model_version = None
16
16
  test_model_version = "v1"
17
+ model_provider = "openai"
18
+ model_name = "gpt-4o-mini"
19
+ model_endpoint_url = "https://api.openai.com/v1/chat/completions"
17
20
 
18
21
  @pytest.fixture
19
22
  def model_client():
@@ -27,13 +30,35 @@ def sample_model_config():
27
30
  "model_saved_name": test_model_saved_name,
28
31
  "model_version": test_model_version,
29
32
  "testing_for": "foundationModels",
30
- "model_name": "gpt-4o-mini",
33
+ "model_name": model_name,
31
34
  "model_config": {
32
- "model_provider": "openai",
33
- "endpoint_url": "https://api.openai.com/v1/chat/completions",
34
- "apikey": OPENAI_API_KEY,
35
+ # "model_provider": model_provider,
36
+ "model_provider": "custom",
37
+ "endpoint_url": model_endpoint_url,
38
+ # "apikey": OPENAI_API_KEY,
35
39
  "input_modalities": ["text"],
36
40
  "output_modalities": ["text"],
41
+ "custom_headers": [
42
+ {
43
+ "key": "Content-Type",
44
+ "value": "application/json"
45
+ },
46
+ {
47
+ "key": "Authorization",
48
+ "value": "Bearer " + OPENAI_API_KEY
49
+ }
50
+ ],
51
+ "custom_payload": {
52
+ "model": model_name,
53
+ "messages": [
54
+ {
55
+ "role": "user",
56
+ "content": "{prompt}"
57
+ }
58
+ ]
59
+ },
60
+ "custom_response_content_type": "json",
61
+ "custom_response_format": ".choices[0].message.content",
37
62
  },
38
63
  }
39
64
 
File without changes
File without changes
File without changes