enkryptai-sdk 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
enkryptai_sdk/datasets.py CHANGED
@@ -36,9 +36,23 @@ class DatasetClient(BaseClient):
36
36
 
37
37
  payload = config.to_dict()
38
38
 
39
+ # If payload["tools"] is None or is an empty list [] or [{}], remove it from the payload
40
+ if (payload.get("tools") is None or
41
+ payload["tools"] == [] or
42
+ payload["tools"] == [{}] or
43
+ payload["tools"] == [{"name": "", "description": ""}]):
44
+ del payload["tools"]
45
+
46
+ # Print payload
47
+ # print(f"\nAdd Dataset Payload: {payload}")
48
+
39
49
  response = self._request(
40
50
  "POST", "/datasets/add-task", headers=headers, json=payload
41
51
  )
52
+
53
+ # Print response
54
+ # print(f"\nAdd Dataset Response: {response}")
55
+
42
56
  if response.get("error"):
43
57
  raise DatasetClientError(response["error"])
44
58
  return DatasetAddTaskResponse.from_dict(response)
@@ -62,10 +62,13 @@ __all__ = [
62
62
  "GuardrailsModelsResponse",
63
63
  "GuardrailDetectors",
64
64
  "GuardrailsDetectRequest",
65
+ "GuardrailsBatchDetectRequest",
65
66
  "GuardrailsPolicyDetectRequest",
66
67
  "DetectResponseSummary",
67
68
  "DetectResponseDetails",
68
69
  "GuardrailsDetectResponse",
70
+ "BatchDetectResponseItem",
71
+ "GuardrailsBatchDetectResponse",
69
72
  "GuardrailsPIIRequest",
70
73
  "GuardrailsPIIResponse",
71
74
  "GuardrailsHallucinationRequest",
@@ -3,17 +3,35 @@ from dataclasses import dataclass, field
3
3
  from typing import Dict, List, Optional, Any
4
4
 
5
5
 
6
+ @dataclass
7
+ class Tool(BaseDTO):
8
+ name: str
9
+ description: str
10
+
11
+ @classmethod
12
+ def from_dict(cls, data: Dict[str, Any]) -> "Tool":
13
+ return cls(
14
+ name=data.get("name", ""),
15
+ description=data.get("description", "")
16
+ )
17
+
18
+ def to_dict(self) -> Dict[str, Any]:
19
+ return {
20
+ "name": self.name,
21
+ "description": self.description
22
+ }
23
+
6
24
  @dataclass
7
25
  class DatasetConfig(BaseDTO):
8
26
  dataset_name: str
9
27
  system_description: str
10
28
  policy_description: str = ""
11
- tools: List[str] = field(default_factory=list)
29
+ tools: List[Tool] = field(default_factory=list)
12
30
  info_pdf_url: str = ""
13
31
  max_prompts: int = 100
14
- # scenarios: int = 2
15
- # categories: int = 2
16
- # depth: int = 2
32
+ scenarios: int = 2
33
+ categories: int = 2
34
+ depth: int = 2
17
35
 
18
36
 
19
37
  @dataclass
@@ -141,6 +141,7 @@ class OutputGuardrailsPolicy(BaseDTO):
141
141
  class DeploymentInput(BaseDTO):
142
142
  name: str
143
143
  model_saved_name: str
144
+ model_version: str
144
145
  input_guardrails_policy: InputGuardrailsPolicy
145
146
  output_guardrails_policy: OutputGuardrailsPolicy
146
147
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
@@ -153,6 +154,7 @@ class DeploymentInput(BaseDTO):
153
154
  return cls(
154
155
  name=data.get("name", ""),
155
156
  model_saved_name=data.get("model_saved_name", ""),
157
+ model_version=data.get("model_version", ""),
156
158
  input_guardrails_policy=InputGuardrailsPolicy.from_dict(input_policy_data),
157
159
  output_guardrails_policy=OutputGuardrailsPolicy.from_dict(output_policy_data)
158
160
  )
@@ -161,6 +163,7 @@ class DeploymentInput(BaseDTO):
161
163
  result = {
162
164
  "name": self.name,
163
165
  "model_saved_name": self.model_saved_name,
166
+ "model_version": self.model_version,
164
167
  "input_guardrails_policy": self.input_guardrails_policy.to_dict(),
165
168
  "output_guardrails_policy": self.output_guardrails_policy.to_dict()
166
169
  }
@@ -172,6 +175,7 @@ class DeploymentInput(BaseDTO):
172
175
  class GetDeploymentResponse(BaseDTO):
173
176
  name: str
174
177
  model_saved_name: str
178
+ model_version: str
175
179
  input_guardrails_policy: InputGuardrailsPolicy
176
180
  output_guardrails_policy: OutputGuardrailsPolicy
177
181
  created_at: str
@@ -186,6 +190,7 @@ class GetDeploymentResponse(BaseDTO):
186
190
  return cls(
187
191
  name=data.get("name", ""),
188
192
  model_saved_name=data.get("model_saved_name", ""),
193
+ model_version=data.get("model_version", ""),
189
194
  input_guardrails_policy=InputGuardrailsPolicy.from_dict(input_policy_data),
190
195
  output_guardrails_policy=OutputGuardrailsPolicy.from_dict(output_policy_data),
191
196
  updated_at=data.get("updated_at", ""),
@@ -197,6 +202,7 @@ class GetDeploymentResponse(BaseDTO):
197
202
  return {
198
203
  "name": self.name,
199
204
  "model_saved_name": self.model_saved_name,
205
+ "model_version": self.model_version,
200
206
  "input_guardrails_policy": self.input_guardrails_policy.to_dict(),
201
207
  "output_guardrails_policy": self.output_guardrails_policy.to_dict(),
202
208
  "updated_at": self.updated_at,
@@ -293,6 +299,7 @@ class DeploymentSummary(BaseDTO):
293
299
  created_at: str
294
300
  updated_at: str
295
301
  model_saved_name: str
302
+ model_version: str
296
303
  project_name: str
297
304
 
298
305
  @classmethod
@@ -303,6 +310,7 @@ class DeploymentSummary(BaseDTO):
303
310
  created_at=data.get("created_at", ""),
304
311
  updated_at=data.get("updated_at", ""),
305
312
  model_saved_name=data.get("model_saved_name", ""),
313
+ model_version=data.get("model_version", ""),
306
314
  project_name=data.get("project_name", "")
307
315
  )
308
316
 
@@ -313,6 +321,7 @@ class DeploymentSummary(BaseDTO):
313
321
  "created_at": self.created_at,
314
322
  "updated_at": self.updated_at,
315
323
  "model_saved_name": self.model_saved_name,
324
+ "model_version": self.model_version,
316
325
  "project_name": self.project_name
317
326
  }
318
327
 
@@ -301,6 +301,27 @@ class GuardrailsDetectRequest(BaseDTO):
301
301
  }
302
302
 
303
303
 
304
+ @dataclass
305
+ class GuardrailsBatchDetectRequest(BaseDTO):
306
+ texts: List[str] = field(default_factory=list)
307
+ detectors: GuardrailDetectors = field(default_factory=GuardrailDetectors)
308
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
309
+
310
+ @classmethod
311
+ def from_dict(cls, data: Dict[str, Any]) -> "GuardrailsBatchDetectRequest":
312
+ detectors_data = data.get("detectors", {})
313
+ return cls(
314
+ texts=data.get("texts", []),
315
+ detectors=GuardrailDetectors.from_dict(detectors_data)
316
+ )
317
+
318
+ def to_dict(self) -> Dict[str, Any]:
319
+ return {
320
+ "texts": self.texts,
321
+ "detectors": self.detectors.to_dict()
322
+ }
323
+
324
+
304
325
  @dataclass
