enkryptai-sdk 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,17 +21,20 @@ class Tool(BaseDTO):
21
21
  "description": self.description
22
22
  }
23
23
 
24
+
24
25
  @dataclass
25
26
  class DatasetConfig(BaseDTO):
26
- dataset_name: str
27
- system_description: str
28
- policy_description: str = ""
29
- tools: List[Tool] = field(default_factory=list)
30
- info_pdf_url: str = ""
27
+ system_description: str = None
28
+ dataset_name: str = None
29
+ policy_description: str = None
30
+ risk_categories: str = None
31
+ tools: List[Tool] = None
32
+ info_pdf_url: str = None
31
33
  max_prompts: int = 100
32
34
  scenarios: int = 2
33
35
  categories: int = 2
34
36
  depth: int = 2
37
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
35
38
 
36
39
 
37
40
  @dataclass
@@ -51,6 +54,7 @@ class DatasetDataPoint(BaseDTO):
51
54
  category: str
52
55
  prompt: str
53
56
  source: str
57
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
54
58
 
55
59
 
56
60
  @dataclass
@@ -1,8 +1,7 @@
1
1
  from enum import Enum
2
2
  from .base import BaseDTO
3
3
  from dataclasses import dataclass, field
4
- from typing import List, Dict, Any, Set
5
-
4
+ from typing import List, Dict, Any, Set, Optional, BinaryIO
6
5
 
7
6
  class GuardrailsPIIModes(str, Enum):
8
7
  REQUEST = "request"
@@ -1334,13 +1333,13 @@ class GuardrailsPolicyData(BaseDTO):
1334
1333
 
1335
1334
 
1336
1335
  @dataclass
1337
- class GuardrailsaPolicyResponse(BaseDTO):
1336
+ class GuardrailsPolicyResponse(BaseDTO):
1338
1337
  message: str
1339
1338
  data: GuardrailsPolicyData
1340
1339
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
1341
1340
 
1342
1341
  @classmethod
1343
- def from_dict(cls, data: Dict[str, Any]) -> "GuardrailsaPolicyResponse":
1342
+ def from_dict(cls, data: Dict[str, Any]) -> "GuardrailsPolicyResponse":
1344
1343
  policy_data = data.get("data", {})
