enkryptai-sdk 0.1.4__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {enkryptai_sdk-0.1.4/src/enkryptai_sdk.egg-info → enkryptai_sdk-0.1.5}/PKG-INFO +1 -1
  2. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/setup.py +1 -1
  3. enkryptai_sdk-0.1.5/src/enkryptai_sdk/__init__.py +13 -0
  4. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/config.py +107 -46
  5. enkryptai_sdk-0.1.5/src/enkryptai_sdk/dto/__init__.py +18 -0
  6. enkryptai_sdk-0.1.5/src/enkryptai_sdk/dto/models.py +202 -0
  7. enkryptai_sdk-0.1.5/src/enkryptai_sdk/dto/red_team.py +196 -0
  8. enkryptai_sdk-0.1.5/src/enkryptai_sdk/models.py +144 -0
  9. enkryptai_sdk-0.1.5/src/enkryptai_sdk/red_team.py +185 -0
  10. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5/src/enkryptai_sdk.egg-info}/PKG-INFO +1 -1
  11. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk.egg-info/SOURCES.txt +3 -0
  12. enkryptai_sdk-0.1.4/src/enkryptai_sdk/__init__.py +0 -5
  13. enkryptai_sdk-0.1.4/src/enkryptai_sdk/models.py +0 -0
  14. enkryptai_sdk-0.1.4/src/enkryptai_sdk/red_team.py +0 -0
  15. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/LICENSE +0 -0
  16. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/README.md +0 -0
  17. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/setup.cfg +0 -0
  18. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/evals.py +0 -0
  19. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/guardrails.py +0 -0
  20. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk/response.py +0 -0
  21. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk.egg-info/dependency_links.txt +0 -0
  22. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/src/enkryptai_sdk.egg-info/top_level.txt +0 -0
  23. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_all.py +0 -0
  24. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_basic.py +0 -0
  25. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_detect_policy.py +0 -0
  26. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_injection_attack.py +0 -0
  27. {enkryptai_sdk-0.1.4 → enkryptai_sdk-0.1.5}/tests/test_policy_violation.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: enkryptai-sdk
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -8,7 +8,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as fh:
8
8
 
