enkryptai-sdk 1.0.15__tar.gz → 1.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {enkryptai_sdk-1.0.15/src/enkryptai_sdk.egg-info → enkryptai_sdk-1.0.17}/PKG-INFO +1 -4
  2. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/README.md +0 -3
  3. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/setup.py +1 -1
  4. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/models.py +17 -22
  5. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/red_team.py +130 -22
  6. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/models.py +39 -89
  7. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/red_team.py +1 -0
  8. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17/src/enkryptai_sdk.egg-info}/PKG-INFO +1 -4
  9. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/LICENSE +0 -0
  10. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/setup.cfg +0 -0
  11. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/__init__.py +0 -0
  12. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/ai_proxy.py +0 -0
  13. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/base.py +0 -0
  14. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/coc.py +0 -0
  15. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/config.py +0 -0
  16. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/datasets.py +0 -0
  17. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/deployments.py +0 -0
  18. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/__init__.py +0 -0
  19. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/ai_proxy.py +0 -0
  20. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/base.py +0 -0
  21. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/coc.py +0 -0
  22. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/datasets.py +0 -0
  23. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/deployments.py +0 -0
  24. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/dto/guardrails.py +0 -0
  25. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/evals.py +0 -0
  26. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/guardrails.py +0 -0
  27. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/guardrails_old.py +0 -0
  28. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk/response.py +0 -0
  29. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk.egg-info/SOURCES.txt +0 -0
  30. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk.egg-info/dependency_links.txt +0 -0
  31. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/src/enkryptai_sdk.egg-info/top_level.txt +0 -0
  32. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_ai_proxy.py +0 -0
  33. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_all.py +0 -0
  34. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_all_v2.py +0 -0
  35. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_basic.py +0 -0
  36. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_coc.py +0 -0
  37. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_datasets.py +0 -0
  38. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_deployments.py +0 -0
  39. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_detect_policy.py +0 -0
  40. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_guardrails.py +0 -0
  41. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_injection_attack.py +0 -0
  42. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_model.py +0 -0
  43. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_openai.py +0 -0
  44. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_policy_violation.py +0 -0
  45. {enkryptai_sdk-1.0.15 → enkryptai_sdk-1.0.17}/tests/test_redteam.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: enkryptai-sdk