1345
1344
  return cls(
1346
1345
  message=data.get("message", ""),
@@ -1454,3 +1453,109 @@ class GuardrailsListPoliciesResponse(BaseDTO):
1454
1453
  result.update(self._extra_fields)
1455
1454
  return result
1456
1455
 
1456
+
1457
+ # -------------------------------------
1458
+ # Guardrails Policy Atomizer
1459
+ # -------------------------------------
1460
+
1461
+ @dataclass
1462
+ class GuardrailsPolicyAtomizerRequest(BaseDTO):
1463
+ text: Optional[str] = None
1464
+ file: Optional[BinaryIO] = None
1465
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
1466
+
1467
+ @classmethod
1468
+ def from_dict(cls, data: Dict[str, Any]) -> "GuardrailsPolicyAtomizerRequest":
1469
+ return cls(
1470
+ file=data.get("file", None),
1471
+ text=data.get("text", None)
1472
+ )
1473
+
1474
+ def to_dict(self) -> Dict[str, Any]:
1475
+ result = {}
1476
+ if self.file:
1477
+ result["file"] = self.file
1478
+ if self.text:
1479
+ result["text"] = self.text
1480
+ result.update(self._extra_fields)
1481
+ return result
1482
+
1483
+ def validate(self) -> bool:
1484
+ """
1485
+ Validate that either file or text is provided, but not both.
1486
+
1487
+ Returns:
1488
+ bool: True if valid, False otherwise
1489
+ """
1490
+ return bool(self.file) != bool(self.text) # XOR - only one should be True
1491
+
1492
+
1493
+ @dataclass
1494
+ class GuardrailsPolicyAtomizerResponse(BaseDTO):
1495
+ status: str = ""
1496
+ message: str = ""
1497
+ source: str = ""
1498
+ filename: str = ""
1499
+ total_rules: int = 0
1500
+ policy_rules: str = ""
1501
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
1502
+
1503
+ @classmethod
1504
+ def from_dict(cls, data: Dict[str, Any]) -> "GuardrailsPolicyAtomizerResponse":
1505
+ return cls(
1506
+ status=data.get("status", "success"),
1507
+ message=data.get("message", ""),
1508
+ source=data.get("source", ""),
1509
+ filename=data.get("filename", ""),
1510
+ total_rules=data.get("total_rules", 0),
1511
+ policy_rules=data.get("policy_rules", "")
1512
+ )
1513
+
1514
+ def to_dict(self) -> Dict[str, Any]:
1515
+ result = {
1516
+ "status": self.status,
1517
+ "message": self.message,
1518
+ "source": self.source,
1519
+ "filename": self.filename,
1520
+ "total_rules": self.total_rules,
1521
+ "policy_rules": self.policy_rules
1522
+ }
1523
+ result.update(self._extra_fields)
1524
+ return result
1525
+
1526
+ def is_successful(self) -> bool:
1527
+ """
1528
+ Check if the atomization was successful.
1529
+
1530
+ Returns:
1531
+ bool: True if status is "success", False otherwise
1532
+ """
1533
+ return self.status == "success"
1534
+
1535
+ def get_rules_list(self) -> List[str]:
1536
+ """
1537
+ Get the policy rules as a list of strings.
1538
+
1539
+ Returns:
1540
+ List[str]: List of individual policy rules
1541
+ """
1542
+ if not self.policy_rules:
1543
+ return []
1544
+ return [rule.strip() for rule in self.policy_rules.split('\n') if rule.strip()]
1545
+
1546
+ def __str__(self) -> str:
1547
+ """
1548
+ String representation of the response.
1549
+
1550
+ Returns:
1551
+ str: A formatted string showing the atomization results
1552
+ """
1553
+ source_info = f"File: {self.filename}" if self.source == "upload" else "Source: Text input"
1554
+ return (
1555
+ f"Policy Atomizer Response:\n"
1556
+ f"Status: {self.status}\n"
1557
+ f"{source_info}\n"
1558
+ f"Total Rules: {self.total_rules}\n"
1559
+ f"Message: {self.message}"
1560
+ )
1561
+
@@ -142,8 +142,8 @@ class CustomHeader(BaseDTO):
142
142
 
143
143
  @dataclass
144
144
  class ModelConfigDetails(BaseDTO):
145
- model_id: str = ""
146
- model_source: str = ""
145
+ model_id: str = None
146
+ model_source: str = None
147
147
  # model_provider: str = "openai"
148
148
  model_provider: ModelProviders = ModelProviders.OPENAI
149
149
  model_api_value: str = ""
@@ -259,14 +259,14 @@ class ModelConfigDetails(BaseDTO):
259
259
 
260
260
  @dataclass
261
261
  class ModelConfig(BaseDTO):
262
- created_at: str = ""
263
- updated_at: str = ""
264
- model_id: str = ""
265
- model_saved_name: str = "Model Name"
266
- model_version: str = "v1"
262
+ created_at: str = None
263
+ updated_at: str = None
264
+ model_id: str = None
265
+ model_saved_name: str = None
266
+ model_version: str = None
267
267
  testing_for: str = "foundationModels"
268
268
  # modality: Modality = Modality.TEXT
269
- project_name: str = ""
269
+ project_name: str = None
270
270
  model_name: Optional[str] = "gpt-4o-mini"
271
271
  certifications: List[str] = field(default_factory=list)
272
272
  model_config: ModelConfigDetails = field(default_factory=ModelConfigDetails)
@@ -3,6 +3,8 @@ from enum import Enum
3
3
  from .base import BaseDTO
4
4
  from typing import Dict, List, Optional, Any
5
5
  from dataclasses import dataclass, field, asdict
6
+ from .datasets import DatasetConfig
7
+ from .models import ModelConfig
6
8
 
7
9
 
8
10
  @dataclass
@@ -24,6 +26,7 @@ class RedteamHealthResponse(BaseDTO):
24
26
 
25
27
  @dataclass
26
28
  class RedTeamResponse(BaseDTO):
29
+ status: Optional[str] = None
27
30
  task_id: Optional[str] = None
28
31
  message: Optional[str] = None
29
32
  data: Optional[Dict[str, Any]] = None
@@ -300,7 +303,7 @@ class RedTeamTestConfigurations(BaseDTO):
300
303
  adv_info_test: TestConfig = field(default=None)
301
304
  adv_bias_test: TestConfig = field(default=None)
302
305
  adv_command_test: TestConfig = field(default=None)
303
- # custom_test: TestConfig = field(default=None)
306
+ custom_test: TestConfig = field(default=None)
304
307
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
305
308
 
306
309
  @classmethod
@@ -482,6 +485,83 @@ class RedTeamConfigWithSavedModel(BaseDTO):
482
485
  )