305
326
  class GuardrailsPolicyDetectRequest(BaseDTO):
306
327
  text: str
@@ -750,6 +771,180 @@ class GuardrailsDetectResponse(BaseDTO):
750
771
  return f"Response Status: {status}\n{violation_str}"
751
772
 
752
773
 
774
+ @dataclass
775
+ class BatchDetectResponseItem(BaseDTO):
776
+ text: str
777
+ summary: DetectResponseSummary
778
+ details: DetectResponseDetails
779
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
780
+
781
+ @classmethod
782
+ def from_dict(cls, data: Dict[str, Any]) -> "BatchDetectResponseItem":
783
+ return cls(
784
+ text=data.get("text", ""),
785
+ summary=DetectResponseSummary.from_dict(data.get("summary", {})),
786
+ details=DetectResponseDetails.from_dict(data.get("details", {}))
787
+ )
788
+
789
+ def to_dict(self) -> Dict[str, Any]:
790
+ result = {
791
+ "text": self.text,
792
+ "summary": self.summary.to_dict(),
793
+ "details": self.details.to_dict()
794
+ }
795
+ result.update(self._extra_fields)
796
+ return result
797
+
798
+ def has_violations(self) -> bool:
799
+ """
800
+ Check if any detectors found violations in the content.
801
+
802
+ Returns:
803
+ bool: True if any detector reported a violation (score > 0), False otherwise
804
+ """
805
+ summary = self.summary.to_dict()
806
+ for key, value in summary.items():
807
+ if key == "toxicity" and isinstance(value, list) and len(value) > 0:
808
+ return True
809
+ elif isinstance(value, (int, float)) and value > 0:
810
+ return True
811
+ return False
812
+
813
+ def get_violations(self) -> list[str]:
814
+ """
815
+ Get a list of detector names that found violations.
816
+
817
+ Returns:
818
+ list[str]: Names of detectors that reported violations
819
+ """
820
+ summary = self.summary.to_dict()
821
+ violations = []
822
+ for detector, value in summary.items():
823
+ if detector == "toxicity" and isinstance(value, list) and len(value) > 0:
824
+ violations.append(detector)
825
+ elif isinstance(value, (int, float)) and value > 0:
826
+ violations.append(detector)
827
+ return violations
828
+
829
+ def is_safe(self) -> bool:
830
+ """
831
+ Check if the content is safe (no violations detected).
832
+
833
+ Returns:
834
+ bool: True if no violations were detected, False otherwise
835
+ """
836
+ return not self.has_violations()
837
+
838
+ def is_attack(self) -> bool:
839
+ """
840
+ Check if the content is attacked (violations detected).
841
+
842
+ Returns:
843
+ bool: True if violations were detected, False otherwise
844
+ """
845
+ return self.has_violations()
846
+
847
+ def __str__(self) -> str:
848
+ """
849
+ String representation of the response.
850
+
851
+ Returns:
852
+ str: A formatted string showing summary and violation status
853
+ """
854
+ violations = self.get_violations()
855
+ status = "UNSAFE" if violations else "SAFE"
856
+
857
+ if violations:
858
+ violation_str = f"Violations detected: {', '.join(violations)}"
859
+ else:
860
+ violation_str = "No violations detected"
861
+
862
+ return f"Response Status: {status}\n{violation_str}"
863
+
864
+
865
+ @dataclass
866
+ class GuardrailsBatchDetectResponse(BaseDTO):
867
+ batch_detections: List[BatchDetectResponseItem] = field(default_factory=list)
868
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
869
+
870
+ @classmethod
871
+ def from_dict(cls, data: List[Dict[str, Any]]) -> "GuardrailsBatchDetectResponse":
872
+ return cls(
873
+ batch_detections=[BatchDetectResponseItem.from_dict(item) for item in data]
874
+ )
875
+
876
+ def to_dict(self) -> List[Dict[str, Any]]:
877
+ return [response.to_dict() for response in self.batch_detections]
878
+
879
+ def has_violations(self) -> bool:
880
+ """
881
+ Check if any detectors found violations in any of the batch_detections.
882
+
883
+ Returns:
884
+ bool: True if any detector reported a violation, False otherwise
885
+ """
886
+ for detection in self.batch_detections:
887
+ summary = detection.summary.to_dict()
888
+ for key, value in summary.items():
889
+ if key == "toxicity" and isinstance(value, list) and len(value) > 0:
890
+ return True
891
+ elif isinstance(value, (int, float)) and value > 0:
892
+ return True
893
+ return False
894
+
895
+ def get_violations(self) -> List[str]:
896
+ """
897
+ Get a list of texts that have violations.
898
+
899
+ Returns:
900
+ List[str]: List of texts that have violations
901
+ """
902
+ violations = set()
903
+ for detection in self.batch_detections:
904
+ summary = detection.summary.to_dict()
905
+ for detector, value in summary.items():
906
+ if detector == "toxicity" and isinstance(value, list) and len(value) > 0:
907
+ violations.add(detector)
908
+ elif isinstance(value, (int, float)) and value > 0:
909
+ violations.add(detector)
910
+ return list(violations)
911
+
912
+ def is_safe(self) -> bool:
913
+ """
914
+ Check if all content is safe (no violations detected).
915
+
916
+ Returns:
917
+ bool: True if no violations were detected in any response, False otherwise
918
+ """
919
+ return not self.has_violations()
920
+
921
+ def is_attack(self) -> bool:
922
+ """
923
+ Check if any content is attacked (violations detected).
924
+
925
+ Returns:
926
+ bool: True if violations were detected in any response, False otherwise
927
+ """
928
+ return self.has_violations()
929
+
930
+ def __str__(self) -> str:
931
+ """
932
+ String representation of the batch response.
933
+
934
+ Returns:
935
+ str: A formatted string showing violation status for all batch_detections
936
+ """
937
+ violations = self.get_violations()
938
+ status = "UNSAFE" if violations else "SAFE"
939
+
940
+ if violations:
941
+ violation_str = f"Violations detected in texts:\n" + "\n".join(f"- {text}" for text in violations)
942
+ else:
943
+ violation_str = "No violations detected in any text"
944
+
945
+ return f"Batch Response Status: {status}\n{violation_str}"
946
+
947
+
753
948
  # -------------------------------------
754
949
  # Guardrails PII
755
950
  # -------------------------------------
@@ -7,14 +7,14 @@ from dataclasses import dataclass, field, asdict
7
7
  from typing import Optional, List, Set, Dict, Any
8
8
 
9
9
 
10
- class Modality(Enum):
11
- TEXT = "text"
12
- IMAGE = "image"
13
- AUDIO = "audio"
14
- VIDEO = "video"
10
+ # class Modality(Enum):
11
+ # TEXT = "text"
12
+ # IMAGE = "image"
13
+ # AUDIO = "audio"
14
+ # VIDEO = "video"
15
15
 
16
- def to_dict(self):
17
- return self.value
16
+ # def to_dict(self):
17
+ # return self.value
18
18
 
19
19
 
20
20
  class ModelProviders(str, Enum):
@@ -85,7 +85,6 @@ class AuthData(BaseDTO):
85
85
 
86
86
  @dataclass
87
87
  class ModelDetailConfig:
88
- model_version: Optional[str] = None
89
88
  model_source: str = ""
90
89
  # model_provider: str = "openai"
91
90
  model_provider: ModelProviders = ModelProviders.OPENAI
@@ -99,18 +98,52 @@ class ModelDetailConfig:
99
98
  @dataclass
100
99
  class DetailModelConfig:
