enkryptai-sdk 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {enkryptai_sdk-0.1.4/src/enkryptai_sdk.egg-info → enkryptai_sdk-0.1.6}/PKG-INFO +1 -1
  2. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/setup.py +1 -1
  3. enkryptai_sdk-0.1.6/src/enkryptai_sdk/__init__.py +13 -0
  4. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/src/enkryptai_sdk/config.py +107 -46
  5. enkryptai_sdk-0.1.6/src/enkryptai_sdk/dto/__init__.py +18 -0
  6. enkryptai_sdk-0.1.6/src/enkryptai_sdk/dto/models.py +215 -0
  7. enkryptai_sdk-0.1.6/src/enkryptai_sdk/dto/red_team.py +196 -0
  8. enkryptai_sdk-0.1.6/src/enkryptai_sdk/models.py +160 -0
  9. enkryptai_sdk-0.1.6/src/enkryptai_sdk/red_team.py +195 -0
  10. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6/src/enkryptai_sdk.egg-info}/PKG-INFO +1 -1
  11. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/src/enkryptai_sdk.egg-info/SOURCES.txt +3 -0
  12. enkryptai_sdk-0.1.4/src/enkryptai_sdk/__init__.py +0 -5
  13. enkryptai_sdk-0.1.4/src/enkryptai_sdk/models.py +0 -0
  14. enkryptai_sdk-0.1.4/src/enkryptai_sdk/red_team.py +0 -0
  15. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/LICENSE +0 -0
  16. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/README.md +0 -0
  17. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/setup.cfg +0 -0
  18. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/src/enkryptai_sdk/evals.py +0 -0
  19. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/src/enkryptai_sdk/guardrails.py +0 -0
  20. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/src/enkryptai_sdk/response.py +0 -0
  21. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/src/enkryptai_sdk.egg-info/dependency_links.txt +0 -0
  22. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/src/enkryptai_sdk.egg-info/top_level.txt +0 -0
  23. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/tests/test_all.py +0 -0
  24. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/tests/test_basic.py +0 -0
  25. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/tests/test_detect_policy.py +0 -0
  26. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/tests/test_injection_attack.py +0 -0
  27. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.6}/tests/test_policy_violation.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: enkryptai-sdk
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -8,7 +8,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as fh:
8
8
 