483
486
 
484
487
 
488
+ @dataclass
489
+ class RedTeamCustomConfig(BaseDTO):
490
+ test_name: str = "Test Name"
491
+
492
+ redteam_test_configurations: RedTeamTestConfigurations = field(
493
+ default_factory=RedTeamTestConfigurations
494
+ )
495
+ dataset_configuration: DatasetConfig = field(
496
+ default_factory=DatasetConfig
497
+ )
498
+ endpoint_configuration: ModelConfig = field(
499
+ default_factory=ModelConfig
500
+ )
501
+
502
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
503
+
504
+ def to_dict(self) -> dict:
505
+ d = asdict(self)
506
+ d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
507
+ d["dataset_configuration"] = self.dataset_configuration.to_dict()
508
+ d["endpoint_configuration"] = self.endpoint_configuration.to_dict()
509
+ return d
510
+
511
+ @classmethod
512
+ def from_dict(cls, data: dict):
513
+ data = data.copy()
514
+ test_configs = RedTeamTestConfigurations.from_dict(
515
+ data.pop("redteam_test_configurations", {})
516
+ )
517
+ dataset_config = DatasetConfig.from_dict(
518
+ data.pop("dataset_configuration", {})
519
+ )
520
+ endpoint_config = ModelConfig.from_dict(
521
+ data.pop("endpoint_configuration", {})
522
+ )
523
+ return cls(
524
+ **data,
525
+ redteam_test_configurations=test_configs,
526
+ dataset_configuration=dataset_config,
527
+ endpoint_configuration=endpoint_config,
528
+ )
529
+
530
+ @dataclass
531
+ class RedTeamCustomConfigWithSavedModel(BaseDTO):
532
+ test_name: str = "Test Name"
533
+
534
+ redteam_test_configurations: RedTeamTestConfigurations = field(
535
+ default_factory=RedTeamTestConfigurations
536
+ )
537
+ dataset_configuration: DatasetConfig = field(
538
+ default_factory=DatasetConfig
539
+ )
540
+
541
+ _extra_fields: Dict[str, Any] = field(default_factory=dict)
542
+
543
+ def to_dict(self) -> dict:
544
+ d = asdict(self)
545
+ d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
546
+ d["dataset_configuration"] = self.dataset_configuration.to_dict()
547
+ return d
548
+
549
+ @classmethod
550
+ def from_dict(cls, data: dict):
551
+ data = data.copy()
552
+ test_configs = RedTeamTestConfigurations.from_dict(
553
+ data.pop("redteam_test_configurations", {})
554
+ )
555
+ dataset_config = DatasetConfig.from_dict(
556
+ data.pop("dataset_configuration", {})
557
+ )
558
+ return cls(
559
+ **data,
560
+ redteam_test_configurations=test_configs,
561
+ dataset_configuration=dataset_config,
562
+ )
563
+
564
+
485
565
  @dataclass
486
566
  class RedTeamTaskList(BaseDTO):
487
567
  tasks: List[RedTeamTaskDetails] = field(default_factory=list)
@@ -494,3 +574,6 @@ class RedTeamTaskList(BaseDTO):
494
574
  # Default configurations
495
575
  DEFAULT_REDTEAM_CONFIG = RedTeamConfig()
496
576
  DEFAULT_REDTEAM_CONFIG_WITH_SAVED_MODEL = RedTeamConfigWithSavedModel()
577
+
578
+ DEFAULT_CUSTOM_REDTEAM_CONFIG = RedTeamCustomConfig()
579
+ DEFAULT_CUSTOM_REDTEAM_CONFIG_WITH_SAVED_MODEL = RedTeamCustomConfigWithSavedModel()
@@ -1,3 +1,4 @@
1
+ import os
1
2
  # import requests
2
3
  from .base import BaseClient