101
100
  model_saved_name: str = "Model Name"
102
- testing_for: str = "LLM"
103
- model_name: str = "gpt-4o-mini"
104
- modality: Modality = Modality.TEXT
101
+ model_version: str = "v1"
102
+ testing_for: str = "foundationModels"
103
+ model_name: Optional[str] = "gpt-4o-mini"
104
+ # modality: Modality = Modality.TEXT
105
105
  model_config: ModelDetailConfig = field(default_factory=ModelDetailConfig)
106
106
 
107
107
 
108
+ class InputModality(str, Enum):
109
+ text = "text"
110
+ image = "image"
111
+ audio = "audio"
112
+ # video = "video"
113
+ # code = "code"
114
+
115
+
116
+ class OutputModality(str, Enum):
117
+ text = "text"
118
+ # image = "image"
119
+ # audio = "audio"
120
+ # video = "video"
121
+ # code = "code"
122
+
123
+
124
+ @dataclass
125
+ class CustomHeader(BaseDTO):
126
+ key: str
127
+ value: str
128
+
129
+ @classmethod
130
+ def from_dict(cls, data: Dict[str, Any]) -> "CustomHeader":
131
+ return cls(
132
+ key=data.get("key", ""),
133
+ value=data.get("value", "")
134
+ )
135
+
136
+ def to_dict(self) -> Dict[str, Any]:
137
+ return {
138
+ "key": self.key,
139
+ "value": self.value
140
+ }
141
+
142
+
108
143
  @dataclass
109
144
  class ModelConfigDetails(BaseDTO):
110
145
  model_id: str = ""
111
- model_version: Optional[str] = None
112
146
  model_source: str = ""
113
- model_name: str = ""
114
147
  # model_provider: str = "openai"
115
148
  model_provider: ModelProviders = ModelProviders.OPENAI
116
149
  model_api_value: str = ""
@@ -119,15 +152,23 @@ class ModelConfigDetails(BaseDTO):
119
152
  model_api_key: str = ""
120
153
  model_endpoint_url: str = ""
121
154
  rate_per_min: int = 20
122
- testing_for: str = "LLM"
155
+ testing_for: str = "foundationModels"
123
156
  headers: str = ""
124
157
  system_prompt: str = ""
125
- conversation_template: str = ""
126
158
  hosting_type: str = "External"
127
159
  endpoint_url: str = "https://api.openai.com/v1/chat/completions"
160
+ model_name: Optional[str] = ""
161
+ apikey: Optional[str] = None
128
162
  paths: Optional[PathsConfig] = None
163
+ tools: List[Dict[str, str]] = field(default_factory=list)
129
164
  auth_data: AuthData = field(default_factory=AuthData)
130
- apikey: Optional[str] = None
165
+ input_modalities: List[InputModality] = field(default_factory=list)
166
+ output_modalities: List[OutputModality] = field(default_factory=list)
167
+ custom_headers: List[CustomHeader] = field(default_factory=list)
168
+ custom_payload: Dict[str, Any] = field(default_factory=dict)
169
+ custom_response_content_type: Optional[str] = ""
170
+ custom_response_format: Optional[str] = ""
171
+ metadata: Dict[str, Any] = field(default_factory=dict)
131
172
  default_request_options: Dict[str, Any] = field(default_factory=dict)
132
173
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
133
174
 
@@ -136,6 +177,17 @@ class ModelConfigDetails(BaseDTO):
136
177
  # Create a copy of the data to avoid modifying the original
137
178
  data = data.copy()
138
179
 
180
+ if "custom_headers" in data:
181
+ data["custom_headers"] = [CustomHeader.from_dict(h) for h in data["custom_headers"]]
182
+
183
+ # Convert input_modalities strings to enum values
184
+ if "input_modalities" in data:
185
+ data["input_modalities"] = [InputModality(m) for m in data["input_modalities"]]
186
+
187
+ # Convert output_modalities strings to enum values
188
+ if "output_modalities" in data:
189
+ data["output_modalities"] = [OutputModality(m) for m in data["output_modalities"]]
190
+
139
191
  # Validate model_provider if present
140
192
  if "model_provider" in data:
141
193
  provider = data["model_provider"]
@@ -151,10 +203,10 @@ class ModelConfigDetails(BaseDTO):
151
203
  valid_providers = [p.value for p in ModelProviders]
152
204
  raise ValueError(f"Invalid model_provider type. Valid values: {valid_providers}")
153
205
 
154
- # Remove known fields that we don't want in our model
155
- unwanted_fields = ["queryParams"]
156
- for field in unwanted_fields:
157
- data.pop(field, None)
206
+ # # Remove known fields that we don't want in our model
207
+ # unwanted_fields = ["queryParams"]
208
+ # for field in unwanted_fields:
209
+ # data.pop(field, None)
158
210
 
159
211
  # Handle apikeys to apikey conversion
160
212
  if "apikeys" in data:
@@ -211,14 +263,14 @@ class ModelConfig(BaseDTO):
211
263
  updated_at: str = ""
212
264
  model_id: str = ""
213
265
  model_saved_name: str = "Model Name"
214
- testing_for: str = "LLM"
215
- model_name: str = "gpt-4o-mini"
216
- model_type: str = "text_2_text"
217
- modality: Modality = Modality.TEXT
266
+ model_version: str = "v1"
267
+ testing_for: str = "foundationModels"
268
+ # modality: Modality = Modality.TEXT
218
269
  project_name: str = ""
270
+ model_name: Optional[str] = "gpt-4o-mini"
219
271
  certifications: List[str] = field(default_factory=list)
220
272
  model_config: ModelConfigDetails = field(default_factory=ModelConfigDetails)
221
- is_sample: bool = False
273
+ is_sample: Optional[bool] = False
222
274
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
223
275
 
224
276
  @classmethod
@@ -231,16 +283,16 @@ class ModelConfig(BaseDTO):
231
283
  except ValueError as e:
232
284
  raise ValueError(f"Error in model_config: {str(e)}")
233
285
 
234
- # Handle Modality enum
235
- modality_value = data.pop("modality", "text")
236
- modality = Modality(modality_value)
286
+ # # Handle Modality enum
287
+ # modality_value = data.pop("modality", "text")
288
+ # modality = Modality(modality_value)
237
289
 
238
- return cls(**data, modality=modality, model_config=model_config)
290
+ return cls(**data, model_config=model_config)
239
291
 
240
292
  @classmethod
241
293
  def __str__(self):
242
294
  """String representation of the ModelConfig."""
243
- return f"ModelConfig(name={self.model_saved_name}, model={self.model_name})"
295
+ return f"ModelConfig(name={self.model_saved_name}, version={self.model_version}, model={self.model_name})"
244
296
 
245
297
  def __repr__(self):
246
298
  """Detailed string representation of the ModelConfig."""
@@ -341,7 +393,7 @@ class Task(BaseDTO):
341
393
 
342
394
  task_id: str
343
395
  status: str
344
- model_name: str
396
+ model_name: Optional[str] = None
345
397
  test_name: Optional[str] = None
346
398
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
347
399
 
@@ -1,4 +1,5 @@
1
1
  import pandas as pd
2
+ from enum import Enum
2
3
  from .base import BaseDTO
3
4
  from typing import Dict, List, Optional, Any
4
5
  from dataclasses import dataclass, field, asdict
@@ -100,11 +101,11 @@ class ResultSummary(BaseDTO):
100
101
  test_date: str
101
102
  test_name: str
102
103
  dataset_name: str
103
- model_name: str
104
104
  model_endpoint_url: str
105
105
  model_source: str
106
106
  model_provider: str
107
107
  risk_score: float