9
9
  setup(
10
10
  name="enkryptai-sdk", # This is the name of your package on PyPI
11
- version="0.1.4",
11
+ version="0.1.6",
12
12
  description="A Python SDK with guardrails and red teaming functionality for API interactions",
13
13
  long_description=long_description,
14
14
  long_description_content_type="text/markdown",
@@ -0,0 +1,13 @@
1
+ from .guardrails import GuardrailsClient
2
+ from .config import GuardrailsConfig
3
+ from .evals import EvalsClient
4
+ from .models import ModelClient
5
+ from .red_team import RedTeamClient
6
+
7
+ __all__ = [
8
+ "GuardrailsClient",
9
+ "GuardrailsConfig",
10
+ "EvalsClient",
11
+ "ModelClient",
12
+ "RedTeamClient",
13
+ ]
@@ -1,17 +1,21 @@
1
1
  import copy
2
2
 
3
3
  # Base default configuration for all detectors.
4
- DEFAULT_CONFIG = {
4
+ DEFAULT_GUARDRAILS_CONFIG = {
5
5
  "topic_detector": {"enabled": False, "topic": []},
6
6
  "nsfw": {"enabled": False},
7
7
  "toxicity": {"enabled": False},
8
8
  "pii": {"enabled": False, "entities": []},
9
9
  "injection_attack": {"enabled": False},
10
10
  "keyword_detector": {"enabled": False, "banned_keywords": []},
11
- "policy_violation": {"enabled": False, "policy_text": "", "need_explanation": False},
11
+ "policy_violation": {
12
+ "enabled": False,
13
+ "policy_text": "",
14
+ "need_explanation": False,
15
+ },
12
16
  "bias": {"enabled": False},
13
17
  "copyright_ip": {"enabled": False},
14
- "system_prompt": {"enabled": False, "index": "system"}
18
+ "system_prompt": {"enabled": False, "index": "system"},
15
19
  }
16
20
 
17
21
 
@@ -24,29 +28,30 @@ class GuardrailsConfig:
24
28
 
25
29
  def __init__(self, config=None):
26
30
  # Use a deep copy of the default to avoid accidental mutation.
27
- self.config = copy.deepcopy(DEFAULT_CONFIG) if config is None else config
31
+ self.config = (
32
+ copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG) if config is None else config
33
+ )
28
34
 
29
35
  @classmethod
30
36
  def injection_attack(cls):
31
37
  """
32
38
  Returns a configuration instance pre-configured for injection attack detection.
33
39
  """
34
- config = copy.deepcopy(DEFAULT_CONFIG)
40
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
35
41
  config["injection_attack"] = {"enabled": True}
36
42
  return cls(config)
37
43
 
38
44
  @classmethod
39
- def policy_violation(cls,
40
- policy_text: str,
41
- need_explanation: bool = False):
45
+ def policy_violation(cls, policy_text: str, need_explanation: bool = False):
42
46
  """
43
47
  Returns a configuration instance pre-configured for policy violation detection.
44
48
  """
45
- config = copy.deepcopy(DEFAULT_CONFIG)
46
- config["policy_violation"] = {"enabled": True,
47
- "policy_text": policy_text,
48
- "need_explanation": need_explanation
49
- }
49
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
50
+ config["policy_violation"] = {
51
+ "enabled": True,
52
+ "policy_text": policy_text,
53
+ "need_explanation": need_explanation,
54
+ }
50
55
  return cls(config)
51
56
 
52
57
  @classmethod
@@ -54,7 +59,7 @@ class GuardrailsConfig:
54
59
  """
55
60
  Returns a configuration instance pre-configured for toxicity detection.
56
61
  """
57
- config = copy.deepcopy(DEFAULT_CONFIG)
62
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
58
63
  config["toxicity"] = {"enabled": True}
59
64
  return cls(config)
60
65
 
@@ -63,7 +68,7 @@ class GuardrailsConfig:
63
68
  """
64
69
  Returns a configuration instance pre-configured for NSFW content detection.
65
70
  """
66
- config = copy.deepcopy(DEFAULT_CONFIG)
71
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
67
72
  config["nsfw"] = {"enabled": True}
68
73
  return cls(config)
69
74
 
@@ -72,7 +77,7 @@ class GuardrailsConfig:
72
77
  """
73
78
  Returns a configuration instance pre-configured for bias detection.
74
79
  """
75
- config = copy.deepcopy(DEFAULT_CONFIG)
80
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
76
81
  config["bias"] = {"enabled": True}
77
82
  return cls(config)
78
83
 
@@ -80,14 +85,14 @@ class GuardrailsConfig:
80
85
  def pii(cls, entities=None):
81
86
  """
82
87
  Returns a configuration instance pre-configured for PII detection.
83
-
88
+
84
89
  Args:
85
90
  entities (list, optional): List of PII entity types to detect.
86
91
  """
87
- config = copy.deepcopy(DEFAULT_CONFIG)
92
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
88
93
  config["pii"] = {
89
94
  "enabled": True,
90
- "entities": entities if entities is not None else []
95
+ "entities": entities if entities is not None else [],
91
96
  }
92
97
  return cls(config)
93
98
 
@@ -95,14 +100,14 @@ class GuardrailsConfig:
95
100
  def topic(cls, topics=None):
96
101
  """
97
102
  Returns a configuration instance pre-configured for topic detection.
98
-
103
+
99
104
  Args:
100
105
  topics (list, optional): List of topics to detect.
101
106
  """
102
- config = copy.deepcopy(DEFAULT_CONFIG)
107
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
103
108
  config["topic_detector"] = {
104
109
  "enabled": True,
105
- "topic": topics if topics is not None else []
110
+ "topic": topics if topics is not None else [],
106
111
  }
107
112
  return cls(config)
108
113
 
@@ -110,14 +115,14 @@ class GuardrailsConfig:
110
115
  def keyword(cls, keywords=None):
111
116
  """
112
117
  Returns a configuration instance pre-configured for keyword detection.
113
-
118
+
114
119
  Args:
115
120
  keywords (list, optional): List of banned keywords to detect.
116
121
  """
117
- config = copy.deepcopy(DEFAULT_CONFIG)
122
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
118
123
  config["keyword_detector"] = {
119
124
  "enabled": True,
120
- "banned_keywords": keywords if keywords is not None else []
125
+ "banned_keywords": keywords if keywords is not None else [],
121
126
  }
122
127
  return cls(config)
123
128
 
@@ -126,7 +131,7 @@ class GuardrailsConfig:
126
131
  """
127
132
  Returns a configuration instance pre-configured for copyright/IP detection.
128
133
  """
129
- config = copy.deepcopy(DEFAULT_CONFIG)
134
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
130
135
  config["copyright_ip"] = {"enabled": True}
131
136
  return cls(config)
132
137
 
@@ -134,21 +139,18 @@ class GuardrailsConfig:
134
139
  def system_prompt(cls, index="system"):
135
140
  """
136
141
  Returns a configuration instance pre-configured for system prompt detection.
137
-
142
+
138
143
  Args:
139
144
  index (str, optional): Index name for system prompt detection. Defaults to "system".
140
145
  """
141
- config = copy.deepcopy(DEFAULT_CONFIG)
142
- config["system_prompt"] = {
143
- "enabled": True,
144
- "index": index
145
- }
146
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
147
+ config["system_prompt"] = {"enabled": True, "index": index}
146
148
  return cls(config)
147
149
 
148
150
  def update(self, **kwargs):
149
151
  """
150
152
  Update the configuration with custom values.
151
-
153
+
152
154
  Only keys that exist in the default configuration can be updated.
153
155
  For example:
154
156
  config.update(nsfw={"enabled": True}, toxicity={"enabled": True})
@@ -170,16 +172,16 @@ class GuardrailsConfig:
170
172
  def from_custom_config(cls, config_dict: dict):
171
173
  """
172
174
  Configure guardrails from a dictionary input.
173
-
175
+
174
176
  Validates that the input dictionary matches the expected schema structure.
175
177
  Each key must exist in the default configuration, and its value must be a dictionary.
176
-
178
+
177
179
  Args:
178
180
  config_dict (dict): Dictionary containing guardrails configuration
179
-
181
+
180
182
  Returns:
181
183
  GuardrailsConfig: Returns a new GuardrailsConfig instance
182
-
184
+
183
185
  Raises:
184
186
  ValueError: If the input dictionary contains invalid keys or malformed values
185
187
  """
@@ -189,33 +191,92 @@ class GuardrailsConfig:
189
191
  raise ValueError(f"Unknown detector config: {key}")
190
192
  if not isinstance(value, dict):
191
193
  raise ValueError(f"Config value for {key} must be a dictionary")
192
-
194
+
193
195
  # Validate that all required fields exist in the default config
194
- default_fields = set(DEFAULT_CONFIG[key].keys())
196
+ default_fields = set(DEFAULT_GUARDRAILS_CONFIG[key].keys())
195
197
  provided_fields = set(value.keys())
196
-
198
+
197
199
  if not provided_fields.issubset(default_fields):
198
200
  invalid_fields = provided_fields - default_fields
199
201
  raise ValueError(f"Invalid fields for {key}: {invalid_fields}")
200
-
202
+
201
203
  instance.config[key] = value
202
-
204
+
203
205
  return instance
204
206
 
205
207
  def get_config(self, detector_name: str) -> dict:
206
208
  """
207
209
  Get the configuration for a specific detector.
208
-
210
+
209
211
  Args:
210
212
  detector_name (str): Name of the detector to get configuration for
211
-
213
+
212
214
  Returns:
213
215
  dict: Configuration dictionary for the specified detector
214
-
216
+
215
217
  Raises:
216
218
  ValueError: If the detector name doesn't exist in the configuration
217
219
  """
218
220
  if detector_name not in self.config:
219
221
  raise ValueError(f"Unknown detector: {detector_name}")
220
-
222
+
221
223
  return copy.deepcopy(self.config[detector_name])
224
+
225
+
226
+ class RedTeamConfig:
227
+ """
228
+ A helper class to manage RedTeam configuration.
229
+ """
230
+
231
+ def __init__(self, config=None):
232
+ if config is None:
233
+ config = copy.deepcopy(DEFAULT_REDTEAM_CONFIG)
234
+ # Only include advanced tests if dataset is not standard
235
+ if config.get("dataset_name") != "standard":
236
+ config["redteam_test_configurations"].update(
237
+ copy.deepcopy(ADVANCED_REDTEAM_TESTS)
238
+ )
239
+ self.config = config
240
+
241
+ def as_dict(self):
242
+ """
243
+ Return the underlying configuration dictionary.
244
+ """
245
+ return self.config
246
+
247
+
248
+ class ModelConfig:
249
+ def __init__(self, config=None):
250
+ if config is None:
251
+ config = copy.deepcopy(DETAIL_MODEL_CONFIG)
252
+ self.config = config
253
+
254
+ @classmethod
255
+ def model_name(self, model_name: str):
256
+ """
257
+ Set the model name.
258
+ """
259
+ self.config["model_name"] = model_name
260
+ return self
261
+
262
+ @classmethod
263
+ def testing_for(self, testing_for: str):
264
+ """
265
+ Set the testing for.
266
+ """
267
+ self.config["testing_for"] = testing_for
268
+ return self
269
+
270
+ @classmethod
271
+ def model_config(self, model_config: dict):
272
+ """
273
+ Set the model config.
274
+ """
275
+ self.config["model_config"] = model_config
276
+ return self
277
+
278
+ def as_dict(self):
279
+ """
280
+ Return the underlying configuration dictionary.
281
+ """
282
+ return self.config
@@ -0,0 +1,18 @@
1
+ from .models import *
2
+ from .red_team import *
3
+
4
+ __all__ = [
5
+ "DetailModelConfig",
6
+ "ModelConfig",
7
+ "RedTeamConfig",
8
+ "AdvancedRedTeamTests",
9
+ "TestConfig",
10
+ "AttackMethods",
11
+ "RedTeamTestConfigurations",
12
+ "TargetModelConfiguration",
13
+ "Location",
14
+ "Metadata",
15
+ "DEFAULT_REDTEAM_CONFIG",
16
+ "ADVANCED_REDTEAM_TESTS",
17
+ "DETAIL_MODEL_CONFIG",
18
+ ]
@@ -0,0 +1,215 @@
1
+ from dataclasses import dataclass, field, asdict
2
+ from typing import Optional, List, Set, Dict, Any
3
+ from enum import Enum
4
+ import json
5
+
6
+
7
+ class Modality(Enum):
8
+ TEXT = "text"
9
+ IMAGE = "image"
10
+ AUDIO = "audio"
11
+ VIDEO = "video"
12
+
13
+ def to_dict(self):
14
+ return self.value
15
+
16
+
17
+ @dataclass
18
+ class ModelResponse:
19
+ message: Optional[str] = None
20
+ data: Optional[Dict] = None
21
+
22
+ def to_dict(self):
23
+ return asdict(self)
24
+
25
+ @classmethod
26
+ def from_dict(cls, data: dict):
27
+ return cls(**data)
28
+
29
+
30
+ @dataclass
31
+ class EndpointConfig:
32
+ scheme: str = "https"
33
+ host: str = "api.openai.com"
34
+ port: int = 443
35
+ base_path: str = "v1"
36
+
37
+
38
+ @dataclass
39
+ class PathsConfig:
40
+ completions: str = "/chat/completions"
41
+ chat: str = "chat/completions"
42
+
43
+
44
+ @dataclass
45
+ class AuthData:
46
+ header_name: str = "Authorization"
47
+ header_prefix: str = "Bearer"
48
+ space_after_prefix: bool = True
49
+
50
+ def to_dict(self):
51
+ return asdict(self)
52
+
53
+ @classmethod
54
+ def from_dict(cls, data: dict):
55
+ return cls(**data)
56
+
57
+
58
+ @dataclass
59
+ class ModelDetailConfig:
60
+ model_version: Optional[str] = None
61
+ model_source: str = ""
62
+ model_provider: str = "openai"
63
+ system_prompt: str = ""
64
+
65
+ endpoint_url: str = "https://api.openai.com/v1/chat/completions"
66
+ auth_data: AuthData = field(default_factory=AuthData)
67
+ api_keys: Set[Optional[str]] = field(default_factory=lambda: {None})
68
+
69
+
70
+ @dataclass
71
+ class DetailModelConfig:
72
+ model_saved_name: str = "Model Name"
73
+ testing_for: str = "LLM"
74
+ model_name: str = "gpt-4o-mini"
75
+ modality: Modality = Modality.TEXT
76
+ model_config: ModelDetailConfig = field(default_factory=ModelDetailConfig)
77
+
78
+
79
+ @dataclass
80
+ class ModelConfigDetails:
81
+ model_version: Optional[str] = None
82
+ model_source: str = ""
83
+ model_provider: str = "openai"
84
+ system_prompt: str = ""
85
+ conversation_template: str = ""
86
+ is_compatible_with: str = "openai"
87
+ hosting_type: str = "External"
88
+ endpoint_url: str = "https://api.openai.com/v1/chat/completions"
89
+ auth_data: AuthData = field(default_factory=AuthData)
90
+ apikey: Optional[str] = None
91
+ default_request_options: Dict[str, Any] = field(default_factory=dict)
92
+
93
+ @classmethod
94
+ def from_dict(cls, data: dict):
95
+ # Create a copy of the data to avoid modifying the original
96
+ data = data.copy()
97
+
98
+ # Remove known fields that we don't want in our model
99
+ unwanted_fields = ["queryParams", "paths"]
100
+ for field in unwanted_fields:
101
+ data.pop(field, None)
102
+
103
+ # Handle apikeys to apikey conversion
104
+ if "apikeys" in data:
105
+ apikeys = data.pop("apikeys")
106
+ if apikeys and not data.get("apikey"):
107
+ data["apikey"] = apikeys[0]
108
+
109
+ # Convert endpoint dict to endpoint_url if present
110
+ if "endpoint" in data:
111
+ endpoint = data.pop("endpoint")
112
+ scheme = endpoint.get("scheme", "https")
113
+ host = endpoint.get("host", "")
114
+ port = endpoint.get("port", "")
115
+ base_path = endpoint.get("base_path", "")
116
+
117
+ endpoint_url = f"{scheme}://{host}"
118
+ if port and port not in [80, 443]:
119
+ endpoint_url += f":{port}"
120
+ if base_path:
121
+ base_path = "/" + base_path.strip("/")
122
+ endpoint_url += base_path
123
+
124
+ data["endpoint_url"] = endpoint_url
125
+
126
+ # Handle nested AuthData
127
+ auth_data = data.pop("auth_data", {})
128
+ auth_data_obj = AuthData.from_dict(auth_data)
129
+
130
+ # Only keep fields that are defined in the dataclass
131
+ valid_fields = cls.__dataclass_fields__.keys()
132
+ filtered_data = {k: v for k, v in data.items() if k in valid_fields}
133
+
134
+ return cls(**filtered_data, auth_data=auth_data_obj)
135
+
136
+ def to_dict(self):
137
+ d = asdict(self)
138
+ # Handle AuthData specifically
139
+ d["auth_data"] = self.auth_data.to_dict()
140
+ return d
141
+
142
+ def to_json(self):
143
+ return json.dumps(self.to_dict())
144
+
145
+ @classmethod
146
+ def from_json(cls, json_str: str):
147
+ return cls.from_dict(json.loads(json_str))
148
+
149
+
150
+ @dataclass
151
+ class ModelConfig:
152
+ created_at: str = ""
153
+ updated_at: str = ""
154
+ model_id: str = ""
155
+ model_saved_name: str = "Model Name"
156
+ testing_for: str = "LLM"
157
+ model_name: str = "gpt-4o-mini"
158
+ model_type: str = "text_2_text"
159
+ modality: Modality = Modality.TEXT
160
+ certifications: List[str] = field(default_factory=list)
161
+ model_config: ModelConfigDetails = field(default_factory=ModelConfigDetails)
162
+
163
+ def to_dict(self) -> dict:
164
+ """Convert the ModelConfig instance to a dictionary."""
165
+ # First create a shallow copy of self as dict
166
+ d = {}
167
+ for field in self.__dataclass_fields__:
168
+ value = getattr(self, field)
169
+ if field == "modality":
170
+ d[field] = value.value
171
+ elif field == "model_config":
172
+ if isinstance(value, ModelConfigDetails):
173
+ d[field] = value.to_dict()
174
+ else:
175
+ d[field] = value
176
+ else:
177
+ d[field] = value
178
+ return d
179
+
180
+ def to_json(self) -> str:
181
+ """Convert the ModelConfig instance to a JSON string."""
182
+ return json.dumps(self.to_dict())
183
+
184
+ @classmethod
185
+ def from_dict(cls, data: dict):
186
+ """Create a ModelConfig instance from a dictionary."""
187
+ # Handle nested ModelConfigDetails
188
+ model_config_data = data.pop("model_config", {})
189
+ model_config = ModelConfigDetails.from_dict(model_config_data)
190
+
191
+ # Handle Modality enum
192
+ modality_value = data.pop("modality", "text")
193
+ modality = Modality(modality_value)
194
+
195
+ return cls(**data, modality=modality, model_config=model_config)
196
+
197
+ @classmethod
198
+ def from_json(cls, json_str: str):
199
+ """Create a ModelConfig instance from a JSON string."""
200
+ data = json.loads(json_str)
201
+ return cls.from_dict(data)
202
+
203
+ def __str__(self):
204
+ """String representation of the ModelConfig."""
205
+ return f"ModelConfig(name={self.model_saved_name}, model={self.model_name})"
206
+
207
+ def __repr__(self):
208
+ """Detailed string representation of the ModelConfig."""
209
+ return (
210
+ f"ModelConfig({', '.join(f'{k}={v!r}' for k, v in self.to_dict().items())})"
211
+ )
212
+
213
+
214
+ # Default configuration
215
+ DETAIL_MODEL_CONFIG = ModelConfig()
@@ -0,0 +1,196 @@
1
+ from dataclasses import dataclass, field, asdict
2
+ from typing import Dict, List, Optional
3
+ import json
4
+
5
+
6
+ @dataclass
7
+ class RedTeamResponse:
8
+ task_id: Optional[str] = None
9
+
10
+ def to_dict(self) -> dict:
11
+ return asdict(self)
12
+
13
+ @classmethod
14
+ def from_dict(cls, data: dict):
15
+ return cls(**data)
16
+
17
+
18
+ @dataclass
19
+ class RedTeamTaskStatus:
20
+ status: Optional[str] = None
21
+
22
+ def to_dict(self) -> dict:
23
+ return asdict(self)
24
+
25
+ @classmethod
26
+ def from_dict(cls, data: dict):
27
+ return cls(**data)
28
+
29
+
30
+ @dataclass
31
+ class RedTeamTaskDetails:
32
+ created_at: Optional[str] = None
33
+ model_name: Optional[str] = None
34
+ status: Optional[str] = None
35
+ task_id: Optional[str] = None
36
+
37
+ def to_dict(self) -> dict:
38
+ return asdict(self)
39
+
40
+ @classmethod
41
+ def from_dict(cls, data: dict):
42
+ return cls(**data)
43
+
44
+
45
+ @dataclass
46
+ class RedTeamResultSummary:
47
+ test_date: Optional[str] = None
48
+ test_name: Optional[str] = None
49
+ dataset_name: Optional[str] = None
50
+ model_name: Optional[str] = None
51
+ model_endpoint_url: Optional[str] = None
52
+ model_source: Optional[str] = None
53
+ model_provider: Optional[str] = None
54
+ risk_score: Optional[float] = None
55
+ test_type: Optional[List] = None
56
+ nist_category: Optional[List] = None
57
+ scenario: Optional[List] = None
58
+ category: Optional[List] = None
59
+ attack_method: Optional[List] = None
60
+
61
+ def to_dict(self) -> dict:
62
+ return asdict(self)
63
+
64
+ @classmethod
65
+ def from_dict(cls, data: dict):
66
+ return cls(**data)
67
+
68
+
69
+ @dataclass
70
+ class RedTeamResultDetails: # To Be Updated
71
+ details: Optional[Dict] = None
72
+
73
+ def to_dict(self) -> dict:
74
+ return asdict(self)
75
+
76
+ @classmethod
77
+ def from_dict(cls, data: dict):
78
+ return cls(**data)
79
+
80
+
81
+ @dataclass
82
+ class AttackMethods:
83
+ basic: List[str] = field(default_factory=lambda: ["basic"])
84
+ advanced: Dict[str, List[str]] = field(
85
+ default_factory=lambda: {"static": ["single_shot"], "dynamic": ["iterative"]}
86
+ )
87
+
88
+ def to_dict(self) -> dict:
89
+ return asdict(self)
90
+
91
+ @classmethod
92
+ def from_dict(cls, data: dict):
93
+ return cls(**data)
94
+
95
+
96
+ @dataclass
97
+ class TestConfig:
98
+ sample_percentage: int = 100
99
+ attack_methods: AttackMethods = field(default_factory=AttackMethods)
100
+
101
+ def to_dict(self) -> dict:
102
+ return {
103
+ "sample_percentage": self.sample_percentage,
104
+ "attack_methods": self.attack_methods.to_dict(),
105
+ }
106
+
107
+ @classmethod
108
+ def from_dict(cls, data: dict):
109
+ attack_methods = AttackMethods.from_dict(data.pop("attack_methods", {}))
110
+ return cls(**data, attack_methods=attack_methods)
111
+
112
+
113
+ @dataclass
114
+ class RedTeamTestConfigurations:
115
+ # Basic tests
116
+ bias_test: TestConfig = field(default=None)
117
+ cbrn_test: TestConfig = field(default=None)
118
+ insecure_code_test: TestConfig = field(default=None)
119
+ toxicity_test: TestConfig = field(default=None)
120
+ harmful_test: TestConfig = field(default=None)
121
+ # Advanced tests
122
+ adv_info_test: TestConfig = field(default=None)
123
+ adv_bias_test: TestConfig = field(default=None)
124
+ adv_command_test: TestConfig = field(default=None)
125
+
126
+ def to_dict(self) -> dict:
127
+ return asdict(self)
128
+
129
+ @classmethod
130
+ def from_dict(cls, data: dict):
131
+ return cls(**{k: TestConfig.from_dict(v) for k, v in data.items()})
132
+
133
+
134
+ @dataclass
135
+ class TargetModelConfiguration:
136
+ testing_for: str = "LLM"
137
+ model_name: str = "gpt-4o-mini"
138
+ model_version: Optional[str] = None
139
+ system_prompt: str = ""
140
+ conversation_template: str = ""
141
+ model_source: str = ""
142
+ model_provider: str = "openai"
143
+ model_endpoint_url: str = "https://api.openai.com/v1/chat/completions"
144
+ model_api_key: Optional[str] = None
145
+
146
+ def to_dict(self) -> dict:
147
+ return asdict(self)
148
+
149
+ @classmethod
150
+ def from_dict(cls, data: dict):
151
+ return cls(**data)
152
+
153
+
154
+ @dataclass
155
+ class RedTeamConfig:
156
+ test_name: str = "Test Name"
157
+ dataset_name: str = "standard"
158
+ model_name: str = "gpt-4o-mini"
159
+ redteam_test_configurations: RedTeamTestConfigurations = field(
160
+ default_factory=RedTeamTestConfigurations
161
+ )
162
+ target_model_configuration: TargetModelConfiguration = field(
163
+ default_factory=TargetModelConfiguration
164
+ )
165
+
166
+ def to_dict(self) -> dict:
167
+ d = asdict(self)
168
+ d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
169
+ d["target_model_configuration"] = self.target_model_configuration.to_dict()
170
+ return d
171
+
172
+ def to_json(self) -> str:
173
+ return json.dumps(self.to_dict())
174
+
175
+ @classmethod
176
+ def from_dict(cls, data: dict):
177
+ data = data.copy()
178
+ test_configs = RedTeamTestConfigurations.from_dict(
179
+ data.pop("redteam_test_configurations", {})
180
+ )
181
+ target_config = TargetModelConfiguration.from_dict(
182
+ data.pop("target_model_configuration", {})
183
+ )
184
+ return cls(
185
+ **data,
186
+ redteam_test_configurations=test_configs,
187
+ target_model_configuration=target_config,
188
+ )
189
+
190
+ @classmethod
191
+ def from_json(cls, json_str: str):
192
+ return cls.from_dict(json.loads(json_str))
193
+
194
+
195
+ # Default configurations
196
+ DEFAULT_REDTEAM_CONFIG = RedTeamConfig()
@@ -0,0 +1,160 @@
1
+ import urllib3
2
+ from .dto import ModelConfig, ModelResponse
3
+ from urllib.parse import urlparse, urlsplit
4
+
5
+
6
+ class ModelClientError(Exception):
7
+ """
8
+ A custom exception for ModelClient errors.
9
+ """
10
+
11
+ pass
12
+
13
+
14
+ class ModelClient:
15
+ def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
16
+ self.api_key = api_key
17
+ self.base_url = base_url
18
+ self.http = urllib3.PoolManager()
19
+ self.headers = {"apikey": self.api_key}
20
+
21
+ def _request(self, method, endpoint, payload=None, headers=None, **kwargs):
22
+ url = self.base_url + endpoint
23
+ request_headers = {
24
+ "Accept-Encoding": "gzip", # Add required gzip encoding
25
+ **self.headers,
26
+ }
27
+ if headers:
28
+ request_headers.update(headers)
29
+
30
+ try:
31
+ response = self.http.request(method, url, headers=request_headers, **kwargs)
32
+
33
+ if response.status >= 400:
34
+ error_data = (
35
+ response.json()
36
+ if response.data
37
+ else {"message": f"HTTP {response.status}"}
38
+ )
39
+ error_message = error_data.get("message", str(error_data))
40
+ raise urllib3.exceptions.HTTPError(error_message)
41
+ return response.json()
42
+ except urllib3.exceptions.HTTPError as e:
43
+ return {"error": str(e)}
44
+
45
+ def health(self):
46
+ return self._request("GET", "/models/health")
47
+
48
+ def add_model(self, config: ModelConfig):
49
+ """
50
+ Add a new model configuration to the system.
51
+
52
+ Args:
53
+ config (ModelConfig): Configuration object containing model details
54
+
55
+ Returns:
56
+ dict: Response from the API containing the added model details
57
+ """
58
+ headers = {"Content-Type": "application/json"}
59
+ config = ModelConfig.from_dict(config)
60
+ # Parse endpoint_url into components
61
+ parsed_url = urlparse(config.model_config.endpoint_url)
62
+ path_parts = parsed_url.path.strip("/").split("/")
63
+
64
+ # Extract base_path and endpoint path
65
+ if len(path_parts) >= 1:
66
+ base_path = path_parts[0] # Usually 'v1'
67
+ remaining_path = "/".join(path_parts[1:]) # The rest of the path
68
+ else:
69
+ base_path = ""
70
+ remaining_path = ""
71
+
72
+ # Determine paths based on the endpoint
73
+ paths = {
74
+ "completions": f"/{remaining_path}" if remaining_path else "",
75
+ "chat": "",
76
+ }
77
+
78
+ payload = {
79
+ "model_saved_name": config.model_saved_name,
80
+ "testing_for": config.testing_for,
81
+ "model_name": config.model_name,
82
+ "model_type": config.model_type,
83
+ "certifications": config.certifications,
84
+ "model_config": {
85
+ "is_compatible_with": config.model_config.is_compatible_with,
86
+ "model_version": config.model_config.model_version,
87
+ "hosting_type": config.model_config.hosting_type,
88
+ "model_source": config.model_config.model_source,
89
+ "model_provider": config.model_config.model_provider,
90
+ "system_prompt": config.model_config.system_prompt,
91
+ "conversation_template": config.model_config.conversation_template,
92
+ "endpoint": {
93
+ "scheme": parsed_url.scheme,
94
+ "host": parsed_url.hostname,
95
+ "port": parsed_url.port
96
+ or (443 if parsed_url.scheme == "https" else 80),
97
+ "base_path": f"/{base_path}/{paths['completions']}", # Just v1
98
+ },
99
+ "paths": paths,
100
+ "auth_data": {
101
+ "header_name": config.model_config.auth_data.header_name,
102
+ "header_prefix": config.model_config.auth_data.header_prefix,
103
+ "space_after_prefix": config.model_config.auth_data.space_after_prefix,
104
+ },
105
+ "apikeys": (
106
+ [config.model_config.apikey] if config.model_config.apikey else []
107
+ ),
108
+ "default_request_options": config.model_config.default_request_options,
109
+ },
110
+ }
111
+ try:
112
+ response = self._request(
113
+ "POST", "/models/add-model", headers=headers, json=payload
114
+ )
115
+ if response.get("error"):
116
+ raise ModelClientError(response["error"])
117
+ return ModelResponse.from_dict(response)
118
+ except Exception as e:
119
+ raise ModelClientError(str(e))
120
+
121
+ def get_model(self, model_id: str) -> ModelConfig:
122
+ """
123
+ Get model configuration by model ID.
124
+
125
+ Args:
126
+ model_id (str): ID of the model to retrieve
127
+
128
+ Returns:
129
+ ModelConfig: Configuration object containing model details
130
+ """
131
+ headers = {"X-Enkrypt-Model": model_id}
132
+ response = self._request("GET", "/models/get-model", headers=headers)
133
+ if response.get("error"):
134
+ raise ModelClientError(response["error"])
135
+ return ModelConfig.from_dict(response)
136
+
137
+ def get_model_list(self):
138
+ """
139
+ Get a list of all available models.
140
+
141
+ Returns:
142
+ dict: Response from the API containing the list of models
143
+ """
144
+ try:
145
+ return self._request("GET", "/models/list-models")
146
+ except Exception as e:
147
+ return {"error": str(e)}
148
+
149
+ def delete_model(self, model_id: str):
150
+ """
151
+ Delete a specific model from the system.
152
+
153
+ Args:
154
+ model_id (str): The identifier or name of the model to delete
155
+
156
+ Returns:
157
+ dict: Response from the API containing the deletion status
158
+ """
159
+ headers = {"X-Enkrypt-Model": model_id}
160
+ return self._request("DELETE", "/models/delete-model", headers=headers)
@@ -0,0 +1,195 @@
1
+ import urllib3
2
+ from .dto import (
3
+ RedTeamConfig,
4
+ RedTeamResponse,
5
+ RedTeamResultSummary,
6
+ RedTeamResultDetails,
7
+ RedTeamTaskStatus,
8
+ RedTeamTaskDetails,
9
+ )
10
+
11
+
12
+ class RedTeamClientError(Exception):
13
+ """
14
+ A custom exception for Red Team errors.
15
+ """
16
+
17
+ pass
18
+
19
+
20
+ class RedTeamClient:
21
+ """
22
+ A client for interacting with the Red Team API.
23
+ """
24
+
25
+ def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com"):
26
+ self.api_key = api_key
27
+ self.base_url = base_url
28
+ self.http = urllib3.PoolManager()
29
+ self.headers = {"apikey": self.api_key}
30
+
31
+ def _request(self, method, endpoint, headers=None, **kwargs):
32
+ url = self.base_url + endpoint
33
+ request_headers = {
34
+ "Accept-Encoding": "gzip", # Add required gzip encoding
35
+ **self.headers,
36
+ }
37
+ if headers:
38
+ request_headers.update(headers)
39
+
40
+ try:
41
+ response = self.http.request(method, url, headers=request_headers, **kwargs)
42
+
43
+ if response.status >= 400:
44
+ error_data = (
45
+ response.json()
46
+ if response.data
47
+ else {"message": f"HTTP {response.status}"}
48
+ )
49
+ error_message = error_data.get("message", str(error_data))
50
+ raise urllib3.exceptions.HTTPError(error_message)
51
+ return response.json()
52
+ except urllib3.exceptions.HTTPError as e:
53
+ return {"error": str(e)}
54
+
55
+ def get_model(self, model):
56
+ models = self._request("GET", "/models/list-models")
57
+ if model in models:
58
+ return model
59
+ else:
60
+ return None
61
+
62
+ def add_task(
63
+ self,
64
+ config: RedTeamConfig,
65
+ ):
66
+ """
67
+ Add a new red teaming task.
68
+ """
69
+ config = RedTeamConfig.from_dict(config)
70
+ test_configs = config.redteam_test_configurations.to_dict()
71
+ # Remove None or empty test configurations
72
+ test_configs = {k: v for k, v in test_configs.items() if v is not None}
73
+
74
+ payload = {
75
+ # "async": config.async_enabled,
76
+ "dataset_name": config.dataset_name,
77
+ "test_name": config.test_name,
78
+ "redteam_test_configurations": test_configs,
79
+ }
80
+
81
+ model = config.model_name
82
+ saved_model = self.get_model(model)
83
+
84
+ if saved_model:
85
+ print("saved model found")
86
+ headers = {
87
+ "X-Enkrypt-Model": saved_model,
88
+ "Content-Type": "application/json",
89
+ }
90
+ payload["location"] = {"storage": "supabase", "container_name": "supabase"}
91
+ return self._request(
92
+ "POST",
93
+ "/redteam/v2/model/add-task",
94
+ headers=headers,
95
+ json=payload,
96
+ )
97
+ elif config.target_model_configuration:
98
+ payload["target_model_configuration"] = (
99
+ config.target_model_configuration.to_dict()
100
+ )
101
+ # print(payload)
102
+ response = self._request(
103
+ "POST",
104
+ "/redteam/v2/add-task",
105
+ json=payload,
106
+ )
107
+ if response.get("error"):
108
+ raise RedTeamClientError(response["error"])
109
+ return RedTeamResponse.from_dict(response)
110
+ else:
111
+ raise RedTeamClientError(
112
+ "Please use a saved model or provide a target model configuration"
113
+ )
114
+
115
+ def status(self, task_id: str):
116
+ """
117
+ Get the status of a specific red teaming task.
118
+
119
+ Args:
120
+ task_id (str): The ID of the task to check status
121
+
122
+ Returns:
123
+ dict: The task status information
124
+ """
125
+ headers = {"X-Enkrypt-Task-ID": task_id}
126
+
127
+ response = self._request("GET", "/redteam/task-status", headers=headers)
128
+ if response.get("error"):
129
+ raise RedTeamClientError(response["error"])
130
+ return RedTeamTaskStatus.from_dict(response)
131
+
132
+ def cancel_task(self, task_id: str):
133
+ """
134
+ Cancel a specific red teaming task.
135
+
136
+ Args:
137
+ task_id (str): The ID of the task to cancel
138
+ """
139
+ raise RedTeamClientError(
140
+ "This feature is currently under development. Please check our documentation "
141
+ "at https://docs.enkrypt.ai for updates or contact support@enkrypt.ai for assistance."
142
+ )
143
+
144
+ def get_task(self, task_id: str):
145
+ """
146
+ Get the status and details of a specific red teaming task.
147
+
148
+ Args:
149
+ task_id (str): The ID of the task to retrieve
150
+
151
+ Returns:
152
+ dict: The task details and status
153
+ """
154
+ headers = {"X-Enkrypt-Task-ID": task_id}
155
+
156
+ response = self._request("GET", "/redteam/get-task", headers=headers)
157
+ if response.get("error"):
158
+ raise RedTeamClientError(response["error"])
159
+ if response.get("data").get("job_id "):
160
+ response["data"]["task_id"] = response["data"].pop("job_id")
161
+ return RedTeamTaskDetails.from_dict(response["data"])
162
+
163
+ def get_result_summary(self, task_id: str):
164
+ """
165
+ Get the summary of results for a specific red teaming task.
166
+
167
+ Args:
168
+ task_id (str): The ID of the task to get results for
169
+
170
+ Returns:
171
+ dict: The summary of the task results
172
+ """
173
+ headers = {"X-Enkrypt-Task-ID": task_id}
174
+
175
+ response = self._request("GET", "/redteam/results/summary", headers=headers)
176
+ if response.get("error"):
177
+ raise RedTeamClientError(response["error"])
178
+ return RedTeamResultSummary.from_dict(response["summary"])
179
+
180
+ def get_result_details(self, task_id: str):
181
+ """
182
+ Get the detailed results for a specific red teaming task.
183
+
184
+ Args:
185
+ task_id (str): The ID of the task to get detailed results for
186
+
187
+ Returns:
188
+ dict: The detailed task results
189
+ """
190
+ # TODO: Update the response to be updated
191
+ headers = {"X-Enkrypt-Task-ID": task_id}
192
+ response = self._request("GET", "/redteam/results/details", headers=headers)
193
+ if response.get("error"):
194
+ raise RedTeamClientError(response["error"])
195
+ return RedTeamResultDetails.from_dict(response["details"])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: enkryptai-sdk
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -12,6 +12,9 @@ src/enkryptai_sdk.egg-info/PKG-INFO
12
12
  src/enkryptai_sdk.egg-info/SOURCES.txt
13
13
  src/enkryptai_sdk.egg-info/dependency_links.txt
14
14
  src/enkryptai_sdk.egg-info/top_level.txt
15
+ src/enkryptai_sdk/dto/__init__.py
16
+ src/enkryptai_sdk/dto/models.py
17
+ src/enkryptai_sdk/dto/red_team.py
15
18
  tests/test_all.py
16
19
  tests/test_basic.py
17
20
  tests/test_detect_policy.py
@@ -1,5 +0,0 @@
1
- from .guardrails import GuardrailsClient
2
- from .config import GuardrailsConfig
3
- from .evals import EvalsClient
4
-
5
- __all__ = ["GuardrailsClient", "GuardrailsConfig", "EvalsClient"]
File without changes
File without changes
File without changes
File without changes
File without changes