3
4
  from .config import GuardrailsConfig
@@ -24,10 +25,12 @@ from .dto import (
24
25
  GuardrailsRelevancyResponse,
25
26
  # GuardrailsPolicyRequest,
26
27
  GuardrailsPolicyData,
27
- GuardrailsaPolicyResponse,
28
+ GuardrailsPolicyResponse,
28
29
  # GuardrailsDeletePolicyData,
29
30
  GuardrailsDeletePolicyResponse,
30
31
  GuardrailsListPoliciesResponse,
32
+ GuardrailsPolicyAtomizerRequest,
33
+ GuardrailsPolicyAtomizerResponse,
31
34
  )
32
35
 
33
36
  # ---------------------------------------
@@ -92,7 +95,7 @@ class GuardrailsClient(BaseClient):
92
95
  try:
93
96
  response = self._request("GET", "/guardrails/health")
94
97
  if response.get("error"):
95
- raise GuardrailsClientError(response["error"])
98
+ raise GuardrailsClientError(f"API Error: {str(response)}")
96
99
  return GuardrailsHealthResponse.from_dict(response)
97
100
  except Exception as e:
98
101
  raise GuardrailsClientError(str(e))
@@ -104,7 +107,7 @@ class GuardrailsClient(BaseClient):
104
107
  try:
105
108
  response = self._request("GET", "/guardrails/status")
106
109
  if response.get("error"):
107
- raise GuardrailsClientError(response["error"])
110
+ raise GuardrailsClientError(f"API Error: {str(response)}")
108
111
  return GuardrailsHealthResponse.from_dict(response)
109
112
  except Exception as e:
110
113
  raise GuardrailsClientError(str(e))
@@ -116,7 +119,7 @@ class GuardrailsClient(BaseClient):
116
119
  try:
117
120
  response = self._request("GET", "/guardrails/models")
118
121
  if response.get("error"):
119
- raise GuardrailsClientError(response["error"])
122
+ raise GuardrailsClientError(f"API Error: {str(response)}")
120
123
  return GuardrailsModelsResponse.from_dict(response)
121
124
  except Exception as e:
122
125
  raise GuardrailsClientError(str(e))
@@ -152,7 +155,7 @@ class GuardrailsClient(BaseClient):
152
155
  try:
153
156
  response = self._request("POST", "/guardrails/detect", json=payload)
154
157
  if response.get("error"):
155
- raise GuardrailsClientError(response["error"])
158
+ raise GuardrailsClientError(f"API Error: {str(response)}")
156
159
  return GuardrailsDetectResponse.from_dict(response)
157
160
  except Exception as e:
158
161
  raise GuardrailsClientError(str(e))
@@ -188,7 +191,7 @@ class GuardrailsClient(BaseClient):
188
191
  try:
189
192
  response = self._request("POST", "/guardrails/batch/detect", json=payload)
190
193
  if isinstance(response, dict) and response.get("error"):
191
- raise GuardrailsClientError(response["error"])
194
+ raise GuardrailsClientError(f"API Error: {str(response)}")
192
195
  return GuardrailsBatchDetectResponse.from_dict(response)
193
196
  except Exception as e:
194
197
  raise GuardrailsClientError(str(e))
@@ -207,7 +210,7 @@ class GuardrailsClient(BaseClient):
207
210
  try:
208
211
  response = self._request("POST", "/guardrails/pii", json=payload)
209
212
  if response.get("error"):
210
- raise GuardrailsClientError(response["error"])
213
+ raise GuardrailsClientError(f"API Error: {str(response)}")
211
214
  return GuardrailsPIIResponse.from_dict(response)
212
215
  except Exception as e:
213
216
  raise GuardrailsClientError(str(e))
@@ -225,7 +228,7 @@ class GuardrailsClient(BaseClient):
225
228
  try:
226
229
  response = self._request("POST", "/guardrails/hallucination", json=payload)
227
230
  if response.get("error"):
228
- raise GuardrailsClientError(response["error"])
231
+ raise GuardrailsClientError(f"API Error: {str(response)}")
229
232
  return GuardrailsHallucinationResponse.from_dict(response)
230
233
  except Exception as e:
231
234
  raise GuardrailsClientError(str(e))
@@ -242,7 +245,7 @@ class GuardrailsClient(BaseClient):
242
245
  try:
243
246
  response = self._request("POST", "/guardrails/adherence", json=payload)
244
247
  if response.get("error"):
245
- raise GuardrailsClientError(response["error"])
248
+ raise GuardrailsClientError(f"API Error: {str(response)}")
246
249
  return GuardrailsAdherenceResponse.from_dict(response)
247
250
  except Exception as e:
248
251
  raise GuardrailsClientError(str(e))
@@ -259,7 +262,7 @@ class GuardrailsClient(BaseClient):
259
262
  try:
260
263
  response = self._request("POST", "/guardrails/relevancy", json=payload)
261
264
  if response.get("error"):
262
- raise GuardrailsClientError(response["error"])
265
+ raise GuardrailsClientError(f"API Error: {str(response)}")
263
266
  return GuardrailsRelevancyResponse.from_dict(response)
264
267
  except Exception as e:
265
268
  raise GuardrailsClientError(str(e))
@@ -293,8 +296,8 @@ class GuardrailsClient(BaseClient):
293
296
  try:
294
297
  response = self._request("POST", "/guardrails/add-policy", json=payload)
295
298
  if response.get("error"):
296
- raise GuardrailsClientError(response["error"])
297
- return GuardrailsaPolicyResponse.from_dict(response)
299
+ raise GuardrailsClientError(f"API Error: {str(response)}")
300
+ return GuardrailsPolicyResponse.from_dict(response)
298
301
  except Exception as e:
299
302
  raise GuardrailsClientError(str(e))
300
303
 
@@ -307,7 +310,7 @@ class GuardrailsClient(BaseClient):
307
310
  try:
308
311
  response = self._request("GET", "/guardrails/get-policy", headers=headers)
309
312
  if response.get("error"):
310
- raise GuardrailsClientError(response["error"])
313
+ raise GuardrailsClientError(f"API Error: {str(response)}")
311
314
  return GuardrailsPolicyData.from_dict(response)
312
315
  except Exception as e:
313
316
  raise GuardrailsClientError(str(e))
@@ -335,8 +338,8 @@ class GuardrailsClient(BaseClient):
335
338
  try:
336
339
  response = self._request("PATCH", "/guardrails/modify-policy", headers=headers, json=payload)
337
340
  if response.get("error"):
338
- raise GuardrailsClientError(response["error"])
339
- return GuardrailsaPolicyResponse.from_dict(response)
341
+ raise GuardrailsClientError(f"API Error: {str(response)}")
342
+ return GuardrailsPolicyResponse.from_dict(response)
340
343
  except Exception as e:
341
344
  raise GuardrailsClientError(str(e))
342
345
 
@@ -349,7 +352,7 @@ class GuardrailsClient(BaseClient):
349
352
  try:
350
353
  response = self._request("DELETE", "/guardrails/delete-policy", headers=headers)
351
354
  if response.get("error"):
352
- raise GuardrailsClientError(response["error"])
355
+ raise GuardrailsClientError(f"API Error: {str(response)}")
353
356
  return GuardrailsDeletePolicyResponse.from_dict(response)
354
357
  except Exception as e:
355
358
  raise GuardrailsClientError(str(e))
@@ -364,7 +367,7 @@ class GuardrailsClient(BaseClient):
364
367
  try:
365
368
  response = self._request("POST", "/guardrails/policy/detect", headers=headers, json=payload)
366
369
  if response.get("error"):
367
- raise GuardrailsClientError(response["error"])
370
+ raise GuardrailsClientError(f"API Error: {str(response)}")
368
371
  return GuardrailsDetectResponse.from_dict(response)
369
372
  except Exception as e:
370
373
  raise GuardrailsClientError(str(e))
@@ -377,7 +380,69 @@ class GuardrailsClient(BaseClient):
377
380
  try:
378
381
  response = self._request("GET", "/guardrails/list-policies")
379
382
  if isinstance(response, dict) and response.get("error"):
380
- raise GuardrailsClientError(response["error"])
383
+ raise GuardrailsClientError(f"API Error: {str(response)}")
381
384
  return GuardrailsListPoliciesResponse.from_dict(response)
382
385
  except Exception as e:
383
386
  raise GuardrailsClientError(str(e))