3
- Version: 1.0.15
3
+ Version: 1.0.17
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -526,9 +526,6 @@ sample_custom_redteam_model_config = {
526
526
 
527
527
  ```python Python
528
528
  sample_redteam_risk_mitigation_guardrails_policy_config = {
529
- "required_detectors": [
530
- "policy_violation"
531
- ],
532
529
  "redteam_summary": {
533
530
  "category": [
534
531
  {
@@ -503,9 +503,6 @@ sample_custom_redteam_model_config = {
503
503
 
504
504
  ```python Python
505
505
  sample_redteam_risk_mitigation_guardrails_policy_config = {
506
- "required_detectors": [
507
- "policy_violation"
508
- ],
509
506
  "redteam_summary": {
510
507
  "category": [
511
508
  {
@@ -9,7 +9,7 @@ with open(os.path.join(here, "README.md"), encoding="utf-8") as fh:
9
9
  setup(
10
10
  name="enkryptai-sdk", # This is the name of your package on PyPI
11
11
  # NOTE: Also change this in .github/workflows/test.yaml
12
- version="1.0.15", # Update this for new versions
12
+ version="1.0.17", # Update this for new versions
13
13
  description="A Python SDK with guardrails and red teaming functionality for API interactions",
14
14
  long_description=long_description,
15
15
  long_description_content_type="text/markdown",
@@ -5,6 +5,7 @@ from .base import BaseDTO
5
5
  from tabulate import tabulate
6
6
  from dataclasses import dataclass, field, asdict
7
7
  from typing import Optional, List, Set, Dict, Any
8
+ from .red_team import ModelAuthTypeEnum, CustomHeader, ModelJwtConfig
8
9
 
9
10
 
10
11
  # class Modality(Enum):
@@ -40,6 +41,7 @@ class ModelProviders(str, Enum):
40
41
  CUSTOM = "custom"
41
42
  HR = "hr"
42
43
  URL = "url"
44
+ ENKRYPTAI = "enkryptai"
43
45
 
44
46
 
45
47
  @dataclass
@@ -92,10 +94,11 @@ class ModelDetailConfig:
92
94
  # model_provider: str = "openai"
93
95
  model_provider: ModelProviders = ModelProviders.OPENAI
94
96
  system_prompt: str = ""
95
-
96
- endpoint_url: str = "https://api.openai.com/v1/chat/completions"
97
+ endpoint_url: str = ""
97
98
  auth_data: AuthData = field(default_factory=AuthData)
99
+ metadata: Dict[str, Any] = field(default_factory=dict)
98
100
  api_keys: Set[Optional[str]] = field(default_factory=lambda: {None})
101
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
99
102
 
100
103
 
101
104
  @dataclass
@@ -124,25 +127,6 @@ class OutputModality(str, Enum):
124
127
  # code = "code"
125
128
 
126
129
 
127
- @dataclass
128
- class CustomHeader(BaseDTO):
129
- key: str
130
- value: str
131
-
132
- @classmethod
133
- def from_dict(cls, data: Dict[str, Any]) -> "CustomHeader":
134
- return cls(
135
- key=data.get("key", ""),
136
- value=data.get("value", "")
137
- )
138
-
139
- def to_dict(self) -> Dict[str, Any]:
140
- return {
141
- "key": self.key,
142
- "value": self.value
143
- }
144
-
145
-
146
130
  @dataclass
147
131
  class ModelConfigDetails(BaseDTO):
148
132
  model_id: str = None
@@ -159,11 +143,13 @@ class ModelConfigDetails(BaseDTO):
159
143
  headers: str = ""
160
144
  system_prompt: str = ""
161
145
  hosting_type: str = "External"
162
- endpoint_url: str = "https://api.openai.com/v1/chat/completions"
146
+ endpoint_url: str = ""
163
147
  model_name: Optional[str] = ""
164
148
  apikey: Optional[str] = None
165
149
  paths: Optional[PathsConfig] = None
166
150
  tools: List[Dict[str, str]] = field(default_factory=list)
151
+ model_auth_type: Optional[ModelAuthTypeEnum] = ModelAuthTypeEnum.APIKEY
152
+ model_jwt_config: Optional[ModelJwtConfig] = None
167
153
  auth_data: AuthData = field(default_factory=AuthData)
168
154
  input_modalities: List[InputModality] = field(default_factory=list)
169
155
  output_modalities: List[OutputModality] = field(default_factory=list)
@@ -184,6 +170,12 @@ class ModelConfigDetails(BaseDTO):
184
170
  if "custom_headers" in data:
185
171
  data["custom_headers"] = [CustomHeader.from_dict(h) for h in data["custom_headers"]]
186
172
 
173
+ if "model_auth_type" in data:
174
+ data["model_auth_type"] = ModelAuthTypeEnum(data["model_auth_type"])
175
+
176
+ if "model_jwt_config" in data:
177
+ data["model_jwt_config"] = ModelJwtConfig.from_dict(data["model_jwt_config"])
178
+
187
179
  # Convert input_modalities strings to enum values
188
180
  if "input_modalities" in data:
189
181
  data["input_modalities"] = [InputModality(m) for m in data["input_modalities"]]
@@ -249,6 +241,9 @@ class ModelConfigDetails(BaseDTO):
249
241
 
250
242
  def to_dict(self):
251
243
  d = super().to_dict()
244
+ d["model_auth_type"] = self.model_auth_type.value
245
+ if self.model_jwt_config:
246
+ d["model_jwt_config"] = self.model_jwt_config.to_dict()
252
247
  # Handle AuthData specifically
253
248
  d["auth_data"] = self.auth_data.to_dict()
254
249
  # Handle CustomHeader list
@@ -20,6 +20,64 @@ class RiskGuardrailDetectorsEnum(str, Enum):
20
20
  # SYSTEM_PROMPT = "system_prompt"
21
21
 
22
22
 
23
+ class ModelAuthTypeEnum(str, Enum):
24
+ APIKEY = "apikey"
25
+ JWT = "jwt"
26
+
27
+
28
+ class ModelJwtMethodEnum(str, Enum):
29
+ POST = "POST"
30
+ GET = "GET"
31
+
32
+
33
+ @dataclass
34
+ class CustomHeader(BaseDTO):
35
+ key: str
36
+ value: str
37
+
38
+ @classmethod
39
+ def from_dict(cls, data: Dict[str, Any]) -> "CustomHeader":
40
+ return cls(
41
+ key=data.get("key", ""),
42
+ value=data.get("value", "")
43
+ )
44
+
45
+ def to_dict(self) -> Dict[str, Any]:
46
+ return {
47
+ "key": self.key,
48
+ "value": self.value
49
+ }
50
+
51
+
52
+ @dataclass
53
+ class ModelJwtConfig(BaseDTO):
54
+ jwt_method: ModelJwtMethodEnum = ModelJwtMethodEnum.POST
55
+ jwt_url: str = ""
56
+ jwt_headers: List[CustomHeader] = field(default_factory=list)
57
+ jwt_body: str = ""
58
+ jwt_response_key: str = ""
59
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
60
+
61
+ @classmethod
62
+ def from_dict(cls, data: Dict[str, Any]) -> "ModelJwtConfig":
63
+ return cls(
64
+ jwt_method=ModelJwtMethodEnum(data.get("jwt_method", ModelJwtMethodEnum.POST)),
65
+ jwt_url=data.get("jwt_url", ""),
66
+ jwt_headers=[CustomHeader.from_dict(header) for header in data.get("jwt_headers", [])],
67
+ jwt_body=data.get("jwt_body", ""),
68
+ jwt_response_key=data.get("jwt_response_key", ""),
69
+ )
70
+
71
+ def to_dict(self) -> Dict[str, Any]:
72
+ return {
73
+ "jwt_method": self.jwt_method.value,
74
+ "jwt_url": self.jwt_url,
75
+ "jwt_headers": [header.to_dict() for header in self.jwt_headers],
76
+ "jwt_body": self.jwt_body,
77
+ "jwt_response_key": self.jwt_response_key,
78
+ }
79
+
80
+
23
81
  @dataclass
24
82
  class RedteamHealthResponse(BaseDTO):
25
83
  status: str
@@ -62,13 +120,60 @@ class RedTeamTaskStatus(BaseDTO):
62
120
  status: Optional[str] = None
63
121
 
64
122
 
123
+ @dataclass
124
+ class RedTeamTaskDetailsModelConfig(BaseDTO):
125
+ system_prompt: Optional[str] = None
126
+ model_version: Optional[str] = None
127
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
128
+
129
+ @classmethod
130
+ def from_dict(cls, data: Dict) -> "RedTeamTaskDetailsModelConfig":
131
+ return cls(
132
+ system_prompt=data.get("system_prompt"),
133
+ model_version=data.get("model_version"),
134
+ )
135
+
136
+ def to_dict(self) -> Dict:
137
+ return {
138
+ "system_prompt": self.system_prompt,
139
+ "model_version": self.model_version,
140
+ }
141
+
142
+
65
143
  @dataclass
66
144
  class RedTeamTaskDetails(BaseDTO):
67
145
  created_at: Optional[str] = None
146
+ model_saved_name: Optional[str] = None
68
147
  model_name: Optional[str] = None
69
148
  status: Optional[str] = None
70
149
  test_name: Optional[str] = None
71
150
  task_id: Optional[str] = None
151
+ model_config: Optional[RedTeamTaskDetailsModelConfig] = None
152
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
153
+
154
+ @classmethod
155
+ def from_dict(cls, data: Dict) -> "RedTeamTaskDetails":
156
+ # print(f"RedTeamTaskDetails data: {data}")
157
+ return cls(
158
+ created_at=data.get("created_at"),
159
+ model_saved_name=data.get("model_saved_name"),
160
+ model_name=data.get("model_name"),
161
+ status=data.get("status"),
162
+ test_name=data.get("test_name"),
163
+ task_id=data.get("task_id"),
164
+ model_config=RedTeamTaskDetailsModelConfig.from_dict(data.get("model_config", {})),
165
+ )
166
+
167
+ def to_dict(self) -> Dict:
168
+ return {
169
+ "created_at": self.created_at,
170
+ "model_saved_name": self.model_saved_name,
171
+ "model_name": self.model_name,
172
+ "status": self.status,
173
+ "test_name": self.test_name,
174
+ "task_id": self.task_id,
175
+ "model_config": self.model_config.to_dict(),
176
+ }
72
177
 
73
178
 
74
179
  @dataclass
@@ -322,8 +427,21 @@ class RedTeamTestConfigurations(BaseDTO):
322
427
  # Advanced tests
323
428
  adv_info_test: TestConfig = field(default=None)
324
429
  adv_bias_test: TestConfig = field(default=None)
430
+ adv_tool_test: TestConfig = field(default=None)
325
431
  adv_command_test: TestConfig = field(default=None)
432
+ adv_pii_test: TestConfig = field(default=None)
433
+ adv_competitor_test: TestConfig = field(default=None)
434
+ # Custom tests
326
435
  custom_test: TestConfig = field(default=None)
436
+ # Agents tests
437
+ alignment_and_governance_test: TestConfig = field(default=None)
438
+ input_and_content_integrity_test: TestConfig = field(default=None)
439
+ infrastructure_and_integration_test: TestConfig = field(default=None)
440
+ security_and_privacy_test: TestConfig = field(default=None)
441
+ human_factors_and_societal_impact_test: TestConfig = field(default=None)
442
+ access_control_test: TestConfig = field(default=None)
443
+ physical_and_actuation_safety_test: TestConfig = field(default=None)
444
+ reliability_and_monitoring_test: TestConfig = field(default=None)
327
445
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
328
446
 
329
447
  @classmethod
@@ -345,25 +463,6 @@ class OutputModality(str, Enum):
345
463
  # audio = "audio"
346
464
  # video = "video"
347
465
  # code = "code"
348
-
349
-
350
- @dataclass
351
- class CustomHeader(BaseDTO):
352
- key: str
353
- value: str
354
-
355
- @classmethod
356
- def from_dict(cls, data: Dict[str, Any]) -> "CustomHeader":
357
- return cls(
358
- key=data.get("key", ""),
359
- value=data.get("value", "")
360
- )
361
-
362
- def to_dict(self) -> Dict[str, Any]:
363
- return {
364
- "key": self.key,
365
- "value": self.value
366
- }
367
466
 
368
467
 
369
468
  @dataclass
@@ -376,6 +475,8 @@ class TargetModelConfiguration(BaseDTO):
376
475
  rate_per_min: int = 20
377
476
  model_name: Optional[str] = "gpt-4o-mini"
378
477
  model_version: Optional[str] = None
478
+ model_auth_type: Optional[ModelAuthTypeEnum] = ModelAuthTypeEnum.APIKEY
479
+ model_jwt_config: Optional[ModelJwtConfig] = None
379
480
  model_api_key: Optional[str] = None
380
481
  input_modalities: List[InputModality] = field(default_factory=list)
381
482
  output_modalities: List[OutputModality] = field(default_factory=list)
@@ -391,10 +492,17 @@ class TargetModelConfiguration(BaseDTO):
391
492
  data = data.copy()
392
493
  if "custom_headers" in data:
393
494
  data["custom_headers"] = [CustomHeader.from_dict(header) for header in data["custom_headers"]]
495
+ if "model_auth_type" in data:
496
+ data["model_auth_type"] = ModelAuthTypeEnum(data["model_auth_type"])
497
+ if "model_jwt_config" in data:
498
+ data["model_jwt_config"] = ModelJwtConfig.from_dict(data["model_jwt_config"])
394
499
  return cls(**data)
395
500
 
396
501
  def to_dict(self) -> dict:
397
502
  d = asdict(self)
503
+ d["model_auth_type"] = self.model_auth_type.value
504
+ if self.model_jwt_config:
505
+ d["model_jwt_config"] = self.model_jwt_config.to_dict()
398
506
  d["custom_headers"] = [header.to_dict() for header in self.custom_headers]
399
507
  return d
400
508
 
@@ -602,7 +710,7 @@ class RedTeamTaskList(BaseDTO):
602
710
 
603
711
  @dataclass
604
712
  class RedTeamRiskMitigationGuardrailsPolicyConfig(BaseDTO):
605
- required_detectors: List[RiskGuardrailDetectorsEnum] = field(default_factory=list)
713
+ # required_detectors: List[RiskGuardrailDetectorsEnum] = field(default_factory=list)
606
714
  redteam_summary: ResultSummary = field(default_factory=ResultSummary)
607
715
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
608
716
 
@@ -611,14 +719,14 @@ class RedTeamRiskMitigationGuardrailsPolicyConfig(BaseDTO):
611
719
  data = data.copy()
612
720
  summary = ResultSummary.from_dict(data.pop("redteam_summary", {}))
613
721
  return cls(
614
- required_detectors=[RiskGuardrailDetectorsEnum(detector) for detector in data.get("required_detectors", [])],
722
+ # required_detectors=[RiskGuardrailDetectorsEnum(detector) for detector in data.get("required_detectors", [])],
615
723
  redteam_summary=summary,
616
724
  _extra_fields=data,
617
725
  )
618
726
 
619
727
  def to_dict(self) -> dict:
620
728
  return {
621
- "required_detectors": [detector.value for detector in self.required_detectors],
729
+ # "required_detectors": [detector.value for detector in self.required_detectors],
622
730
  "redteam_summary": self.redteam_summary.to_dict(),
623
731
  }
624
732
 
@@ -29,17 +29,29 @@ class ModelClient(BaseClient):
29
29
  if isinstance(config, dict):
30
30
  config = ModelConfig.from_dict(config)
31
31
 
32
- # Parse endpoint_url into components
33
- parsed_url = urlparse(config.model_config.endpoint_url)
34
- path_parts = parsed_url.path.strip("/").split("/")
35
-
36
- # Extract base_path and endpoint path
37
- if len(path_parts) >= 1:
38
- base_path = path_parts[0] # Usually 'v1'
39
- remaining_path = "/".join(path_parts[1:]) # The rest of the path
40
- else:
41
- base_path = ""
42
- remaining_path = ""
32
+ endpoint_data = {}
33
+ base_path = ""
34
+ remaining_path = ""
35
+
36
+ if config.model_config.endpoint_url:
37
+ # Parse endpoint_url into components
38
+ parsed_url = urlparse(config.model_config.endpoint_url)
39
+ path_parts = parsed_url.path.strip("/").split("/")
40
+
41
+ # Extract base_path and endpoint path
42
+ if len(path_parts) >= 1:
43
+ base_path = path_parts[0] # Usually 'v1'
44
+ remaining_path = "/".join(path_parts[1:]) # The rest of the path
45
+ else:
46
+ base_path = ""
47
+ remaining_path = ""
48
+
49
+ endpoint_data = {
50
+ "scheme": parsed_url.scheme,
51
+ "host": parsed_url.hostname,
52
+ "port": parsed_url.port or (443 if parsed_url.scheme == "https" else 80),
53
+ "base_path": f"/{base_path}",
54
+ }
43
55
 
44
56
  if config.model_config.paths:
45
57
  paths = config.model_config.paths.to_dict()
@@ -61,18 +73,15 @@ class ModelClient(BaseClient):
61
73
  "hosting_type": config.model_config.hosting_type,
62
74
  "model_source": config.model_config.model_source,
63
75
  "system_prompt": config.model_config.system_prompt,
64
- "endpoint": {
65
- "scheme": parsed_url.scheme,
66
- "host": parsed_url.hostname,
67
- "port": parsed_url.port or (443 if parsed_url.scheme == "https" else 80),
68
- "base_path": f"/{base_path}",
69
- },
76
+ "endpoint": endpoint_data,
70
77
  "paths": paths,
71
78
  "auth_data": {
72
79
  "header_name": config.model_config.auth_data.header_name,
73
80
  "header_prefix": config.model_config.auth_data.header_prefix,
74
81
  "space_after_prefix": config.model_config.auth_data.space_after_prefix,
75
82
  },
83
+ "model_auth_type": config.model_config.model_auth_type,
84
+ "model_jwt_config": config.model_config.model_jwt_config,
76
85
  "apikeys": [config.model_config.apikey] if config.model_config.apikey else [],
77
86
  "tools": config.model_config.tools,
78
87
  "input_modalities": [m.value if hasattr(m, 'value') else m for m in config.model_config.input_modalities],
@@ -149,92 +158,33 @@ class ModelClient(BaseClient):
149
158
  except Exception as e:
150
159
  return {"error": str(e)}
151
160
 
152
- def modify_model(self, config: ModelConfig, old_model_saved_name=None, old_model_version=None) -> ModelResponse:
161
+ def modify_model(self, config: ModelConfig | dict, old_model_saved_name=None, old_model_version=None) -> ModelResponse:
153
162
  """
154
163
  Modify an existing model in the system.
155
164
 
156
165
  Args:
157
- old_model_saved_name (str): The old saved name of the model to modify
158
- old_model_version (str): The old version of the model to modify
159
- config (ModelConfig): Configuration object containing model details
166
+ config (Union[ModelConfig, dict]): Configuration object or dictionary containing model details
167
+ old_model_saved_name (str, optional): The old saved name of the model to modify. Defaults to None.
168
+ old_model_version (str, optional): The old version of the model to modify. Defaults to None.
160
169
 
161
170
  Returns:
162
171
  dict: Response from the API containing the modified model details
163
172
  """
173
+
174
+ temp_config = config
175
+ if isinstance(temp_config, dict):
176
+ temp_config = ModelConfig.from_dict(temp_config)
177
+
164
178
  if old_model_saved_name is None:
165
- old_model_saved_name = config["model_saved_name"]
179
+ old_model_saved_name = temp_config.model_saved_name
166
180
 
167
181
  if old_model_version is None:
168
- old_model_version = config["model_version"]
182
+ old_model_version = temp_config.model_version
169
183
 
170
184
  headers = {"Content-Type": "application/json", "X-Enkrypt-Model": old_model_saved_name, "X-Enkrypt-Model-Version": old_model_version}
171
- # print(config)
172
- config = ModelConfig.from_dict(config)
173
- # Parse endpoint_url into components
174
- parsed_url = urlparse(config.model_config.endpoint_url)
175
- path_parts = parsed_url.path.strip("/").split("/")
176
-
177
- # Extract base_path and endpoint path
178
- if len(path_parts) >= 1:
179
- base_path = path_parts[0] # Usually 'v1'
180
- remaining_path = "/".join(path_parts[1:]) # The rest of the path
181
- else:
182
- base_path = ""
183
- remaining_path = ""
184
-
185
- if config.model_config.paths:
186
- paths = config.model_config.paths.to_dict()
187
- else:
188
- # Determine paths based on the endpoint
189
- paths = {
190
- "completions": (
191
- f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
192
- ),
193
- "chat": f"/{remaining_path}" if remaining_path else "",
194
- }
195
-
196
- # Convert custom_headers to list of dictionaries
197
- custom_headers = [header.to_dict() for header in config.model_config.custom_headers]
185
+
186
+ payload = self.prepare_model_payload(temp_config)
198
187
 
199
- payload = {
200
- "model_saved_name": config.model_saved_name,
201
- "model_version": config.model_version,
202
- "testing_for": config.testing_for,
203
- "model_name": config.model_name,
204
- "certifications": config.certifications,
205
- "model_config": {
206
- "model_provider": config.model_config.model_provider,
207
- "hosting_type": config.model_config.hosting_type,
208
- "model_source": config.model_config.model_source,
209
- "system_prompt": config.model_config.system_prompt,
210
- "endpoint": {
211
- "scheme": parsed_url.scheme,
212
- "host": parsed_url.hostname,
213
- "port": parsed_url.port
214
- or (443 if parsed_url.scheme == "https" else 80),
215
- "base_path": f"/{base_path}", # Just v1
216
- },
217
- "paths": paths,
218
- "auth_data": {
219
- "header_name": config.model_config.auth_data.header_name,
220
- "header_prefix": config.model_config.auth_data.header_prefix,
221
- "space_after_prefix": config.model_config.auth_data.space_after_prefix,
222
- },
223
- "apikeys": (
224
- [config.model_config.apikey] if config.model_config.apikey else []
225
- ),
226
- "tools": config.model_config.tools,
227
- "input_modalities": [m.value if hasattr(m, 'value') else m for m in config.model_config.input_modalities],
228
- "output_modalities": [m.value if hasattr(m, 'value') else m for m in config.model_config.output_modalities],
229
- "custom_curl_command": config.model_config.custom_curl_command,
230
- "custom_headers": custom_headers,
231
- "custom_payload": config.model_config.custom_payload,
232
- "custom_response_content_type": config.model_config.custom_response_content_type,
233
- "custom_response_format": config.model_config.custom_response_format,
234
- "metadata": config.model_config.metadata,
235
- "default_request_options": config.model_config.default_request_options,
236
- },
237
- }
238
188
  try:
239
189
  response = self._request(
240
190
  "PATCH", "/models/modify-model", headers=headers, json=payload
@@ -367,6 +367,7 @@ class RedTeamClient(BaseClient):
367
367
  response = self._request("GET", "/redteam/get-task", headers=headers)
368
368
  if response.get("error"):
369
369
  raise RedTeamClientError(f"API Error: {str(response)}")
370
+ # print(f"RedTeamTaskDetails response: {response}")
370
371
  return RedTeamTaskDetails.from_dict(response["data"])
371
372
 
372
373
  def get_result_summary(self, task_id: str = None, test_name: str = None):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: enkryptai-sdk
3
- Version: 1.0.15
3
+ Version: 1.0.17
4
4
  Summary: A Python SDK with guardrails and red teaming functionality for API interactions
5
5
  Home-page: https://github.com/enkryptai/enkryptai-sdk
6
6
  Author: Enkrypt AI Team
@@ -526,9 +526,6 @@ sample_custom_redteam_model_config = {
526
526
 
527
527
  ```python Python
528
528
  sample_redteam_risk_mitigation_guardrails_policy_config = {
529
- "required_detectors": [
530
- "policy_violation"
531
- ],
532
529
  "redteam_summary": {
533
530
  "category": [
534
531
  {
File without changes
File without changes