enkryptai-sdk 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- enkryptai_sdk/__init__.py +11 -2
- enkryptai_sdk/{guardrails_config.py → config.py} +107 -46
- enkryptai_sdk/dto/__init__.py +18 -0
- enkryptai_sdk/dto/models.py +202 -0
- enkryptai_sdk/dto/red_team.py +196 -0
- enkryptai_sdk/evals.py +84 -0
- enkryptai_sdk/guardrails.py +11 -4
- enkryptai_sdk/models.py +144 -0
- enkryptai_sdk/red_team.py +185 -0
- enkryptai_sdk/response.py +135 -0
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/METADATA +111 -1
- enkryptai_sdk-0.1.5.dist-info/RECORD +15 -0
- enkryptai_sdk-0.1.3.dist-info/RECORD +0 -9
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/LICENSE +0 -0
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/WHEEL +0 -0
- {enkryptai_sdk-0.1.3.dist-info → enkryptai_sdk-0.1.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
from dataclasses import dataclass, field, asdict
|
|
2
|
+
from typing import Dict, List, Optional
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class RedTeamResponse:
|
|
8
|
+
task_id: Optional[str] = None
|
|
9
|
+
|
|
10
|
+
def to_dict(self) -> dict:
|
|
11
|
+
return asdict(self)
|
|
12
|
+
|
|
13
|
+
@classmethod
|
|
14
|
+
def from_dict(cls, data: dict):
|
|
15
|
+
return cls(**data)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class RedTeamTaskStatus:
|
|
20
|
+
status: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
def to_dict(self) -> dict:
|
|
23
|
+
return asdict(self)
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def from_dict(cls, data: dict):
|
|
27
|
+
return cls(**data)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class RedTeamTaskDetails:
|
|
32
|
+
created_at: Optional[str] = None
|
|
33
|
+
model_name: Optional[str] = None
|
|
34
|
+
status: Optional[str] = None
|
|
35
|
+
task_id: Optional[str] = None
|
|
36
|
+
|
|
37
|
+
def to_dict(self) -> dict:
|
|
38
|
+
return asdict(self)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_dict(cls, data: dict):
|
|
42
|
+
return cls(**data)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class RedTeamResultSummary:
|
|
47
|
+
test_date: Optional[str] = None
|
|
48
|
+
test_name: Optional[str] = None
|
|
49
|
+
dataset_name: Optional[str] = None
|
|
50
|
+
model_name: Optional[str] = None
|
|
51
|
+
model_endpoint_url: Optional[str] = None
|
|
52
|
+
model_source: Optional[str] = None
|
|
53
|
+
model_provider: Optional[str] = None
|
|
54
|
+
risk_score: Optional[float] = None
|
|
55
|
+
test_type: Optional[List] = None
|
|
56
|
+
nist_category: Optional[List] = None
|
|
57
|
+
scenario: Optional[List] = None
|
|
58
|
+
category: Optional[List] = None
|
|
59
|
+
attack_method: Optional[List] = None
|
|
60
|
+
|
|
61
|
+
def to_dict(self) -> dict:
|
|
62
|
+
return asdict(self)
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_dict(cls, data: dict):
|
|
66
|
+
return cls(**data)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class RedTeamResultDetails: # To Be Updated
|
|
71
|
+
details: Optional[Dict] = None
|
|
72
|
+
|
|
73
|
+
def to_dict(self) -> dict:
|
|
74
|
+
return asdict(self)
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def from_dict(cls, data: dict):
|
|
78
|
+
return cls(**data)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class AttackMethods:
|
|
83
|
+
basic: List[str] = field(default_factory=lambda: ["basic"])
|
|
84
|
+
advanced: Dict[str, List[str]] = field(
|
|
85
|
+
default_factory=lambda: {"static": ["single_shot"], "dynamic": ["iterative"]}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def to_dict(self) -> dict:
|
|
89
|
+
return asdict(self)
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def from_dict(cls, data: dict):
|
|
93
|
+
return cls(**data)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class TestConfig:
|
|
98
|
+
sample_percentage: int = 100
|
|
99
|
+
attack_methods: AttackMethods = field(default_factory=AttackMethods)
|
|
100
|
+
|
|
101
|
+
def to_dict(self) -> dict:
|
|
102
|
+
return {
|
|
103
|
+
"sample_percentage": self.sample_percentage,
|
|
104
|
+
"attack_methods": self.attack_methods.to_dict(),
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def from_dict(cls, data: dict):
|
|
109
|
+
attack_methods = AttackMethods.from_dict(data.pop("attack_methods", {}))
|
|
110
|
+
return cls(**data, attack_methods=attack_methods)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class RedTeamTestConfigurations:
|
|
115
|
+
# Basic tests
|
|
116
|
+
bias_test: TestConfig = field(default=None)
|
|
117
|
+
cbrn_test: TestConfig = field(default=None)
|
|
118
|
+
insecure_code_test: TestConfig = field(default=None)
|
|
119
|
+
toxicity_test: TestConfig = field(default=None)
|
|
120
|
+
harmful_test: TestConfig = field(default=None)
|
|
121
|
+
# Advanced tests
|
|
122
|
+
adv_info_test: TestConfig = field(default=None)
|
|
123
|
+
adv_bias_test: TestConfig = field(default=None)
|
|
124
|
+
adv_command_test: TestConfig = field(default=None)
|
|
125
|
+
|
|
126
|
+
def to_dict(self) -> dict:
|
|
127
|
+
return asdict(self)
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_dict(cls, data: dict):
|
|
131
|
+
return cls(**{k: TestConfig.from_dict(v) for k, v in data.items()})
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class TargetModelConfiguration:
|
|
136
|
+
testing_for: str = "LLM"
|
|
137
|
+
model_name: str = "gpt-4o-mini"
|
|
138
|
+
model_version: Optional[str] = None
|
|
139
|
+
system_prompt: str = ""
|
|
140
|
+
conversation_template: str = ""
|
|
141
|
+
model_source: str = ""
|
|
142
|
+
model_provider: str = "openai"
|
|
143
|
+
model_endpoint_url: str = "https://api.openai.com/v1/chat/completions"
|
|
144
|
+
model_api_key: Optional[str] = None
|
|
145
|
+
|
|
146
|
+
def to_dict(self) -> dict:
|
|
147
|
+
return asdict(self)
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def from_dict(cls, data: dict):
|
|
151
|
+
return cls(**data)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class RedTeamConfig:
|
|
156
|
+
test_name: str = "Test Name"
|
|
157
|
+
dataset_name: str = "standard"
|
|
158
|
+
model_name: str = "gpt-4o-mini"
|
|
159
|
+
redteam_test_configurations: RedTeamTestConfigurations = field(
|
|
160
|
+
default_factory=RedTeamTestConfigurations
|
|
161
|
+
)
|
|
162
|
+
target_model_configuration: TargetModelConfiguration = field(
|
|
163
|
+
default_factory=TargetModelConfiguration
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def to_dict(self) -> dict:
|
|
167
|
+
d = asdict(self)
|
|
168
|
+
d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
|
|
169
|
+
d["target_model_configuration"] = self.target_model_configuration.to_dict()
|
|
170
|
+
return d
|
|
171
|
+
|
|
172
|
+
def to_json(self) -> str:
|
|
173
|
+
return json.dumps(self.to_dict())
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_dict(cls, data: dict):
|
|
177
|
+
data = data.copy()
|
|
178
|
+
test_configs = RedTeamTestConfigurations.from_dict(
|
|
179
|
+
data.pop("redteam_test_configurations", {})
|
|
180
|
+
)
|
|
181
|
+
target_config = TargetModelConfiguration.from_dict(
|
|
182
|
+
data.pop("target_model_configuration", {})
|
|
183
|
+
)
|
|
184
|
+
return cls(
|
|
185
|
+
**data,
|
|
186
|
+
redteam_test_configurations=test_configs,
|
|
187
|
+
target_model_configuration=target_config,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
@classmethod
|
|
191
|
+
def from_json(cls, json_str: str):
|
|
192
|
+
return cls.from_dict(json.loads(json_str))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# Default configurations
|
|
196
|
+
DEFAULT_REDTEAM_CONFIG = RedTeamConfig()
|
enkryptai_sdk/evals.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
|
|
3
|
+
class EvalsClient:
|
|
4
|
+
"""
|
|
5
|
+
A client for interacting with Enkrypt AI Evals API endpoints.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self, api_key, base_url="https://api.enkryptai.com"):
|
|
9
|
+
"""
|
|
10
|
+
Initializes the client.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
- api_key (str): Your API key for authenticating with the service.
|
|
14
|
+
- base_url (str): Base URL of the API (default: "https://api.enkryptai.com").
|
|
15
|
+
"""
|
|
16
|
+
self.api_key = api_key
|
|
17
|
+
self.base_url = base_url.rstrip('/')
|
|
18
|
+
self.session = requests.Session()
|
|
19
|
+
|
|
20
|
+
def _request(self, method, endpoint, headers=None, **kwargs):
|
|
21
|
+
"""
|
|
22
|
+
Internal helper to send an HTTP request.
|
|
23
|
+
|
|
24
|
+
Automatically adds the API key to headers.
|
|
25
|
+
"""
|
|
26
|
+
url = self.base_url + endpoint
|
|
27
|
+
headers = headers or {}
|
|
28
|
+
if 'apikey' not in headers:
|
|
29
|
+
headers['apikey'] = self.api_key
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
response = self.session.request(method, url, headers=headers, **kwargs)
|
|
33
|
+
response.raise_for_status()
|
|
34
|
+
return response.json()
|
|
35
|
+
|
|
36
|
+
except Exception as e:
|
|
37
|
+
print(e)
|
|
38
|
+
return {"error": str(e)}
|
|
39
|
+
|
|
40
|
+
# ----------------------------
|
|
41
|
+
# Basic Evals Endpoints
|
|
42
|
+
# ----------------------------
|
|
43
|
+
|
|
44
|
+
def check_adherence(self, llm_answer, context):
|
|
45
|
+
"""
|
|
46
|
+
Checks if the LLM's answer adheres to the provided context.
|
|
47
|
+
|
|
48
|
+
Parameters:
|
|
49
|
+
- llm_answer (str): The response generated by the LLM
|
|
50
|
+
- context (str): The context against which to check the answer
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
- JSON response from the API containing adherence analysis
|
|
54
|
+
"""
|
|
55
|
+
payload = {
|
|
56
|
+
"llm_answer": llm_answer,
|
|
57
|
+
"context": context
|
|
58
|
+
}
|
|
59
|
+
try:
|
|
60
|
+
return self._request("POST", "/guardrails/adherence", json=payload)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
print(e)
|
|
63
|
+
return {"error": str(e)}
|
|
64
|
+
|
|
65
|
+
def check_relevancy(self, question, llm_answer):
|
|
66
|
+
"""
|
|
67
|
+
Checks if the LLM's answer is relevant to the asked question.
|
|
68
|
+
|
|
69
|
+
Parameters:
|
|
70
|
+
- question (str): The original question asked
|
|
71
|
+
- llm_answer (str): The response generated by the LLM
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
- JSON response from the API containing relevancy analysis
|
|
75
|
+
"""
|
|
76
|
+
payload = {
|
|
77
|
+
"question": question,
|
|
78
|
+
"llm_answer": llm_answer
|
|
79
|
+
}
|
|
80
|
+
try:
|
|
81
|
+
return self._request("POST", "/guardrails/relevancy", json=payload)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
print(e)
|
|
84
|
+
return {"error": str(e)}
|
enkryptai_sdk/guardrails.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import requests
|
|
2
|
+
from .config import GuardrailsConfig
|
|
3
|
+
from .response import GuardrailsResponse, PIIResponse
|
|
2
4
|
|
|
3
5
|
class GuardrailsClient:
|
|
4
6
|
"""
|
|
@@ -74,7 +76,6 @@ class GuardrailsClient:
|
|
|
74
76
|
"""
|
|
75
77
|
# Use injection attack config by default if none provided
|
|
76
78
|
if config is None:
|
|
77
|
-
from .guardrails_config import GuardrailsConfig
|
|
78
79
|
config = GuardrailsConfig.injection_attack()
|
|
79
80
|
|
|
80
81
|
# Allow passing in either a dict or a GuardrailsConfig instance.
|
|
@@ -85,7 +86,8 @@ class GuardrailsClient:
|
|
|
85
86
|
"text": text,
|
|
86
87
|
"detectors": config
|
|
87
88
|
}
|
|
88
|
-
|
|
89
|
+
response_body = self._request("POST", "/guardrails/detect", json=payload)
|
|
90
|
+
return GuardrailsResponse(response_body)
|
|
89
91
|
|
|
90
92
|
def pii(self, text, mode, key="null", entities=None):
|
|
91
93
|
"""
|
|
@@ -97,7 +99,8 @@ class GuardrailsClient:
|
|
|
97
99
|
"key": key,
|
|
98
100
|
"entities": entities
|
|
99
101
|
}
|
|
100
|
-
|
|
102
|
+
response_body = self._request("POST", "/guardrails/pii", json=payload)
|
|
103
|
+
return PIIResponse(response_body)
|
|
101
104
|
|
|
102
105
|
# ----------------------------
|
|
103
106
|
# Guardrails Policy Endpoints
|
|
@@ -181,8 +184,12 @@ class GuardrailsClient:
|
|
|
181
184
|
"""
|
|
182
185
|
headers = {"X-Enkrypt-Policy": policy_name}
|
|
183
186
|
payload = {"text": text}
|
|
187
|
+
|
|
184
188
|
try:
|
|
185
|
-
|
|
189
|
+
|
|
190
|
+
response_body = self._request("POST", "/guardrails/policy/detect", headers=headers, json=payload)
|
|
191
|
+
return GuardrailsResponse(response_body)
|
|
192
|
+
|
|
186
193
|
except Exception as e:
|
|
187
194
|
print(e)
|
|
188
195
|
return {"error": str(e)}
|
enkryptai_sdk/models.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import urllib3
|
|
2
|
+
from .dto import ModelConfig, ModelDetailConfig
|
|
3
|
+
from urllib.parse import urlparse, urlsplit
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelClient:
|
|
7
|
+
def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
|
|
8
|
+
self.api_key = api_key
|
|
9
|
+
self.base_url = base_url
|
|
10
|
+
self.http = urllib3.PoolManager()
|
|
11
|
+
self.headers = {"apikey": self.api_key}
|
|
12
|
+
|
|
13
|
+
def _request(self, method, endpoint, payload=None, headers=None, **kwargs):
|
|
14
|
+
url = self.base_url + endpoint
|
|
15
|
+
request_headers = {
|
|
16
|
+
"Accept-Encoding": "gzip", # Add required gzip encoding
|
|
17
|
+
**self.headers,
|
|
18
|
+
}
|
|
19
|
+
if headers:
|
|
20
|
+
request_headers.update(headers)
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
response = self.http.request(method, url, headers=request_headers, **kwargs)
|
|
24
|
+
|
|
25
|
+
if response.status >= 400:
|
|
26
|
+
error_response = (
|
|
27
|
+
response.json()
|
|
28
|
+
if response.data
|
|
29
|
+
else {"message": f"HTTP {response.status}"}
|
|
30
|
+
)
|
|
31
|
+
raise urllib3.exceptions.HTTPError(
|
|
32
|
+
f"HTTP {response.status}: {error_response}"
|
|
33
|
+
)
|
|
34
|
+
return response.json()
|
|
35
|
+
except urllib3.exceptions.HTTPError as e:
|
|
36
|
+
return {"error": str(e)}
|
|
37
|
+
|
|
38
|
+
def health(self):
|
|
39
|
+
return self._request("GET", "/models/health")
|
|
40
|
+
|
|
41
|
+
def add_model(self, config: ModelConfig):
|
|
42
|
+
"""
|
|
43
|
+
Add a new model configuration to the system.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
config (ModelConfig): Configuration object containing model details
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
dict: Response from the API containing the added model details
|
|
50
|
+
"""
|
|
51
|
+
headers = {"Content-Type": "application/json"}
|
|
52
|
+
config = ModelConfig.from_dict(config)
|
|
53
|
+
# Parse endpoint_url into components
|
|
54
|
+
parsed_url = urlparse(config.model_config.endpoint_url)
|
|
55
|
+
path_parts = parsed_url.path.strip("/").split("/")
|
|
56
|
+
|
|
57
|
+
# Extract base_path and endpoint path
|
|
58
|
+
if len(path_parts) >= 1:
|
|
59
|
+
base_path = path_parts[0] # Usually 'v1'
|
|
60
|
+
remaining_path = "/".join(path_parts[1:]) # The rest of the path
|
|
61
|
+
else:
|
|
62
|
+
base_path = ""
|
|
63
|
+
remaining_path = ""
|
|
64
|
+
|
|
65
|
+
# Determine paths based on the endpoint
|
|
66
|
+
paths = {
|
|
67
|
+
"completions": f"/{remaining_path}" if remaining_path else "",
|
|
68
|
+
"chat": "",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
payload = {
|
|
72
|
+
"model_saved_name": config.model_saved_name,
|
|
73
|
+
"testing_for": config.testing_for,
|
|
74
|
+
"model_name": config.model_name,
|
|
75
|
+
"model_type": config.model_type,
|
|
76
|
+
"certifications": config.certifications,
|
|
77
|
+
"model_config": {
|
|
78
|
+
"is_compatible_with": config.model_config.is_compatible_with,
|
|
79
|
+
"model_version": config.model_config.model_version,
|
|
80
|
+
"hosting_type": config.model_config.hosting_type,
|
|
81
|
+
"model_source": config.model_config.model_source,
|
|
82
|
+
"model_provider": config.model_config.model_provider,
|
|
83
|
+
"system_prompt": config.model_config.system_prompt,
|
|
84
|
+
"conversation_template": config.model_config.conversation_template,
|
|
85
|
+
"endpoint": {
|
|
86
|
+
"scheme": parsed_url.scheme,
|
|
87
|
+
"host": parsed_url.hostname,
|
|
88
|
+
"port": parsed_url.port
|
|
89
|
+
or (443 if parsed_url.scheme == "https" else 80),
|
|
90
|
+
"base_path": f"/{base_path}/{paths['completions']}", # Just v1
|
|
91
|
+
},
|
|
92
|
+
"paths": paths,
|
|
93
|
+
"auth_data": {
|
|
94
|
+
"header_name": config.model_config.auth_data.header_name,
|
|
95
|
+
"header_prefix": config.model_config.auth_data.header_prefix,
|
|
96
|
+
"space_after_prefix": config.model_config.auth_data.space_after_prefix,
|
|
97
|
+
},
|
|
98
|
+
"apikeys": (
|
|
99
|
+
[config.model_config.apikey] if config.model_config.apikey else []
|
|
100
|
+
),
|
|
101
|
+
"default_request_options": config.model_config.default_request_options,
|
|
102
|
+
},
|
|
103
|
+
}
|
|
104
|
+
print(payload)
|
|
105
|
+
return self._request("POST", "/models/add-model", headers=headers, json=payload)
|
|
106
|
+
|
|
107
|
+
def get_model(self, model_id: str) -> ModelConfig:
|
|
108
|
+
"""
|
|
109
|
+
Get model configuration by model ID.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model_id (str): ID of the model to retrieve
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
ModelConfig: Configuration object containing model details
|
|
116
|
+
"""
|
|
117
|
+
headers = {"X-Enkrypt-Model": model_id}
|
|
118
|
+
response = self._request("GET", "/models/get-model", headers=headers)
|
|
119
|
+
return ModelConfig.from_dict(response)
|
|
120
|
+
|
|
121
|
+
def get_model_list(self):
|
|
122
|
+
"""
|
|
123
|
+
Get a list of all available models.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
dict: Response from the API containing the list of models
|
|
127
|
+
"""
|
|
128
|
+
try:
|
|
129
|
+
return self._request("GET", "/models/list-models")
|
|
130
|
+
except Exception as e:
|
|
131
|
+
return {"error": str(e)}
|
|
132
|
+
|
|
133
|
+
def delete_model(self, model_id: str):
|
|
134
|
+
"""
|
|
135
|
+
Delete a specific model from the system.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model_id (str): The identifier or name of the model to delete
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
dict: Response from the API containing the deletion status
|
|
142
|
+
"""
|
|
143
|
+
headers = {"X-Enkrypt-Model": model_id}
|
|
144
|
+
return self._request("DELETE", "/models/delete-model", headers=headers)
|
enkryptai_sdk/red_team.py
CHANGED
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import urllib3
|
|
2
|
+
from .dto import (
|
|
3
|
+
RedTeamConfig,
|
|
4
|
+
RedTeamResponse,
|
|
5
|
+
RedTeamResultSummary,
|
|
6
|
+
RedTeamResultDetails,
|
|
7
|
+
RedTeamTaskStatus,
|
|
8
|
+
RedTeamTaskDetails,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RedTeamError(Exception):
|
|
13
|
+
"""
|
|
14
|
+
A custom exception for Red Team errors.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RedTeamClient:
|
|
21
|
+
"""
|
|
22
|
+
A client for interacting with the Red Team API.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com"):
|
|
26
|
+
self.api_key = api_key
|
|
27
|
+
self.base_url = base_url
|
|
28
|
+
self.http = urllib3.PoolManager()
|
|
29
|
+
self.headers = {"apikey": self.api_key}
|
|
30
|
+
|
|
31
|
+
def _request(self, method, endpoint, headers=None, **kwargs):
|
|
32
|
+
url = self.base_url + endpoint
|
|
33
|
+
request_headers = {
|
|
34
|
+
"Accept-Encoding": "gzip", # Add required gzip encoding
|
|
35
|
+
**self.headers,
|
|
36
|
+
}
|
|
37
|
+
if headers:
|
|
38
|
+
request_headers.update(headers)
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
response = self.http.request(method, url, headers=request_headers, **kwargs)
|
|
42
|
+
if response.status >= 400:
|
|
43
|
+
error_response = (
|
|
44
|
+
response.json()
|
|
45
|
+
if response.data
|
|
46
|
+
else {"message": f"HTTP {response.status}"}
|
|
47
|
+
)
|
|
48
|
+
raise urllib3.exceptions.HTTPError(
|
|
49
|
+
f"HTTP {response.status}: {error_response}"
|
|
50
|
+
)
|
|
51
|
+
return response.json()
|
|
52
|
+
except urllib3.exceptions.HTTPError as e:
|
|
53
|
+
print(f"Request failed: {e}")
|
|
54
|
+
return {"error": str(e)}
|
|
55
|
+
|
|
56
|
+
def get_model(self, model):
|
|
57
|
+
models = self._request("GET", "/models/list-models")
|
|
58
|
+
if model in models:
|
|
59
|
+
return model
|
|
60
|
+
else:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
def add_task(
|
|
64
|
+
self,
|
|
65
|
+
config: RedTeamConfig,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Add a new red teaming task.
|
|
69
|
+
"""
|
|
70
|
+
config = RedTeamConfig.from_dict(config)
|
|
71
|
+
test_configs = config.redteam_test_configurations.to_dict()
|
|
72
|
+
# Remove None or empty test configurations
|
|
73
|
+
test_configs = {k: v for k, v in test_configs.items() if v is not None}
|
|
74
|
+
|
|
75
|
+
payload = {
|
|
76
|
+
# "async": config.async_enabled,
|
|
77
|
+
"dataset_name": config.dataset_name,
|
|
78
|
+
"test_name": config.test_name,
|
|
79
|
+
"redteam_test_configurations": test_configs,
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
model = config.model_name
|
|
83
|
+
saved_model = self.get_model(model)
|
|
84
|
+
|
|
85
|
+
if saved_model:
|
|
86
|
+
print("saved model found")
|
|
87
|
+
headers = {
|
|
88
|
+
"X-Enkrypt-Model": saved_model,
|
|
89
|
+
"Content-Type": "application/json",
|
|
90
|
+
}
|
|
91
|
+
payload["location"] = {"storage": "supabase", "container_name": "supabase"}
|
|
92
|
+
return self._request(
|
|
93
|
+
"POST",
|
|
94
|
+
"/redteam/v2/model/add-task",
|
|
95
|
+
headers=headers,
|
|
96
|
+
json=payload,
|
|
97
|
+
)
|
|
98
|
+
elif config.target_model_configuration:
|
|
99
|
+
payload["target_model_configuration"] = (
|
|
100
|
+
config.target_model_configuration.to_dict()
|
|
101
|
+
)
|
|
102
|
+
# print(payload)
|
|
103
|
+
response = self._request(
|
|
104
|
+
"POST",
|
|
105
|
+
"/redteam/v2/add-task",
|
|
106
|
+
json=payload,
|
|
107
|
+
)
|
|
108
|
+
return RedTeamResponse.from_dict(response)
|
|
109
|
+
else:
|
|
110
|
+
raise RedTeamError(
|
|
111
|
+
"Please use a saved model or provide a target model configuration"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def status(self, task_id: str):
|
|
115
|
+
"""
|
|
116
|
+
Get the status of a specific red teaming task.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
task_id (str): The ID of the task to check status
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
dict: The task status information
|
|
123
|
+
"""
|
|
124
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
125
|
+
|
|
126
|
+
response = self._request("GET", "/redteam/task-status", headers=headers)
|
|
127
|
+
return RedTeamTaskStatus.from_dict(response)
|
|
128
|
+
|
|
129
|
+
def cancel_task(self, task_id: str):
|
|
130
|
+
"""
|
|
131
|
+
Cancel a specific red teaming task.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
task_id (str): The ID of the task to cancel
|
|
135
|
+
"""
|
|
136
|
+
raise RedTeamError(
|
|
137
|
+
"This feature is currently under development. Please check our documentation "
|
|
138
|
+
"at https://docs.enkrypt.ai for updates or contact support@enkrypt.ai for assistance."
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def get_task(self, task_id: str):
|
|
142
|
+
"""
|
|
143
|
+
Get the status and details of a specific red teaming task.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
task_id (str): The ID of the task to retrieve
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
dict: The task details and status
|
|
150
|
+
"""
|
|
151
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
152
|
+
|
|
153
|
+
response = self._request("GET", "/redteam/get-task", headers=headers)
|
|
154
|
+
response["data"]["task_id"] = response["data"].pop("job_id")
|
|
155
|
+
return RedTeamTaskDetails.from_dict(response["data"])
|
|
156
|
+
|
|
157
|
+
def get_result_summary(self, task_id: str):
|
|
158
|
+
"""
|
|
159
|
+
Get the summary of results for a specific red teaming task.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
task_id (str): The ID of the task to get results for
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
dict: The summary of the task results
|
|
166
|
+
"""
|
|
167
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
168
|
+
|
|
169
|
+
response = self._request("GET", "/redteam/results/summary", headers=headers)
|
|
170
|
+
return RedTeamResultSummary.from_dict(response["summary"])
|
|
171
|
+
|
|
172
|
+
def get_result_details(self, task_id: str):
|
|
173
|
+
"""
|
|
174
|
+
Get the detailed results for a specific red teaming task.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
task_id (str): The ID of the task to get detailed results for
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
dict: The detailed task results
|
|
181
|
+
"""
|
|
182
|
+
# TODO: Update the response to be updated
|
|
183
|
+
headers = {"X-Enkrypt-Task-ID": task_id}
|
|
184
|
+
response = self._request("GET", "/redteam/results/details", headers=headers)
|
|
185
|
+
return RedTeamResultDetails.from_dict(response["details"])
|