387
+
388
+ def atomize_policy(self, file=None, text=None):
389
+ """
390
+ Atomize a policy from either a file or text input.
391
+
392
+ Parameters:
393
+ - file (str, optional): Path to the policy file
394
+ - text (str, optional): Policy text content
395
+
396
+ Returns:
397
+ - GuardrailsPolicyAtomizerResponse
398
+
399
+ Raises:
400
+ - GuardrailsClientError: If validation fails or API returns an error
401
+ """
402
+ try:
403
+ # Create and validate request
404
+ request = GuardrailsPolicyAtomizerRequest(file=file, text=text)
405
+ if not request.validate():
406
+ raise GuardrailsClientError("Invalid request: Must provide either file or text. Not both.")
407
+
408
+ # Prepare the request based on input type
409
+ if file:
410
+ # Normalize file path and check existence
411
+ file_path = os.path.abspath(file)
412
+ file_name = os.path.basename(file_path)
413
+ print(f"File name: {file_name}")
414
+ print(f"Reading file: {file_path}")
415
+
416
+ if not os.path.exists(file_path):
417
+ raise GuardrailsClientError(f"File not found: {file_path}")
418
+
419
+ # Check file extension
420
+ if not file_path.lower().endswith('.pdf'):
421
+ raise GuardrailsClientError("Only PDF files are supported")
422
+
423
+ with open(file_path, 'rb') as f:
424
+ file_content = f.read()
425
+ # Create form data with filename
426
+ form_data = {
427
+ 'file': (file_name, file_content, 'application/pdf')
428
+ }
429
+ response = self._request(
430
+ "POST",
431
+ "/guardrails/policy-atomizer",
432
+ form_data=form_data
433
+ )
434
+ else:
435
+ form_data = {'text': text}
436
+ response = self._request(
437
+ "POST",
438
+ "/guardrails/policy-atomizer",
439
+ form_data=form_data
440
+ )
441
+
442
+ if isinstance(response, dict) and response.get("error"):
443
+ raise GuardrailsClientError(f"API Error: {str(response)}")
444
+
445
+ return GuardrailsPolicyAtomizerResponse.from_dict(response)
446
+ except Exception as e:
447
+ raise GuardrailsClientError(str(e))
448
+
enkryptai_sdk/models.py CHANGED
@@ -15,18 +15,20 @@ class ModelClient(BaseClient):
15
15
  def __init__(self, api_key: str, base_url: str = "https://api.enkryptai.com:443"):
16
16
  super().__init__(api_key, base_url)
17
17
 
18
- def add_model(self, config: ModelConfig) -> ModelResponse:
18
+ @staticmethod
19
+ def prepare_model_payload(config: ModelConfig | dict, is_custom: bool = False) -> dict:
19
20
  """
20
- Add a new model configuration to the system.
21
-
21
+ Prepare the payload for model operations from a config object.
22
+
22
23
  Args:
23
- config (ModelConfig): Configuration object containing model details
24
-
24
+ config (Union[ModelConfig, dict]): Configuration object or dictionary containing model details
25
+
25
26
  Returns:
26
- dict: Response from the API containing the added model details
27
+ dict: Processed payload ready for API submission
27
28
  """
28
- headers = {"Content-Type": "application/json"}
29
- config = ModelConfig.from_dict(config)
29
+ if isinstance(config, dict):
30
+ config = ModelConfig.from_dict(config)
31
+
30
32
  # Parse endpoint_url into components
31
33
  parsed_url = urlparse(config.model_config.endpoint_url)
32
34
  path_parts = parsed_url.path.strip("/").split("/")
@@ -43,15 +45,11 @@ class ModelClient(BaseClient):
43
45
  paths = config.model_config.paths.to_dict()
44
46
  else:
45
47
  paths = {
46
- "completions": (
47
- f"/{remaining_path.split('/')[-1]}" if remaining_path else ""
48
- ),
48
+ "completions": f"/{remaining_path.split('/')[-1]}" if remaining_path else "",
49
49
  "chat": f"/{remaining_path}" if remaining_path else "",
50
50
  }
51
51
 