9
9
  setup(
10
10
  name="enkryptai-sdk", # This is the name of your package on PyPI
11
- version="0.1.4",
11
+ version="0.1.5",
12
12
  description="A Python SDK with guardrails and red teaming functionality for API interactions",
13
13
  long_description=long_description,
14
14
  long_description_content_type="text/markdown",
@@ -0,0 +1,13 @@
1
+ from .guardrails import GuardrailsClient
2
+ from .config import GuardrailsConfig
3
+ from .evals import EvalsClient
4
+ from .models import ModelClient
5
+ from .red_team import RedTeamClient
6
+
7
+ __all__ = [
8
+ "GuardrailsClient",
9
+ "GuardrailsConfig",
10
+ "EvalsClient",
11
+ "ModelClient",
12
+ "RedTeamClient",
13
+ ]
@@ -1,17 +1,21 @@
1
1
  import copy
2
2
 
3
3
  # Base default configuration for all detectors.
4
- DEFAULT_CONFIG = {
4
+ DEFAULT_GUARDRAILS_CONFIG = {
5
5
  "topic_detector": {"enabled": False, "topic": []},
6
6
  "nsfw": {"enabled": False},
7
7
  "toxicity": {"enabled": False},
8
8
  "pii": {"enabled": False, "entities": []},
9
9
  "injection_attack": {"enabled": False},
10
10
  "keyword_detector": {"enabled": False, "banned_keywords": []},
11
- "policy_violation": {"enabled": False, "policy_text": "", "need_explanation": False},
11
+ "policy_violation": {
12
+ "enabled": False,
13
+ "policy_text": "",
14
+ "need_explanation": False,
15
+ },
12
16
  "bias": {"enabled": False},
13
17
  "copyright_ip": {"enabled": False},
14
- "system_prompt": {"enabled": False, "index": "system"}
18
+ "system_prompt": {"enabled": False, "index": "system"},
15
19
  }
16
20
 
17
21
 
@@ -24,29 +28,30 @@ class GuardrailsConfig:
24
28
 
25
29
  def __init__(self, config=None):
26
30
  # Use a deep copy of the default to avoid accidental mutation.
27
- self.config = copy.deepcopy(DEFAULT_CONFIG) if config is None else config
31
+ self.config = (
32
+ copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG) if config is None else config
33
+ )
28
34
 
29
35
  @classmethod
30
36
  def injection_attack(cls):
31
37
  """
32
38
  Returns a configuration instance pre-configured for injection attack detection.
33
39
  """
34
- config = copy.deepcopy(DEFAULT_CONFIG)
40
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
35
41
  config["injection_attack"] = {"enabled": True}
36
42
  return cls(config)
37
43
 
38
44
  @classmethod
39
- def policy_violation(cls,
40
- policy_text: str,
41
- need_explanation: bool = False):
45
+ def policy_violation(cls, policy_text: str, need_explanation: bool = False):
42
46
  """
43
47
  Returns a configuration instance pre-configured for policy violation detection.
44
48
  """
45
- config = copy.deepcopy(DEFAULT_CONFIG)
46
- config["policy_violation"] = {"enabled": True,
47
- "policy_text": policy_text,
48
- "need_explanation": need_explanation
49
- }
49
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
50
+ config["policy_violation"] = {
51
+ "enabled": True,
52
+ "policy_text": policy_text,
53
+ "need_explanation": need_explanation,
54
+ }
50
55
  return cls(config)
51
56
 
52
57
  @classmethod
@@ -54,7 +59,7 @@ class GuardrailsConfig:
54
59
  """
55
60
  Returns a configuration instance pre-configured for toxicity detection.
56
61
  """
57
- config = copy.deepcopy(DEFAULT_CONFIG)
62
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
58
63
  config["toxicity"] = {"enabled": True}
59
64
  return cls(config)
60
65
 
@@ -63,7 +68,7 @@ class GuardrailsConfig:
63
68
  """
64
69
  Returns a configuration instance pre-configured for NSFW content detection.
65
70
  """
66
- config = copy.deepcopy(DEFAULT_CONFIG)
71
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
67
72
  config["nsfw"] = {"enabled": True}
68
73
  return cls(config)
69
74
 
@@ -72,7 +77,7 @@ class GuardrailsConfig:
72
77
  """
73
78
  Returns a configuration instance pre-configured for bias detection.
74
79
  """
75
- config = copy.deepcopy(DEFAULT_CONFIG)
80
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
76
81
  config["bias"] = {"enabled": True}
77
82
  return cls(config)
78
83
 
@@ -80,14 +85,14 @@ class GuardrailsConfig:
80
85
  def pii(cls, entities=None):
81
86
  """
82
87
  Returns a configuration instance pre-configured for PII detection.
83
-
88
+
84
89
  Args:
85
90
  entities (list, optional): List of PII entity types to detect.
86
91
  """
87
- config = copy.deepcopy(DEFAULT_CONFIG)
92
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
88
93
  config["pii"] = {
89
94
  "enabled": True,
90
- "entities": entities if entities is not None else []
95
+ "entities": entities if entities is not None else [],
91
96
  }
92
97
  return cls(config)
93
98
 
@@ -95,14 +100,14 @@ class GuardrailsConfig:
95
100
  def topic(cls, topics=None):
96
101
  """
97
102
  Returns a configuration instance pre-configured for topic detection.
98
-
103
+
99
104
  Args:
100
105
  topics (list, optional): List of topics to detect.
101
106
  """
102
- config = copy.deepcopy(DEFAULT_CONFIG)
107
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
103
108
  config["topic_detector"] = {
104
109
  "enabled": True,
105
- "topic": topics if topics is not None else []
110
+ "topic": topics if topics is not None else [],
106
111
  }
107
112
  return cls(config)
108
113
 
@@ -110,14 +115,14 @@ class GuardrailsConfig:
110
115
  def keyword(cls, keywords=None):
111
116
  """
112
117
  Returns a configuration instance pre-configured for keyword detection.
113
-
118
+
114
119
  Args:
115
120
  keywords (list, optional): List of banned keywords to detect.
116
121
  """
117
- config = copy.deepcopy(DEFAULT_CONFIG)
122
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
118
123
  config["keyword_detector"] = {
119
124
  "enabled": True,
120
- "banned_keywords": keywords if keywords is not None else []
125
+ "banned_keywords": keywords if keywords is not None else [],
121
126
  }
122
127
  return cls(config)
123
128
 
@@ -126,7 +131,7 @@ class GuardrailsConfig:
126
131
  """
127
132
  Returns a configuration instance pre-configured for copyright/IP detection.
128
133
  """
129
- config = copy.deepcopy(DEFAULT_CONFIG)
134
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
130
135
  config["copyright_ip"] = {"enabled": True}
131
136
  return cls(config)
132
137
 
@@ -134,21 +139,18 @@ class GuardrailsConfig:
134
139
  def system_prompt(cls, index="system"):
135
140
  """
136
141
  Returns a configuration instance pre-configured for system prompt detection.
137
-
142
+
138
143
  Args:
139
144
  index (str, optional): Index name for system prompt detection. Defaults to "system".
140
145
  """
141
- config = copy.deepcopy(DEFAULT_CONFIG)
142
- config["system_prompt"] = {
143
- "enabled": True,
144
- "index": index
145
- }
146
+ config = copy.deepcopy(DEFAULT_GUARDRAILS_CONFIG)
147
+ config["system_prompt"] = {"enabled": True, "index": index}
146
148
  return cls(config)
147
149
 
148
150
  def update(self, **kwargs):
149
151
  """
150
152
  Update the configuration with custom values.
151
-
153
+
152
154
  Only keys that exist in the default configuration can be updated.
153
155
  For example:
154
156
  config.update(nsfw={"enabled": True}, toxicity={"enabled": True})
@@ -170,16 +172,16 @@ class GuardrailsConfig:
170
172
  def from_custom_config(cls, config_dict: dict):
171
173
  """
172
174
  Configure guardrails from a dictionary input.
173
-
175
+
174
176
  Validates that the input dictionary matches the expected schema structure.
175
177
  Each key must exist in the default configuration, and its value must be a dictionary.
176
-
178
+
177
179
  Args:
178
180
  config_dict (dict): Dictionary containing guardrails configuration
179
-
181
+
180
182
  Returns:
181
183
  GuardrailsConfig: Returns a new GuardrailsConfig instance
182
-
184
+
183
185
  Raises:
184
186
  ValueError: If the input dictionary contains invalid keys or malformed values
185
187
  """
@@ -189,33 +191,92 @@ class GuardrailsConfig:
189
191
  raise ValueError(f"Unknown detector config: {key}")
190
192
  if not isinstance(value, dict):
191
193
  raise ValueError(f"Config value for {key} must be a dictionary")
192
-
194
+
193
195
  # Validate that all required fields exist in the default config
194
- default_fields = set(DEFAULT_CONFIG[key].keys())
196
+ default_fields = set(DEFAULT_GUARDRAILS_CONFIG[key].keys())
195
197
  provided_fields = set(value.keys())
196
-
198
+
197
199
  if not provided_fields.issubset(default_fields):
198
200
  invalid_fields = provided_fields - default_fields
199
201
  raise ValueError(f"Invalid fields for {key}: {invalid_fields}")
200
-
202
+
201
203
  instance.config[key] = value
202
-
204
+
203
205
  return instance
204
206
 
205
207
  def get_config(self, detector_name: str) -> dict:
206
208
  """
207
209
  Get the configuration for a specific detector.
208
-
210
+
209
211
  Args:
210
212
  detector_name (str): Name of the detector to get configuration for
211
-
213
+
212
214
  Returns:
213
215
  dict: Configuration dictionary for the specified detector
214
-
216
+
215
217
  Raises:
216
218
  ValueError: If the detector name doesn't exist in the configuration
217
219
  """
218
220
  if detector_name not in self.config:
219
221
  raise ValueError(f"Unknown detector: {detector_name}")
220
-
222
+
221
223
  return copy.deepcopy(self.config[detector_name])
224
+
225
+
226
+ class RedTeamConfig:
227
+ """
228
+ A helper class to manage RedTeam configuration.
229
+ """
230
+
231
+ def __init__(self, config=None):
232
+ if config is None:
233
+ config = copy.deepcopy(DEFAULT_REDTEAM_CONFIG)
234
+ # Only include advanced tests if dataset is not standard
235
+ if config.get("dataset_name") != "standard":
236
+ config["redteam_test_configurations"].update(
237
+ copy.deepcopy(ADVANCED_REDTEAM_TESTS)
238
+ )
239
+ self.config = config
240
+
241
+ def as_dict(self):
242
+ """
243
+ Return the underlying configuration dictionary.
244
+ """
245
+ return self.config
246
+
247
+
248
+ class ModelConfig:
249
+ def __init__(self, config=None):
250
+ if config is None:
251
+ config = copy.deepcopy(DETAIL_MODEL_CONFIG)
252
+ self.config = config
253
+
254
+ @classmethod
255
+ def model_name(self, model_name: str):
256
+ """
257
+ Set the model name.
258
+ """
259
+ self.config["model_name"] = model_name
260
+ return self
261
+
262
+ @classmethod
263
+ def testing_for(self, testing_for: str):
264
+ """
265
+ Set the testing for.
266
+ """
267
+ self.config["testing_for"] = testing_for
268
+ return self
269
+
270
+ @classmethod
271
+ def model_config(self, model_config: dict):
272
+ """
273
+ Set the model config.
274
+ """
275
+ self.config["model_config"] = model_config
276
+ return self
277
+
278
+ def as_dict(self):
279
+ """
280
+ Return the underlying configuration dictionary.
281
+ """
282
+ return self.config
@@ -0,0 +1,18 @@
1
+ from .models import *
2
+ from .red_team import *
3
+
4
+ __all__ = [
5
+ "DetailModelConfig",
6
+ "ModelConfig",
7
+ "RedTeamConfig",
8
+ "AdvancedRedTeamTests",
9
+ "TestConfig",
10
+ "AttackMethods",
11
+ "RedTeamTestConfigurations",
12
+ "TargetModelConfiguration",
13
+ "Location",
14
+ "Metadata",
15
+ "DEFAULT_REDTEAM_CONFIG",
16
+ "ADVANCED_REDTEAM_TESTS",
17
+ "DETAIL_MODEL_CONFIG",
18
+ ]
@@ -0,0 +1,202 @@
1
+ from dataclasses import dataclass, field, asdict
2
+ from typing import Optional, List, Set, Dict, Any
3
+ from enum import Enum
4
+ import json
5
+
6
+
7
+ class Modality(Enum):
8
+ TEXT = "text"
9
+ IMAGE = "image"
10
+ AUDIO = "audio"
11
+ VIDEO = "video"
12
+
13
+ def to_dict(self):
14
+ return self.value
15
+
16
+
17
+ @dataclass
18
+ class EndpointConfig:
19
+ scheme: str = "https"
20
+ host: str = "api.openai.com"
21
+ port: int = 443
22
+ base_path: str = "v1"
23
+
24
+
25
+ @dataclass
26
+ class PathsConfig:
27
+ completions: str = "/chat/completions"
28
+ chat: str = "chat/completions"
29
+
30
+
31
+ @dataclass
32
+ class AuthData:
33
+ header_name: str = "Authorization"
34
+ header_prefix: str = "Bearer"
35
+ space_after_prefix: bool = True
36
+
37
+ def to_dict(self):
38
+ return asdict(self)
39
+
40
+ @classmethod
41
+ def from_dict(cls, data: dict):
42
+ return cls(**data)
43
+
44
+
45
+ @dataclass
46
+ class ModelDetailConfig:
47
+ model_version: Optional[str] = None
48
+ model_source: str = ""
49
+ model_provider: str = "openai"
50
+ system_prompt: str = ""
51
+
52
+ endpoint_url: str = "https://api.openai.com/v1/chat/completions"
53
+ auth_data: AuthData = field(default_factory=AuthData)
54
+ api_keys: Set[Optional[str]] = field(default_factory=lambda: {None})
55
+
56
+
57
+ @dataclass
58
+ class DetailModelConfig:
59
+ model_saved_name: str = "Model Name"
60
+ testing_for: str = "LLM"
61
+ model_name: str = "gpt-4o-mini"
62
+ modality: Modality = Modality.TEXT
63
+ model_config: ModelDetailConfig = field(default_factory=ModelDetailConfig)
64
+
65
+
66
+ @dataclass
67
+ class ModelConfigDetails:
68
+ model_version: Optional[str] = None
69
+ model_source: str = ""
70
+ model_provider: str = "openai"
71
+ system_prompt: str = ""
72
+ conversation_template: str = ""
73
+ is_compatible_with: str = "openai"
74
+ hosting_type: str = "External"
75
+ endpoint_url: str = "https://api.openai.com/v1/chat/completions"
76
+ auth_data: AuthData = field(default_factory=AuthData)
77
+ apikey: Optional[str] = None
78
+ default_request_options: Dict[str, Any] = field(default_factory=dict)
79
+
80
+ @classmethod
81
+ def from_dict(cls, data: dict):
82
+ # Create a copy of the data to avoid modifying the original
83
+ data = data.copy()
84
+
85
+ # Remove known fields that we don't want in our model
86
+ unwanted_fields = ["queryParams", "paths"]
87
+ for field in unwanted_fields:
88
+ data.pop(field, None)
89
+
90
+ # Handle apikeys to apikey conversion
91
+ if "apikeys" in data:
92
+ apikeys = data.pop("apikeys")
93
+ if apikeys and not data.get("apikey"):
94
+ data["apikey"] = apikeys[0]
95
+
96
+ # Convert endpoint dict to endpoint_url if present
97
+ if "endpoint" in data:
98
+ endpoint = data.pop("endpoint")
99
+ scheme = endpoint.get("scheme", "https")
100
+ host = endpoint.get("host", "")
101
+ port = endpoint.get("port", "")
102
+ base_path = endpoint.get("base_path", "")
103
+
104
+ endpoint_url = f"{scheme}://{host}"
105
+ if port and port not in [80, 443]:
106
+ endpoint_url += f":{port}"
107
+ if base_path:
108
+ base_path = "/" + base_path.strip("/")
109
+ endpoint_url += base_path
110
+
111
+ data["endpoint_url"] = endpoint_url
112
+
113
+ # Handle nested AuthData
114
+ auth_data = data.pop("auth_data", {})
115
+ auth_data_obj = AuthData.from_dict(auth_data)
116
+
117
+ # Only keep fields that are defined in the dataclass
118
+ valid_fields = cls.__dataclass_fields__.keys()
119
+ filtered_data = {k: v for k, v in data.items() if k in valid_fields}
120
+
121
+ return cls(**filtered_data, auth_data=auth_data_obj)
122
+
123
+ def to_dict(self):
124
+ d = asdict(self)
125
+ # Handle AuthData specifically
126
+ d["auth_data"] = self.auth_data.to_dict()
127
+ return d
128
+
129
+ def to_json(self):
130
+ return json.dumps(self.to_dict())
131
+
132
+ @classmethod
133
+ def from_json(cls, json_str: str):
134
+ return cls.from_dict(json.loads(json_str))
135
+
136
+
137
+ @dataclass
138
+ class ModelConfig:
139
+ created_at: str = ""
140
+ updated_at: str = ""
141
+ model_id: str = ""
142
+ model_saved_name: str = "Model Name"
143
+ testing_for: str = "LLM"
144
+ model_name: str = "gpt-4o-mini"
145
+ model_type: str = "text_2_text"
146
+ modality: Modality = Modality.TEXT
147
+ certifications: List[str] = field(default_factory=list)
148
+ model_config: ModelConfigDetails = field(default_factory=ModelConfigDetails)
149
+
150
+ def to_dict(self) -> dict:
151
+ """Convert the ModelConfig instance to a dictionary."""
152
+ # First create a shallow copy of self as dict
153
+ d = {}
154
+ for field in self.__dataclass_fields__:
155
+ value = getattr(self, field)
156
+ if field == "modality":
157
+ d[field] = value.value
158
+ elif field == "model_config":
159
+ if isinstance(value, ModelConfigDetails):
160
+ d[field] = value.to_dict()
161
+ else:
162
+ d[field] = value
163
+ else:
164
+ d[field] = value
165
+ return d
166
+
167
+ def to_json(self) -> str:
168
+ """Convert the ModelConfig instance to a JSON string."""
169
+ return json.dumps(self.to_dict())
170
+
171
+ @classmethod
172
+ def from_dict(cls, data: dict):
173
+ """Create a ModelConfig instance from a dictionary."""
174
+ # Handle nested ModelConfigDetails
175
+ model_config_data = data.pop("model_config", {})
176
+ model_config = ModelConfigDetails.from_dict(model_config_data)
177
+
178
+ # Handle Modality enum
179
+ modality_value = data.pop("modality", "text")
180
+ modality = Modality(modality_value)
181
+
182
+ return cls(**data, modality=modality, model_config=model_config)
183
+
184
+ @classmethod
185
+ def from_json(cls, json_str: str):
186
+ """Create a ModelConfig instance from a JSON string."""
187
+ data = json.loads(json_str)
188
+ return cls.from_dict(data)
189
+
190
+ def __str__(self):
191
+ """String representation of the ModelConfig."""
192
+ return f"ModelConfig(name={self.model_saved_name}, model={self.model_name})"
193
+
194
+ def __repr__(self):
195
+ """Detailed string representation of the ModelConfig."""
196
+ return (
197
+ f"ModelConfig({', '.join(f'{k}={v!r}' for k, v in self.to_dict().items())})"
198
+ )
199
+
200
+
201
+ # Default configuration
202
+ DETAIL_MODEL_CONFIG = ModelConfig()
@@ -0,0 +1,196 @@
1
+ from dataclasses import dataclass, field, asdict
2
+ from typing import Dict, List, Optional
3
+ import json
4
+
5
+
6
+ @dataclass
7
+ class RedTeamResponse:
8
+ task_id: Optional[str] = None
9
+
10
+ def to_dict(self) -> dict:
11
+ return asdict(self)
12
+
13
+ @classmethod
14
+ def from_dict(cls, data: dict):
15
+ return cls(**data)
16
+
17
+
18
+ @dataclass
19
+ class RedTeamTaskStatus:
20
+ status: Optional[str] = None
21
+
22
+ def to_dict(self) -> dict:
23
+ return asdict(self)
24
+
25
+ @classmethod
26
+ def from_dict(cls, data: dict):
27
+ return cls(**data)
28
+
29
+
30
+ @dataclass
31
+ class RedTeamTaskDetails:
32
+ created_at: Optional[str] = None
33
+ model_name: Optional[str] = None
34
+ status: Optional[str] = None
35
+ task_id: Optional[str] = None
36
+
37
+ def to_dict(self) -> dict:
38
+ return asdict(self)
39
+
40
+ @classmethod
41
+ def from_dict(cls, data: dict):
42
+ return cls(**data)
43
+
44
+
45
+ @dataclass
46
+ class RedTeamResultSummary:
47
+ test_date: Optional[str] = None
48
+ test_name: Optional[str] = None
49
+ dataset_name: Optional[str] = None
50
+ model_name: Optional[str] = None
51
+ model_endpoint_url: Optional[str] = None
52
+ model_source: Optional[str] = None
53
+ model_provider: Optional[str] = None
54
+ risk_score: Optional[float] = None
55
+ test_type: Optional[List] = None
56
+ nist_category: Optional[List] = None
57
+ scenario: Optional[List] = None
58
+ category: Optional[List] = None
59
+ attack_method: Optional[List] = None
60
+
61
+ def to_dict(self) -> dict:
62
+ return asdict(self)
63
+
64
+ @classmethod
65
+ def from_dict(cls, data: dict):
66
+ return cls(**data)
67
+
68
+
69
+ @dataclass
70
+ class RedTeamResultDetails: # To Be Updated
71
+ details: Optional[Dict] = None
72
+
73
+ def to_dict(self) -> dict:
74
+ return asdict(self)
75
+
76
+ @classmethod
77
+ def from_dict(cls, data: dict):
78
+ return cls(**data)
79
+
80
+
81
+ @dataclass
82
+ class AttackMethods:
83
+ basic: List[str] = field(default_factory=lambda: ["basic"])
84
+ advanced: Dict[str, List[str]] = field(
85
+ default_factory=lambda: {"static": ["single_shot"], "dynamic": ["iterative"]}
86
+ )
87
+
88
+ def to_dict(self) -> dict:
89
+ return asdict(self)
90
+
91
+ @classmethod
92
+ def from_dict(cls, data: dict):
93
+ return cls(**data)
94
+
95
+
96
+ @dataclass
97
+ class TestConfig:
98
+ sample_percentage: int = 100
99
+ attack_methods: AttackMethods = field(default_factory=AttackMethods)
100
+
101
+ def to_dict(self) -> dict:
102
+ return {
103
+ "sample_percentage": self.sample_percentage,
104
+ "attack_methods": self.attack_methods.to_dict(),
105
+ }
106
+
107
+ @classmethod
108
+ def from_dict(cls, data: dict):
109
+ attack_methods = AttackMethods.from_dict(data.pop("attack_methods", {}))
110
+ return cls(**data, attack_methods=attack_methods)
111
+
112
+
113
+ @dataclass
114
+ class RedTeamTestConfigurations:
115
+ # Basic tests
116
+ bias_test: TestConfig = field(default=None)
117
+ cbrn_test: TestConfig = field(default=None)
118
+ insecure_code_test: TestConfig = field(default=None)
119
+ toxicity_test: TestConfig = field(default=None)
120
+ harmful_test: TestConfig = field(default=None)
121
+ # Advanced tests
122
+ adv_info_test: TestConfig = field(default=None)
123
+ adv_bias_test: TestConfig = field(default=None)
124
+ adv_command_test: TestConfig = field(default=None)
125
+
126
+ def to_dict(self) -> dict:
127
+ return asdict(self)
128
+
129
+ @classmethod
130
+ def from_dict(cls, data: dict):
131
+ return cls(**{k: TestConfig.from_dict(v) for k, v in data.items()})
132
+
133
+
134
+ @dataclass
135
+ class TargetModelConfiguration:
136
+ testing_for: str = "LLM"
137
+ model_name: str = "gpt-4o-mini"
138
+ model_version: Optional[str] = None
139
+ system_prompt: str = ""
140
+ conversation_template: str = ""
141
+ model_source: str = ""
142
+ model_provider: str = "openai"
143
+ model_endpoint_url: str = "https://api.openai.com/v1/chat/completions"
144
+ model_api_key: Optional[str] = None
145
+
146
+ def to_dict(self) -> dict:
147
+ return asdict(self)
148
+
149
+ @classmethod
150
+ def from_dict(cls, data: dict):
151
+ return cls(**data)
152
+
153
+
154
+ @dataclass
155
+ class RedTeamConfig:
156
+ test_name: str = "Test Name"
157
+ dataset_name: str = "standard"
158
+ model_name: str = "gpt-4o-mini"
159
+ redteam_test_configurations: RedTeamTestConfigurations = field(
160
+ default_factory=RedTeamTestConfigurations
161
+ )
162
+ target_model_configuration: TargetModelConfiguration = field(
163
+ default_factory=TargetModelConfiguration
164
+ )
165
+
166
+ def to_dict(self) -> dict:
167
+ d = asdict(self)
168
+ d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
169
+ d["target_model_configuration"] = self.target_model_configuration.to_dict()
170
+ return d
171
+
172
+ def to_json(self) -> str:
173
+ return json.dumps(self.to_dict())
174
+
175
+ @classmethod
176
+ def from_dict(cls, data: dict):
177
+ data = data.copy()
178
+ test_configs = RedTeamTestConfigurations.from_dict(
179
+ data.pop("redteam_test_configurations", {})
180
+ )
181
+ target_config = TargetModelConfiguration.from_dict(
182
+ data.pop("target_model_configuration", {})
183
+ )
184
+ return cls(
185
+ **data,
186
+ redteam_test_configurations=test_configs,
187
+ target_model_configuration=target_config,
188
+ )
189
+
190
+ @classmethod
191
+ def from_json(cls, json_str: str):
192
+ return cls.from_dict(json.loads(json_str))
193
+
194
+
195
+ # Default configurations
196
+ DEFAULT_REDTEAM_CONFIG = RedTeamConfig()
@@ -0,0 +1,144 @@
1
+ import urllib3
2
+ from .dto import ModelConfig, ModelDetailConfig
3
+ from urllib.parse import urlparse, urlsplit
4
+
5
+
6
+ class ModelClient:
7
+ def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
8
+ self.api_key = api_key
9
+ self.base_url = base_url
10
+ self.http = urllib3.PoolManager()
11
+ self.headers = {"apikey": self.api_key}
12
+
13
+ def _request(self, method, endpoint, payload=None, headers=None, **kwargs):
14
+ url = self.base_url + endpoint
15
+ request_headers = {
16
+ "Accept-Encoding": "gzip", # Add required gzip encoding
17
+ **self.headers,
18
+ }
19
+ if headers:
20
+ request_headers.update(headers)
21
+
22
+ try:
23
+ response = self.http.request(method, url, headers=request_headers, **kwargs)
24
+
25
+ if response.status >= 400:
26
+ error_response = (
27
+ response.json()
28
+ if response.data
29
+ else {"message": f"HTTP {response.status}"}
30
+ )
31
+ raise urllib3.exceptions.HTTPError(
32
+ f"HTTP {response.status}: {error_response}"
33
+ )
34
+ return response.json()
35
+ except urllib3.exceptions.HTTPError as e:
36
+ return {"error": str(e)}
37
+
38
+ def health(self):
39
+ return self._request("GET", "/models/health")
40
+
41
+ def add_model(self, config: ModelConfig):
42
+ """
43
+ Add a new model configuration to the system.
44
+
45
+ Args:
46
+ config (ModelConfig): Configuration object containing model details
47
+
48
+ Returns:
49
+ dict: Response from the API containing the added model details
50
+ """
51
+ headers = {"Content-Type": "application/json"}
52
+ config = ModelConfig.from_dict(config)
53
+ # Parse endpoint_url into components
54
+ parsed_url = urlparse(config.model_config.endpoint_url)
55
+ path_parts = parsed_url.path.strip("/").split("/")
56
+
57
+ # Extract base_path and endpoint path
58
+ if len(path_parts) >= 1:
59
+ base_path = path_parts[0] # Usually 'v1'
60
+ remaining_path = "/".join(path_parts[1:]) # The rest of the path
61
+ else:
62
+ base_path = ""
63
+ remaining_path = ""
64
+
65
+ # Determine paths based on the endpoint
66
+ paths = {
67
+ "completions": f"/{remaining_path}" if remaining_path else "",
68
+ "chat": "",
69
+ }
70
+
71
+ payload = {
72
+ "model_saved_name": config.model_saved_name,
73
+ "testing_for": config.testing_for,
74
+ "model_name": config.model_name,
75
+ "model_type": config.model_type,
76
+ "certifications": config.certifications,
77
+ "model_config": {
78
+ "is_compatible_with": config.model_config.is_compatible_with,
79
+ "model_version": config.model_config.model_version,
80
+ "hosting_type": config.model_config.hosting_type,
81
+ "model_source": config.model_config.model_source,
82
+ "model_provider": config.model_config.model_provider,
83
+ "system_prompt": config.model_config.system_prompt,
84
+ "conversation_template": config.model_config.conversation_template,
85
+ "endpoint": {
86
+ "scheme": parsed_url.scheme,
87
+ "host": parsed_url.hostname,
88
+ "port": parsed_url.port
89
+ or (443 if parsed_url.scheme == "https" else 80),
90
+ "base_path": f"/{base_path}/{paths['completions']}", # Just v1
91
+ },
92
+ "paths": paths,
93
+ "auth_data": {
94
+ "header_name": config.model_config.auth_data.header_name,
95
+ "header_prefix": config.model_config.auth_data.header_prefix,
96
+ "space_after_prefix": config.model_config.auth_data.space_after_prefix,
97
+ },
98
+ "apikeys": (
99
+ [config.model_config.apikey] if config.model_config.apikey else []
100
+ ),
101
+ "default_request_options": config.model_config.default_request_options,
102
+ },
103
+ }
104
+ print(payload)
105
+ return self._request("POST", "/models/add-model", headers=headers, json=payload)
106
+
107
+ def get_model(self, model_id: str) -> ModelConfig:
108
+ """
109
+ Get model configuration by model ID.
110
+
111
+ Args:
112
+ model_id (str): ID of the model to retrieve
113
+
114
+ Returns:
115
+ ModelConfig: Configuration object containing model details
116
+ """
117
+ headers = {"X-Enkrypt-Model": model_id}
118
+ response = self._request("GET", "/models/get-model", headers=headers)
119
+ return ModelConfig.from_dict(response)
120
+
121
+ def get_model_list(self):
122
+ """
123
+ Get a list of all available models.
124
+
125
+ Returns:
126
+ dict: Response from the API containing the list of models
127
+ """
128
+ try:
129
+ return self._request("GET", "/models/list-models")
130
+ except Exception as e:
131
+ return {"error": str(e)}
132
+
133
+ def delete_model(self, model_id: str):
134
+ """
135
+ Delete a specific model from the system.
136
+
137
+ Args:
138
+ model_id (str): The identifier or name of the model to delete
139
+
140
+ Returns:
141
+ dict: Response from the API containing the deletion status
142
+ """
143
+ headers = {"X-Enkrypt-Model": model_id}
144
+ return self._request("DELETE", "/models/delete-model", headers=headers)
@@ -0,0 +1,185 @@
1
+ import urllib3
2
+ from .dto import (
3
+ RedTeamConfig,
4
+ RedTeamResponse,
5
+ RedTeamResultSummary,
6
+ RedTeamResultDetails,
7
+ RedTeamTaskStatus,
8
+ RedTeamTaskDetails,
9
+ )
10
+
11
+
12
+ class RedTeamError(Exception):
13
+ """
14
+ A custom exception for Red Team errors.
15
+ """
16
+
17
+ pass
18
+
19
+
20
+ class RedTeamClient:
21
+ """
22
+ A client for interacting with the Red Team API.
23
+ """
24
+
25
+ def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com"):
26
+ self.api_key = api_key
27
+ self.base_url = base_url
28
+ self.http = urllib3.PoolManager()
29
+ self.headers = {"apikey": self.api_key}
30
+
31
+ def _request(self, method, endpoint, headers=None, **kwargs):
32
+ url = self.base_url + endpoint
33
+ request_headers = {
34
+ "Accept-Encoding": "gzip", # Add required gzip encoding
35
+ **self.headers,
36
+ }
37
+ if headers:
38
+ request_headers.update(headers)
39
+
40
+ try:
41
+ response = self.http.request(method, url, headers=request_headers, **kwargs)
42
+ if response.status >= 400:
43
+ error_response = (
44
+ response.json()
45
+ if response.data
46
+ else {"message": f"HTTP {response.status}"}
47
+ )
48
+ raise urllib3.exceptions.HTTPError(
49
+ f"HTTP {response.status}: {error_response}"
50
+ )
51
+ return response.json()
52
+ except urllib3.exceptions.HTTPError as e:
53
+ print(f"Request failed: {e}")
54
+ return {"error": str(e)}
55
+
56
+ def get_model(self, model):
57
+ models = self._request("GET", "/models/list-models")
58
+ if model in models:
59
+ return model
60
+ else:
61
+ return None
62
+
63
+ def add_task(
64
+ self,
65
+ config: RedTeamConfig,
66
+ ):
67
+ """
68
+ Add a new red teaming task.
69
+ """
70
+ config = RedTeamConfig.from_dict(config)
71
+ test_configs = config.redteam_test_configurations.to_dict()
72
+ # Remove None or empty test configurations
73
+ test_configs = {k: v for k, v in test_configs.items() if v is not None}
74
+
75
+ payload = {
76
+ # "async": config.async_enabled,
77
+ "dataset_name": config.dataset_name,
78
+ "test_name": config.test_name,
79
+ "redteam_test_configurations": test_configs,
80
+ }
81
+
82
+ model = config.model_name
83
+ saved_model = self.get_model(model)
84
+
85
+ if saved_model:
86
+ print("saved model found")
87
+ headers = {
88
+ "X-Enkrypt-Model": saved_model,
89
+ "Content-Type": "application/json",
90
+ }
91
+ payload["location"] = {"storage": "supabase", "container_name": "supabase"}
92
+ return self._request(
93
+ "POST",
94
+ "/redteam/v2/model/add-task",
95
+ headers=headers,
96
+ json=payload,
97
+ )
98
+ elif config.target_model_configuration:
99
+ payload["target_model_configuration"] = (
100
+ config.target_model_configuration.to_dict()
101
+ )
102
+ # print(payload)
103
+ response = self._request(
104
+ "POST",
105
+ "/redteam/v2/add-task",
106
+ json=payload,
107
+ )
108
+ return RedTeamResponse.from_dict(response)
109
+ else:
110
+ raise RedTeamError(
111
+ "Please use a saved model or provide a target model configuration"
112
+ )
113
+
114
+ def status(self, task_id: str):
115
+ """
116
+ Get the status of a specific red teaming task.
117
+
118
+ Args:
119
+ task_id (str): The ID of the task to check status
120
+
121
+ Returns:
122
+ dict: The task status information
123
+ """
124
+ headers = {"X-Enkrypt-Task-ID": task_id}
125
+
126
+ response = self._request("GET", "/redteam/task-status", headers=headers)
127
+ return RedTeamTaskStatus.from_dict(response)
128
+
129
+ def cancel_task(self, task_id: str):
130
+ """
131
+ Cancel a specific red teaming task.
132
+
133
+ Args:
134
+ task_id (str): The ID of the task to cancel
135
+ """
136
+ raise RedTeamError(
137
+ "This feature is currently under development. Please check our documentation "
138
+ "at https://docs.enkrypt.ai for updates or contact support@enkrypt.ai for assistance."
139
+ )
140
+
141
+ def get_task(self, task_id: str):
142
+ """
143
+ Get the status and details of a specific red teaming task.
144
+
145
+ Args:
146
+ task_id (str): The ID of the task to retrieve
147
+
148
+ Returns:
149
+ dict: The task details and status
150
+ """
151
+ headers = {"X-Enkrypt-Task-ID": task_id}
152
+
153
+ response = self._request("GET", "/redteam/get-task", headers=headers)
154
+ response["data"]["task_id"] = response["data"].pop("job_id")
155
+ return RedTeamTaskDetails.from_dict(response["data"])
156
+
157
+ def get_result_summary(self, task_id: str):
158
+ """
159
+ Get the summary of results for a specific red teaming task.
160
+
161
+ Args:
162
+ task_id (str): The ID of the task to get results for
163
+
164
+ Returns:
165
+ dict: The summary of the task results
166
+ """
167
+ headers = {"X-Enkrypt-Task-ID": task_id}
168
+
169
+ response = self._request("GET", "/redteam/results/summary", headers=headers)
170
+ return RedTeamResultSummary.from_dict(response["summary"])
171
+
172
+ def get_result_details(self, task_id: str):
173
+ """
174
+ Get the detailed results for a specific red teaming task.
175
+
176
+ Args:
177
+ task_id (str): The ID of the task to get detailed results for
178
+
179
+ Returns:
180
+ dict: The detailed task results
181
+ """
182
+ # TODO: Update the response to be updated
183
+ headers = {"X-Enkrypt-Task-ID": task_id}
184
+ response = self._request("GET", "/redteam/results/details", headers=headers)
185
+ return RedTeamResultDetails.from_dict(response["details"])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: enkryptai-sdk
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -12,6 +12,9 @@ src/enkryptai_sdk.egg-info/PKG-INFO
12
12
  src/enkryptai_sdk.egg-info/SOURCES.txt
13
13
  src/enkryptai_sdk.egg-info/dependency_links.txt
14
14
  src/enkryptai_sdk.egg-info/top_level.txt
15
+ src/enkryptai_sdk/dto/__init__.py
16
+ src/enkryptai_sdk/dto/models.py
17
+ src/enkryptai_sdk/dto/red_team.py
15
18
  tests/test_all.py
16
19
  tests/test_basic.py
17
20
  tests/test_detect_policy.py
@@ -1,5 +0,0 @@
1
- from .guardrails import GuardrailsClient
2
- from .config import GuardrailsConfig
3
- from .evals import EvalsClient
4
-
5
- __all__ = ["GuardrailsClient", "GuardrailsConfig", "EvalsClient"]
File without changes
File without changes
File without changes
File without changes
File without changes