108
+ model_name: Optional[str]
108
109
  test_type: Dict[str, StatisticItem]
109
110
  nist_category: Dict[str, StatisticItem]
110
111
  scenario: Dict[str, StatisticItem]
@@ -257,9 +258,10 @@ class RedTeamResultDetails(BaseDTO):
257
258
  @dataclass
258
259
  class AttackMethods(BaseDTO):
259
260
  basic: List[str] = field(default_factory=lambda: ["basic"])
260
- advanced: Dict[str, List[str]] = field(
261
- default_factory=lambda: {"static": ["single_shot"], "dynamic": ["iterative"]}
262
- )
261
+ # advanced: Dict[str, List[str]] = field(
262
+ # default_factory=lambda: {"static": ["single_shot"], "dynamic": ["iterative"]}
263
+ # )
264
+ advanced: Dict[str, List[str]] = field(default_factory=dict)
263
265
 
264
266
  def to_dict(self) -> dict:
265
267
  return asdict(self)
@@ -271,7 +273,7 @@ class AttackMethods(BaseDTO):
271
273
 
272
274
  @dataclass
273
275
  class TestConfig(BaseDTO):
274
- sample_percentage: int = 100
276
+ sample_percentage: int = 5
275
277
  attack_methods: AttackMethods = field(default_factory=AttackMethods)
276
278
 
277
279
  def to_dict(self) -> dict:
@@ -298,6 +300,7 @@ class RedTeamTestConfigurations(BaseDTO):
298
300
  adv_info_test: TestConfig = field(default=None)
299
301
  adv_bias_test: TestConfig = field(default=None)
300
302
  adv_command_test: TestConfig = field(default=None)
303
+ # custom_test: TestConfig = field(default=None)
301
304
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
302
305
 
303
306
  @classmethod
@@ -305,19 +308,58 @@ class RedTeamTestConfigurations(BaseDTO):
305
308
  return cls(**{k: TestConfig.from_dict(v) for k, v in data.items()})
306
309
 
307
310
 
311
+ class InputModality(str, Enum):
312
+ text = "text"
313
+ image = "image"
314
+ audio = "audio"
315
+ # video = "video"
316
+ # code = "code"
317
+
318
+
319
+ class OutputModality(str, Enum):
320
+ text = "text"
321
+ # image = "image"
322
+ # audio = "audio"
323
+ # video = "video"
324
+ # code = "code"
325
+
326
+
327
+ @dataclass
328
+ class CustomHeader(BaseDTO):
329
+ key: str
330
+ value: str
331
+
332
+ @classmethod
333
+ def from_dict(cls, data: Dict[str, Any]) -> "CustomHeader":
334
+ return cls(
335
+ key=data.get("key", ""),
336
+ value=data.get("value", "")
337
+ )
338
+
339
+ def to_dict(self) -> Dict[str, Any]:
340
+ return {
341
+ "key": self.key,
342
+ "value": self.value
343
+ }
344
+
345
+
308
346
  @dataclass
309
347
  class TargetModelConfiguration(BaseDTO):
310
- testing_for: str = "LLM"
311
- model_name: str = "gpt-4o-mini"
312
- model_type: str = "text_2_text"
313
- model_version: Optional[str] = None
348
+ testing_for: str = "foundationModels"
314
349
  system_prompt: str = ""
315
- conversation_template: str = ""
316
350
  model_source: str = ""
317
351
  model_provider: str = "openai"
318
352
  model_endpoint_url: str = "https://api.openai.com/v1/chat/completions"
319
- model_api_key: Optional[str] = None
320
353
  rate_per_min: int = 20
354
+ model_name: Optional[str] = "gpt-4o-mini"
355
+ model_version: Optional[str] = None
356
+ model_api_key: Optional[str] = None
357
+ input_modalities: List[InputModality] = field(default_factory=list)
358
+ output_modalities: List[OutputModality] = field(default_factory=list)
359
+ custom_headers: List[CustomHeader] = field(default_factory=list)
360
+ custom_payload: Dict[str, Any] = field(default_factory=dict)
361
+ custom_response_content_type: Optional[str] = ""
362
+ custom_response_format: Optional[str] = ""
321
363
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
322
364
 
323
365
  @classmethod
@@ -351,19 +393,25 @@ class RedTeamModelHealthConfig(BaseDTO):
351
393
  @dataclass
352
394
  class RedteamModelHealthResponse(BaseDTO):
353
395
  status: str
396
+ message: str
354
397
  error: str
398
+ data: Optional[Dict[str, Any]] = field(default_factory=dict)
355
399
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
356
400
 
357
401
  @classmethod
358
402
  def from_dict(cls, data: Dict[str, Any]) -> "RedteamModelHealthResponse":
359
403
  return cls(
360
404
  status=data.get("status", ""),
405
+ message=data.get("message", ""),
406
+ data=data.get("data", {}),
361
407
  error=data.get("error", "")
362
408
  )
363
409
 
364
410
  def to_dict(self) -> Dict[str, Any]:
365
411
  return {
366
412
  "status": self.status,
413
+ "message": self.message,
414
+ "data": self.data,
367
415
  "error": self.error
368
416
  }
369
417
 
@@ -409,6 +457,7 @@ class RedTeamConfigWithSavedModel(BaseDTO):
409
457
  test_name: str = "Test Name"
410
458
  dataset_name: str = "standard"
411
459
  model_saved_name: str = "gpt-4o-mini"
460
+ model_version: str = "v1"
412
461
 