52
52
  payload = {
53
- "model_saved_name": config.model_saved_name,
54
- "model_version": config.model_version,
55
53
  "testing_for": config.testing_for,
56
54
  "model_name": config.model_name,
57
55
  "certifications": config.certifications,
@@ -63,9 +61,8 @@ class ModelClient(BaseClient):
63
61
  "endpoint": {
64
62
  "scheme": parsed_url.scheme,
65
63
  "host": parsed_url.hostname,
66
- "port": parsed_url.port
67
- or (443 if parsed_url.scheme == "https" else 80),
68
- "base_path": f"/{base_path}", # Just v1
64
+ "port": parsed_url.port or (443 if parsed_url.scheme == "https" else 80),
65
+ "base_path": f"/{base_path}",
69
66
  },
70
67
  "paths": paths,
71
68
  "auth_data": {
@@ -73,9 +70,7 @@ class ModelClient(BaseClient):
73
70
  "header_prefix": config.model_config.auth_data.header_prefix,
74
71
  "space_after_prefix": config.model_config.auth_data.space_after_prefix,
75
72
  },
76
- "apikeys": (
77
- [config.model_config.apikey] if config.model_config.apikey else []
78
- ),
73
+ "apikeys": [config.model_config.apikey] if config.model_config.apikey else [],
79
74
  "tools": config.model_config.tools,
80
75
  "input_modalities": config.model_config.input_modalities,
81
76
  "output_modalities": config.model_config.output_modalities,
@@ -87,12 +82,32 @@ class ModelClient(BaseClient):
87
82
  "default_request_options": config.model_config.default_request_options,
88
83
  },
89
84
  }
85
+
86
+ if not is_custom:
87
+ payload["model_saved_name"] = config.model_saved_name
88
+ payload["model_version"] = config.model_version
89
+
90
+ return payload
91
+
92
+ def add_model(self, config: ModelConfig) -> ModelResponse:
93
+ """
94
+ Add a new model configuration to the system.
95
+
96
+ Args:
97
+ config (ModelConfig): Configuration object containing model details
98
+
99
+ Returns:
100
+ dict: Response from the API containing the added model details
101
+ """
102
+ headers = {"Content-Type": "application/json"}
103
+ payload = self.prepare_model_payload(config)
104
+
90
105
  try:
91
106
  response = self._request(
92
107
  "POST", "/models/add-model", headers=headers, json=payload
93
108
  )
94
109
  if response.get("error"):
95
- raise ModelClientError(response["error"])
110
+ raise ModelClientError(f"API Error: {str(response)}")
96
111
  return ModelResponse.from_dict(response)
97
112
  except Exception as e:
98
113
  raise ModelClientError(str(e))
@@ -112,7 +127,7 @@ class ModelClient(BaseClient):
112
127
  response = self._request("GET", "/models/get-model", headers=headers)
113
128
  # print(response)
114
129
  if response.get("error"):
115
- raise ModelClientError(response["error"])
130
+ raise ModelClientError(f"API Error: {str(response)}")
116
131
  return ModelConfig.from_dict(response)
117
132
 
118
133
  def get_model_list(self):
@@ -125,7 +140,7 @@ class ModelClient(BaseClient):
125
140
  try:
126
141
  response = self._request("GET", "/models/list-models")
127
142
  if isinstance(response, dict) and response.get("error"):
128
- raise ModelClientError(response["error"])
143
+ raise ModelClientError(f"API Error: {str(response)}")
129
144
  return ModelCollection.from_dict(response)
130
145
  except Exception as e:
131
146
  return {"error": str(e)}
@@ -217,7 +232,7 @@ class ModelClient(BaseClient):
217
232
  "PATCH", "/models/modify-model", headers=headers, json=payload
218
233
  )
219
234
  if response.get("error"):
220
- raise ModelClientError(response["error"])
235
+ raise ModelClientError(f"API Error: {str(response)}")
221
236
  return ModelResponse.from_dict(response)
222
237
  except Exception as e:
223
238
  raise ModelClientError(str(e))
@@ -238,7 +253,7 @@ class ModelClient(BaseClient):
238
253
  try:
239
254
  response = self._request("DELETE", "/models/delete-model", headers=headers)
240
255
  if response.get("error"):
241
- raise ModelClientError(response["error"])
256
+ raise ModelClientError(f"API Error: {str(response)}")
242
257
  return ModelResponse.from_dict(response)
243
258
  except Exception as e:
244
259
  raise ModelClientError(str(e))