413
462
  redteam_test_configurations: RedTeamTestConfigurations = field(
414
463
  default_factory=RedTeamTestConfigurations
@@ -7,10 +7,13 @@ from .dto import (
7
7
  GuardrailsModelsResponse,
8
8
  # GuardrailDetectors,
9
9
  # GuardrailsDetectRequest,
10
+ # GuardrailsBatchDetectRequest,
10
11
  # GuardrailsPolicyDetectRequest,
11
12
  # DetectResponseSummary,
12
13
  # DetectResponseDetails,
13
14
  GuardrailsDetectResponse,
15
+ # BatchDetectResponseItem,
16
+ GuardrailsBatchDetectResponse,
14
17
  # GuardrailsPIIRequest,
15
18
  GuardrailsPIIResponse,
16
19
  # GuardrailsHallucinationRequest,
@@ -154,6 +157,42 @@ class GuardrailsClient(BaseClient):
154
157
  except Exception as e:
155
158
  raise GuardrailsClientError(str(e))
156
159
 
160
+ def batch_detect(self, texts, config=None):
161
+ """
162
+ Detects prompt injection, toxicity, NSFW content, PII, hallucination, and more in batch.
163
+
164
+ Parameters:
165
+ - texts (list): A list of texts to analyze.
166
+ - guardrails_config (dict or GuardrailsConfig, optional): A configuration for detectors.
167
+ If a GuardrailsConfig instance is provided, its underlying dictionary will be used.
168
+ If not provided, defaults to injection attack detection only.
169
+
170
+ Returns:
171
+ - Response from the API.
172
+ """
173
+ # Use injection attack config by default if none provided
174
+ if config is None:
175
+ config = GuardrailsConfig.injection_attack()
176
+
177
+ # Allow passing in either a dict or a GuardrailsConfig or GuardrailDetectors instance
178
+ if hasattr(config, "as_dict"):
179
+ config = config.as_dict()
180
+ if hasattr(config, "to_dict"):
181
+ config = config.to_dict()
182
+
183
+ payload = {
184
+ "texts": texts,
185
+ "detectors": config
186
+ }
187
+
188
+ try:
189
+ response = self._request("POST", "/guardrails/batch/detect", json=payload)
190
+ if isinstance(response, dict) and response.get("error"):
191
+ raise GuardrailsClientError(response["error"])
192
+ return GuardrailsBatchDetectResponse.from_dict(response)
193
+ except Exception as e:
194
+ raise GuardrailsClientError(str(e))
195
+
157
196
  def pii(self, text, mode="request", key="null", entities=None):
158
197
  """
159
198
  Detects Personally Identifiable Information (PII) and can de-anonymize it.
enkryptai_sdk/models.py CHANGED
@@ -51,17 +51,15 @@ class ModelClient(BaseClient):
51
51
 
52
52
  payload = {
53
53
  "model_saved_name": config.model_saved_name,
54
+ "model_version": config.model_version,
54
55
  "testing_for": config.testing_for,
55
56
  "model_name": config.model_name,
56
- "model_type": config.model_type,
57
57
  "certifications": config.certifications,
58
58
  "model_config": {
59
59
  "model_provider": config.model_config.model_provider,
60
- "model_version": config.model_config.model_version,
61
60
  "hosting_type": config.model_config.hosting_type,
62
61
  "model_source": config.model_config.model_source,
63
62
  "system_prompt": config.model_config.system_prompt,
64
- "conversation_template": config.model_config.conversation_template,
65
63
  "endpoint": {
66
64
  "scheme": parsed_url.scheme,
67
65
  "host": parsed_url.hostname,
@@ -78,6 +76,14 @@ class ModelClient(BaseClient):
78
76
  "apikeys": (
79
77
  [config.model_config.apikey] if config.model_config.apikey else []
80
78
  ),
79
+ "tools": config.model_config.tools,
80
+ "input_modalities": config.model_config.input_modalities,
81
+ "output_modalities": config.model_config.output_modalities,
82
+ "custom_headers": config.model_config.custom_headers,
83
+ "custom_payload": config.model_config.custom_payload,
84
+ "custom_response_content_type": config.model_config.custom_response_content_type,
85
+ "custom_response_format": config.model_config.custom_response_format,
86
+ "metadata": config.model_config.metadata,
81
87
  "default_request_options": config.model_config.default_request_options,
82
88
  },
83
89
  }
@@ -91,17 +97,18 @@ class ModelClient(BaseClient):
91
97
  except Exception as e:
92
98
  raise ModelClientError(str(e))
93
99
 
94
- def get_model(self, model_saved_name: str) -> ModelConfig:
100
+ def get_model(self, model_saved_name: str, model_version: str) -> ModelConfig:
95
101
  """
96
102
  Get model configuration by model saved name.
97
103
 
98
104
  Args:
99
105
  model_saved_name (str): Saved name of the model to retrieve
106
+ model_version (str): Version of the model to retrieve
100
107
 
101
108
  Returns:
102
109
  ModelConfig: Configuration object containing model details
103
110
  """
104
- headers = {"X-Enkrypt-Model": model_saved_name}
111
+ headers = {"X-Enkrypt-Model": model_saved_name, "X-Enkrypt-Model-Version": model_version}
105
112
  response = self._request("GET", "/models/get-model", headers=headers)
106
113
  # print(response)
107
114
  if response.get("error"):
@@ -123,12 +130,13 @@ class ModelClient(BaseClient):
123
130
  except Exception as e:
124
131
  return {"error": str(e)}
125
132
 
126
- def modify_model(self, config: ModelConfig, old_model_saved_name=None) -> ModelResponse:
133
+ def modify_model(self, config: ModelConfig, old_model_saved_name=None, old_model_version=None) -> ModelResponse:
127
134
  """
128
135
  Modify an existing model in the system.
129
136
 
130
137
  Args:
131
- model_saved_name (str): The saved name of the model to modify
138
+ old_model_saved_name (str): The old saved name of the model to modify
139
+ old_model_version (str): The old version of the model to modify
132
140
  config (ModelConfig): Configuration object containing model details
133
141
 
134
142
  Returns:
@@ -137,7 +145,10 @@ class ModelClient(BaseClient):
137
145
  if old_model_saved_name is None:
138
146
  old_model_saved_name = config["model_saved_name"]
139
147
 
140
- headers = {"Content-Type": "application/json", "X-Enkrypt-Model": old_model_saved_name}
148
+ if old_model_version is None:
149
+ old_model_version = config["model_version"]
150
+
151
+ headers = {"Content-Type": "application/json", "X-Enkrypt-Model": old_model_saved_name, "X-Enkrypt-Model-Version": old_model_version}
141
152
  # print(config)
142
153
  config = ModelConfig.from_dict(config)
143
154
  # Parse endpoint_url into components
@@ -165,17 +176,15 @@ class ModelClient(BaseClient):
165
176
 
166
177
  payload = {
167
178
  "model_saved_name": config.model_saved_name,
179
+ "model_version": config.model_version,
168
180
  "testing_for": config.testing_for,
169
181
  "model_name": config.model_name,
170
- "model_type": config.model_type,
171
182
  "certifications": config.certifications,
172
183
  "model_config": {
173
184
  "model_provider": config.model_config.model_provider,
174
- "model_version": config.model_config.model_version,
175
185
  "hosting_type": config.model_config.hosting_type,
176
186
  "model_source": config.model_config.model_source,
177
187
  "system_prompt": config.model_config.system_prompt,
178
- "conversation_template": config.model_config.conversation_template,
179
188
  "endpoint": {
180
189
  "scheme": parsed_url.scheme,
181
190
  "host": parsed_url.hostname,
@@ -192,6 +201,14 @@ class ModelClient(BaseClient):
192
201
  "apikeys": (
193
202
  [config.model_config.apikey] if config.model_config.apikey else []
194
203
  ),
204
+ "tools": config.model_config.tools,
205
+ "input_modalities": config.model_config.input_modalities,
206
+ "output_modalities": config.model_config.output_modalities,
207
+ "custom_headers": config.model_config.custom_headers,
208
+ "custom_payload": config.model_config.custom_payload,
209
+ "custom_response_content_type": config.model_config.custom_response_content_type,
210
+ "custom_response_format": config.model_config.custom_response_format,
211
+ "metadata": config.model_config.metadata,
195
212
  "default_request_options": config.model_config.default_request_options,
196
213
  },
197
214
  }
@@ -205,17 +222,18 @@ class ModelClient(BaseClient):
205
222
  except Exception as e:
206
223
  raise ModelClientError(str(e))
207
224
 
208
- def delete_model(self, model_saved_name: str) -> ModelResponse:
225
+ def delete_model(self, model_saved_name: str, model_version: str) -> ModelResponse:
209
226
  """
210
227
  Delete a specific model from the system.
211
228
 
212
229
  Args:
213
230
  model_saved_name (str): The saved name of the model to delete
231
+ model_version (str): The version of the model to delete
214
232
 
215
233
  Returns:
216
234
  dict: Response from the API containing the deletion status
217
235
  """
218
- headers = {"X-Enkrypt-Model": model_saved_name}
236
+ headers = {"X-Enkrypt-Model": model_saved_name, "X-Enkrypt-Model-Version": model_version}
219
237
 
220
238
  try:
221
239
  response = self._request("DELETE", "/models/delete-model", headers=headers)
enkryptai_sdk/red_team.py CHANGED
@@ -66,13 +66,14 @@ class RedTeamClient(BaseClient):
66
66
  except Exception as e:
67
67
  raise RedTeamClientError(str(e))
68
68
 
69
- def check_saved_model_health(self, model_saved_name: str):
69
+ def check_saved_model_health(self, model_saved_name: str, model_version: str):
70
70
  """
71
71
  Get the health status of a saved model.
72
72
  """
73
73
  try:
74
74
  headers = {
75
75
  "X-Enkrypt-Model": model_saved_name,
76
+ "X-Enkrypt-Model-Version": model_version,
76
77
  }
77
78
  response = self._request("POST", "/redteam/model/model-health", headers=headers)
78
79
  # if response.get("error"):
@@ -125,12 +126,16 @@ class RedTeamClient(BaseClient):
125
126
  self,
126
127
  config: RedTeamConfigWithSavedModel,
127
128
  model_saved_name: str,
129
+ model_version: str,
128
130
  ):
129
131
  """
130
132
  Add a new red teaming task using a saved model.
131
133
  """
132
134
  if not model_saved_name:
133
135
  raise RedTeamClientError("Please provide a model_saved_name")
136
+
137
+ if not model_version:
138
+ raise RedTeamClientError("Please provide a model_version. Default is 'v1'")
134
139
 
135
140
  config = RedTeamConfigWithSavedModel.from_dict(config)
136
141
  test_configs = config.redteam_test_configurations.to_dict()
@@ -148,6 +153,7 @@ class RedTeamClient(BaseClient):
148
153
 
149
154
  headers = {
150
155
  "X-Enkrypt-Model": model_saved_name,
156
+ "X-Enkrypt-Model-Version": model_version,
151
157
  "Content-Type": "application/json",
152
158
  }
153
159
  response = self._request(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: enkryptai-sdk
3
- Version: 1.0.4
3
+ Version: 1.0.6
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -21,9 +21,10 @@ Dynamic: license-file
21
21
  Dynamic: requires-python
22
22
  Dynamic: summary
23
23
 
24
- ![Python SDK test](https://github.com/enkryptai/enkryptai-sdk/actions/workflows/test.yaml/badge.svg)
25
24
  # Enkrypt AI Python SDK
26
25
 
26
+ ![Python SDK test](https://github.com/enkryptai/enkryptai-sdk/actions/workflows/test.yaml/badge.svg)
27
+
27
28
  A Python SDK with Guardrails, Models, Deployments, AI Proxy, Datasets and Red Team functionality for API interactions.
28
29
 
29
30
  See [https://pypi.org/project/enkryptai-sdk](https://pypi.org/project/enkryptai-sdk)
@@ -55,6 +56,7 @@ Also see the API documentation at [https://docs.enkryptai.com](https://docs.enkr
55
56
  - [Guardrails Quickstart](#guardrails-quickstart)
56
57
  - [Guardrails Response Objects](#guardrails-response-objects)
57
58
  - [GuardrailsDetectResponse](#guardrailsdetectresponse)
59
+ - [GuardrailsBatchDetectResponse](#guardrailsbatchdetectresponse)
58
60
  - [Available Guardrails Detectors](#available-guardrails-detectors)
59
61
  - [Guardrails Configs](#guardrails-configs)
60
62
  - [Injection Attack](#injection-attack)
@@ -181,6 +183,7 @@ redteam_client = RedTeamClient(api_key=ENKRYPT_API_KEY, base_url=ENKRYPT_BASE_UR
181
183
  ```python Python
182
184
  test_policy_name = "Test Policy"
183
185
  test_model_saved_name = "Test Model"
186
+ test_model_version = "v1"
184
187
  test_deployment_name = "test-deployment"
185
188
 
186
189
  pii_original_text = "My email is example@example.com. My phone number is 123-456-7890."
@@ -257,14 +260,15 @@ sample_detectors = {
257
260
  ```python Python
258
261
  sample_model_config = {
259
262
  "model_saved_name": test_model_saved_name,
260
- "testing_for": "LLM",
263
+ "model_version": test_model_version,
264
+ "testing_for": "foundationModels",
261
265
  "model_name": model_name,
262
- "modality": "text",
263
266
  "model_config": {
264
- "model_version": "",
265
267
  "model_provider": model_provider,
266
268
  "endpoint_url": model_endpoint_url,
267
269
  "apikey": OPENAI_API_KEY,
270
+ "input_modalities": ["text"],
271
+ "output_modalities": ["text"],
268
272
  },
269
273
  }
270
274
  ```
@@ -275,6 +279,7 @@ sample_model_config = {
275
279
  sample_deployment_config = {
276
280
  "name": test_deployment_name,
277
281
  "model_saved_name": test_model_saved_name,
282
+ "model_version": test_model_version,
278
283
  "input_guardrails_policy": {
279
284
  "policy_name": test_policy_name,
280
285
  "enabled": True,
@@ -308,8 +313,16 @@ sample_dataset_config = {
308
313
  "dataset_name": dataset_name,
309
314
  "system_description": "- **Voter Eligibility**: To vote in U.S. elections, individuals must be U.S. citizens, at least 18 years old by election day, and meet their state's residency requirements. - **Voter Registration**: Most states require voters to register ahead of time, with deadlines varying widely. North Dakota is an exception, as it does not require voter registration. - **Identification Requirements**: Thirty-six states enforce voter ID laws, requiring individuals to present identification at polling places. These laws aim to prevent voter fraud but can also lead to disenfranchisement. - **Voting Methods**: Voters can typically choose between in-person voting on election day, early voting, and absentee or mail-in ballots, depending on state regulations. - **Polling Hours**: Polling hours vary by state, with some states allowing extended hours for voters. Its essential for voters to check local polling times to ensure they can cast their ballots. - **Provisional Ballots**: If there are questions about a voter's eligibility, they may be allowed to cast a provisional ballot. This ballot is counted once eligibility is confirmed. - **Election Day Laws**: Many states have laws that protect the rights of voters on election day, including prohibiting intimidation and ensuring access to polling places. - **Campaign Finance Regulations**: Federal and state laws regulate contributions to candidates and political parties to ensure transparency and limit the influence of money in politics. - **Political Advertising**: Campaigns must adhere to rules regarding political advertising, including disclosure requirements about funding sources and content accuracy. - **Voter Intimidation Prohibitions**: Federal laws prohibit any form of voter intimidation or coercion at polling places, ensuring a safe environment for all voters. - **Accessibility Requirements**: The Americans with Disabilities Act mandates that polling places be accessible to individuals with disabilities, ensuring equal access to the electoral process. - **Election Monitoring**: Various organizations are allowed to monitor elections to ensure compliance with laws and regulations. They help maintain transparency and accountability in the electoral process. - **Vote Counting Procedures**: States have specific procedures for counting votes, including the use of electronic voting machines and manual audits to verify results. - **Ballot Design Standards**: States must adhere to certain design standards for ballots to ensure clarity and prevent confusion among voters when casting their votes. - **Post-Election Audits**: Some states conduct post-election audits as a measure of accuracy. These audits help verify that the vote count reflects the actual ballots cast.",
310
315
  "policy_description": "",
311
- "tools": [],
316
+ "tools": [
317
+ {
318
+ "name": "web_search",
319
+ "description": "The tool web search is used to search the web for information related to finance."
320
+ }
321
+ ],
312
322
  "info_pdf_url": "",
323
+ "scenarios": 1,
324
+ "categories": 1,
325
+ "depth": 1,
313
326
  "max_prompts": 100,
314
327
  }
315
328
  ```
@@ -320,16 +333,16 @@ sample_dataset_config = {
320
333
  sample_redteam_model_health_config = {
321
334
  "target_model_configuration": {
322
335
  "model_name": model_name,
323
- "testing_for": "LLM",
324
- "model_type": "text_2_text",
325
- "model_version": "v1",
336
+ "testing_for": "foundationModels",
337
+ "model_version": test_model_version,
326
338
  "model_source": "https://openai.com",
327
339
  "model_provider": model_provider,
328
340
  "model_endpoint_url": model_endpoint_url,
329
341
  "model_api_key": OPENAI_API_KEY,
330
342
  "system_prompt": "",
331
- "conversation_template": "",
332
- "rate_per_min": 20
343
+ "rate_per_min": 20,
344
+ "input_modalities": ["text"],
345
+ "output_modalities": ["text"]
333
346
  },
334
347
  }
335
348
  ```
@@ -364,16 +377,16 @@ sample_redteam_target_config = {
364
377
  },
365
378
  "target_model_configuration": {
366
379
  "model_name": model_name,
367
- "testing_for": "LLM",
368
- "model_type": "text_2_text",
369
- "model_version": "v1",
380
+ "testing_for": "foundationModels",
381
+ "model_version": test_model_version,
370
382
  "model_source": "https://openai.com",
371
383
  "model_provider": model_provider,
372
384
  "model_endpoint_url": model_endpoint_url,
373
385
  "model_api_key": OPENAI_API_KEY,
374
386
  "system_prompt": "",
375
- "conversation_template": "",
376
- "rate_per_min": 20
387
+ "rate_per_min": 20,
388
+ "input_modalities": ["text"],
389
+ "output_modalities": ["text"]
377
390
  },
378
391
  }
379
392
  ```
@@ -539,6 +552,43 @@ print(detect_response)
539
552
  print(detect_response.to_dict())
540
553
  ```
541
554
 
555
+ ### GuardrailsBatchDetectResponse
556
+
557
+ The `GuardrailsBatchDetectResponse` class wraps `batch_detect` response:
558
+
559
+ ```python Python
560
+ # Example usage of batch_detect with multiple texts
561
+ batch_detect_response = guardrails_client.batch_detect(
562
+ texts=[safe_prompt, bomb_prompt],
563
+ config=copy.deepcopy(sample_detectors)
564
+ )
565
+
566
+ # Batch checks
567
+ print(f"Batch Response Is Safe: {batch_detect_response.is_safe()}")
568
+ print(f"Batch Response Is Attack: {batch_detect_response.is_attack()}")
569
+ print(f"Batch Response Has Violations: {batch_detect_response.has_violations()}")
570
+ print(f"Batch Response All Violations: {batch_detect_response.get_violations()}")
571
+
572
+ # Access results for individual texts
573
+ for idx, detection in enumerate(batch_detect_response.batch_detections):
574
+ print(f"\nResults for text #{idx + 1}:")
575
+
576
+ # Access specific detector results
577
+ if detection.details.injection_attack:
578
+ print(f"Injection Attack Safe: {detection.details.injection_attack.safe}")
579
+ print(f"Injection Attack Score: {detection.details.injection_attack.attack}")
580
+
581
+ # Check safety status for this text
582
+ print(f"Is Safe: {detection.is_safe()}")
583
+ print(f"Is Attack: {detection.is_attack()}")
584
+ print(f"Has Violations: {detection.has_violations()}")
585
+ print(f"Violations: {detection.get_violations()}")
586
+
587
+ # Convert entire batch response to dictionary
588
+ print("\nComplete Batch Response Dictionary:")
589
+ print(batch_detect_response.to_dict())
590
+ ```
591
+
542
592
  ## Available Guardrails Detectors
543
593
 
544
594
  - `injection_attack`: Detect prompt injection attempts
@@ -890,7 +940,7 @@ print(add_model_response.to_dict())
890
940
 
891
941
  ```python Python
892
942
  # Check Model Health
893
- check_saved_model_health = redteam_client.check_saved_model_health(model_saved_name=test_model_saved_name)
943
+ check_saved_model_health = redteam_client.check_saved_model_health(model_saved_name=test_model_saved_name, model_version=test_model_version)
894
944
 
895
945
  print(check_saved_model_health)
896
946
 
@@ -901,11 +951,13 @@ assert check_saved_model_health.status == "healthy"
901
951
 
902
952
  ```python Python
903
953
  # Retrieve model details
904
- model_details = model_client.get_model(model_saved_name=test_model_saved_name)
954
+ model_details = model_client.get_model(model_saved_name=test_model_saved_name, model_version=test_model_version)
905
955
 
906
956
  print(model_details)
907
957
 
908
958
  # Get other fields
959
+ print(model_details.model_saved_name)
960
+ print(model_details.model_version)
909
961
  print(model_details.model_name)
910
962
  print(model_details.model_config)
911
963
  print(model_details.model_config.model_provider)
@@ -940,13 +992,23 @@ new_model_config = copy.deepcopy(sample_model_config)
940
992
  new_model_config["model_name"] = "gpt-4o-mini"
941
993
 
942
994
  # Update the model_saved_name if needed
995
+ # ---------------------------------------------------
996
+ # NOTE:
997
+ # To ensure current stuff does not break, please try creating a new model instead of modifying the existing one.
998
+ # Later, you can delete the old model when changes are made to use the new model.
999
+ # ---------------------------------------------------
943
1000
  # new_model_config["model_saved_name"] = "New Model Name"
1001
+ # new_model_config["model_version"] = "v2"
944
1002
 
945
1003
  old_model_saved_name = None
946
1004
  if new_model_config["model_saved_name"] != test_model_saved_name:
947
1005
  old_model_saved_name = test_model_saved_name
948
1006
 
949
- modify_response = model_client.modify_model(old_model_saved_name=old_model_saved_name, config=new_model_config)
1007
+ old_model_version = None
1008
+ if new_model_config["model_version"] != test_model_version:
1009
+ old_model_version = test_model_version
1010
+
1011
+ modify_response = model_client.modify_model(old_model_saved_name=old_model_saved_name, old_model_version=old_model_version, config=new_model_config)
950
1012
 
951
1013
  print(modify_response)
952
1014
 
@@ -960,7 +1022,7 @@ print(modify_response.to_dict())
960
1022
 
961
1023
  ```python Python
962
1024
  # Remove a model
963
- delete_response = model_client.delete_model(model_saved_name=test_model_saved_name)
1025
+ delete_response = model_client.delete_model(model_saved_name=test_model_saved_name, model_version=test_model_version)
964
1026
 
965
1027
  print(delete_response)
966
1028
 
@@ -996,6 +1058,7 @@ print(deployment_details)
996
1058
 
997
1059
  # Get other fields
998
1060
  print(deployment_details.model_saved_name)
1061
+ print(deployment_details.model_version)
999
1062
  print(deployment_details.input_guardrails_policy)
1000
1063
  print(deployment_details.input_guardrails_policy.policy_name)
1001
1064
 
@@ -1219,7 +1282,7 @@ print(add_redteam_target_response.to_dict())
1219
1282
 
1220
1283
  ```python Python
1221
1284
  # Use a dictionary to configure a redteam task
1222
- add_redteam_model_response = redteam_client.add_task_with_saved_model(config=copy.deepcopy(sample_redteam_model_config),model_saved_name=test_model_saved_name)
1285
+ add_redteam_model_response = redteam_client.add_task_with_saved_model(config=copy.deepcopy(sample_redteam_model_config),model_saved_name=test_model_saved_name, model_version=test_model_version)
1223
1286
 
1224
1287
  print(add_redteam_model_response)
1225
1288
 
@@ -0,0 +1,25 @@
1
+ enkryptai_sdk/__init__.py,sha256=rP6PtntJogJauj1lKWK8DkiBr3uYjireIUamr6aflu0,763
2
+ enkryptai_sdk/ai_proxy.py,sha256=pD6kPmD9H0gttN28cezgV7_IVelLXAHNR5cPeXCM8Ew,3799
3
+ enkryptai_sdk/base.py,sha256=MlEDcEIjXo35kat9XkGUu7VB2fIvJk38C94wAeO9bEw,1304
4
+ enkryptai_sdk/config.py,sha256=IpB8_aO4zXdvv061v24oh83oyJ5Tp1QBQTzeuW4h9QY,8828
5
+ enkryptai_sdk/datasets.py,sha256=2rAU1qeqBZz6rXxG4n9XM_8_ZQYNmcNFW-jgyw7nLPI,5364
6
+ enkryptai_sdk/deployments.py,sha256=qXoUQtLYRSqZsAAqPgxJAad1a3mgMAbPx09NiOyAdsw,4155
7
+ enkryptai_sdk/evals.py,sha256=BywyEgIT7xdJ58svO_sDNOMVowdB0RTGoAZPEbCnDVo,2595
8
+ enkryptai_sdk/guardrails.py,sha256=meR4iZuGV5Js9BRU0edJ82CGOrZMroaTcoaEoUn7k-s,13642
9
+ enkryptai_sdk/guardrails_old.py,sha256=SgzPZkTzbAPD9XfmYNG6M1-TrzbhDHpAkI3FjnVWS_s,6434
10
+ enkryptai_sdk/models.py,sha256=Bj07UAhy2yWAgDxOeInkCTqxDmJK5Ae74v7rNlkA6V4,10361
11
+ enkryptai_sdk/red_team.py,sha256=Puo6BvYDrbERXTuSCFXLtXKhXcFiHuJrcRU5akvXps4,14479
12
+ enkryptai_sdk/response.py,sha256=43JRubzgGCpoVxYNzBZY0AlUgLbfcXD_AwD7wU3qY9o,4086
13
+ enkryptai_sdk/dto/__init__.py,sha256=vG5PmG-L99v89AGiKF1QAF63YILk-e9w0dd665M3U60,2414
14
+ enkryptai_sdk/dto/ai_proxy.py,sha256=clwMN4xdH8Zr55dnhilHbs-qaHRlCOrLPrij0Zd1Av0,11283
15
+ enkryptai_sdk/dto/base.py,sha256=6VWTkoNZ7uILqn_iYsPS21cVa2xLYpw5bjDIsRCS5tk,2389
16
+ enkryptai_sdk/dto/datasets.py,sha256=nI0VjgkViaAi49x01dTDfkzoynx5__zCfJLUsZG8kyw,4846
17
+ enkryptai_sdk/dto/deployments.py,sha256=Aw4b8tDA3FYIomqDvCjblCXTagL4bT8Fx91X0SFXs40,11216
18
+ enkryptai_sdk/dto/guardrails.py,sha256=uolGPPF4v0l76MI5G0ofTtc9-r1l0_sqQqQkLEhAsf0,46305
19
+ enkryptai_sdk/dto/models.py,sha256=6BYF4XwzUwY8O_YgE9S7LPAvsQaRYtUV6ylkTYP7w10,13438
20
+ enkryptai_sdk/dto/red_team.py,sha256=4VGxDGCsZsHi6yJsUZ64IXj6u84xZPcSQuME1PFSnSc,15243
21
+ enkryptai_sdk-1.0.6.dist-info/licenses/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ enkryptai_sdk-1.0.6.dist-info/METADATA,sha256=8htBFgQ36iHitRe08rcvEgB9nnu1pXWBoC3G_i0hdss,45959
23
+ enkryptai_sdk-1.0.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
24
+ enkryptai_sdk-1.0.6.dist-info/top_level.txt,sha256=s2X9UJJwvJamNmr6ZXWyyQe60sXtQGWFuaBYfhgHI_4,14
25
+ enkryptai_sdk-1.0.6.dist-info/RECORD,,
@@ -1,25 +0,0 @@
1
- enkryptai_sdk/__init__.py,sha256=rP6PtntJogJauj1lKWK8DkiBr3uYjireIUamr6aflu0,763
2
- enkryptai_sdk/ai_proxy.py,sha256=pD6kPmD9H0gttN28cezgV7_IVelLXAHNR5cPeXCM8Ew,3799
3
- enkryptai_sdk/base.py,sha256=MlEDcEIjXo35kat9XkGUu7VB2fIvJk38C94wAeO9bEw,1304
4
- enkryptai_sdk/config.py,sha256=IpB8_aO4zXdvv061v24oh83oyJ5Tp1QBQTzeuW4h9QY,8828
5
- enkryptai_sdk/datasets.py,sha256=xekcdY9wIniw2TnylaK6o1RNZ5DRoluNOGawBkVgaM0,4881
6
- enkryptai_sdk/deployments.py,sha256=qXoUQtLYRSqZsAAqPgxJAad1a3mgMAbPx09NiOyAdsw,4155
7
- enkryptai_sdk/evals.py,sha256=BywyEgIT7xdJ58svO_sDNOMVowdB0RTGoAZPEbCnDVo,2595
8
- enkryptai_sdk/guardrails.py,sha256=CtQwPBmpFFy71P22kG0gJbefYJflUT9aareYksPhNlQ,12096
9
- enkryptai_sdk/guardrails_old.py,sha256=SgzPZkTzbAPD9XfmYNG6M1-TrzbhDHpAkI3FjnVWS_s,6434
10
- enkryptai_sdk/models.py,sha256=FSdpfz-e1a0D7nuYCnQC859AdrOUVr60QbLR1mlZjp0,8978
11
- enkryptai_sdk/red_team.py,sha256=jlSf7oNPM7jJHXyv37ROQgkluuigxbSpeeAfvKqUiGI,14192
12
- enkryptai_sdk/response.py,sha256=43JRubzgGCpoVxYNzBZY0AlUgLbfcXD_AwD7wU3qY9o,4086
13
- enkryptai_sdk/dto/__init__.py,sha256=kKBw4rkfqMBuK8nXRDtD6Sd0_uqLKgbcHrqzuSGJpr0,2310
14
- enkryptai_sdk/dto/ai_proxy.py,sha256=clwMN4xdH8Zr55dnhilHbs-qaHRlCOrLPrij0Zd1Av0,11283
15
- enkryptai_sdk/dto/base.py,sha256=6VWTkoNZ7uILqn_iYsPS21cVa2xLYpw5bjDIsRCS5tk,2389
16
- enkryptai_sdk/dto/datasets.py,sha256=E3hvHvGZ94iMvCslTcYM3VCKszVQq_xtu93nlm4dZhI,4444
17
- enkryptai_sdk/dto/deployments.py,sha256=lsKdG09C-rceIjGvEyYOBf5zBjrk7ma8NpPfgrAgdfM,10829
18
- enkryptai_sdk/dto/guardrails.py,sha256=XMFco-KlEqI4TYoJAyxxTrk-OFixtEcftBpG1SXFH4k,39583
19
- enkryptai_sdk/dto/models.py,sha256=O4gVhVTenlsytNJIvk2gO5530KZWMye6FCVCtF5IW-A,11700
20
- enkryptai_sdk/dto/red_team.py,sha256=BoOPYFjIONIC0XPuyJtkx_qLVYi2kLGEA4CzlySgbJA,13829
21
- enkryptai_sdk-1.0.4.dist-info/licenses/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- enkryptai_sdk-1.0.4.dist-info/METADATA,sha256=2lqT-652SgRDCyjpMhmRfyUiF6dK1uL-kl0i1RbKhsc,43207
23
- enkryptai_sdk-1.0.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
24
- enkryptai_sdk-1.0.4.dist-info/top_level.txt,sha256=s2X9UJJwvJamNmr6ZXWyyQe60sXtQGWFuaBYfhgHI_4,14
25
- enkryptai_sdk-1.0.4.dist-info/RECORD,,