enkryptai-sdk 1.0.25__py3-none-any.whl → 1.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ from .models import ModelConfig
9
9
  from .guardrails import GuardrailDetectors
10
10
  from .common import ModelAuthTypeEnum, CustomHeader, ModelJwtConfig
11
11
 
12
+
12
13
  # The risk mitigation do not support all detectors, so we need to create a separate enum for them.
13
14
  class RiskGuardrailDetectorsEnum(str, Enum):
14
15
  NSFW = "nsfw"
@@ -27,17 +28,13 @@ class RiskGuardrailDetectorsEnum(str, Enum):
27
28
  class RedteamHealthResponse(BaseDTO):
28
29
  status: str
29
30
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
30
-
31
+
31
32
  @classmethod
32
33
  def from_dict(cls, data: Dict[str, Any]) -> "RedteamHealthResponse":
33
- return cls(
34
- status=data.get("status", "")
35
- )
36
-
34
+ return cls(status=data.get("status", ""))
35
+
37
36
  def to_dict(self) -> Dict[str, Any]:
38
- return {
39
- "status": self.status
40
- }
37
+ return {"status": self.status}
41
38
 
42
39
 
43
40
  @dataclass
@@ -55,7 +52,7 @@ class RedTeamResponse(BaseDTO):
55
52
  message=data.get("message"),
56
53
  data=data.get("data"),
57
54
  )
58
-
55
+
59
56
  def to_dict(self) -> Dict:
60
57
  return super().to_dict()
61
58
 
@@ -77,7 +74,7 @@ class RedTeamTaskDetailsModelConfig(BaseDTO):
77
74
  system_prompt=data.get("system_prompt"),
78
75
  model_version=data.get("model_version"),
79
76
  )
80
-
77
+
81
78
  def to_dict(self) -> Dict:
82
79
  return {
83
80
  "system_prompt": self.system_prompt,
@@ -108,9 +105,11 @@ class RedTeamTaskDetails(BaseDTO):
108
105
  status=data.get("status"),
109
106
  test_name=data.get("test_name"),
110
107
  task_id=data.get("task_id"),
111
- model_config=RedTeamTaskDetailsModelConfig.from_dict(data.get("model_config", {})),
108
+ model_config=RedTeamTaskDetailsModelConfig.from_dict(
109
+ data.get("model_config", {})
110
+ ),
112
111
  )
113
-
112
+
114
113
  def to_dict(self) -> Dict:
115
114
  return {
116
115
  "created_at": self.created_at,
@@ -191,8 +190,10 @@ class ResultSummary(BaseDTO):
191
190
  for key, value in item.items():
192
191
  result[key] = StatisticItem.from_dict(value)
193
192
  return result
194
-
195
- def convert_stat_test_type_list(stat_list: List[Dict]) -> Dict[str, StatisticItemWithTestType]:
193
+
194
+ def convert_stat_test_type_list(
195
+ stat_list: List[Dict],
196
+ ) -> Dict[str, StatisticItemWithTestType]:
196
197
  result = {}
197
198
  for item in stat_list:
198
199
  for key, value in item.items():
@@ -221,8 +222,10 @@ class ResultSummary(BaseDTO):
221
222
  def to_dict(self) -> Dict:
222
223
  def convert_stat_dict(stat_dict: Dict[str, StatisticItem]) -> List[Dict]:
223
224
  return [{key: value.to_dict()} for key, value in stat_dict.items()]
224
-
225
- def convert_stat_test_type_dict(stat_dict: Dict[str, StatisticItemWithTestType]) -> List[Dict]:
225
+
226
+ def convert_stat_test_type_dict(
227
+ stat_dict: Dict[str, StatisticItemWithTestType],
228
+ ) -> List[Dict]:
226
229
  return [{key: value.to_dict()} for key, value in stat_dict.items()]
227
230
 
228
231
  d = super().to_dict()
@@ -248,9 +251,11 @@ class RedTeamResultSummary(BaseDTO):
248
251
  def from_dict(cls, data: Dict) -> "RedTeamResultSummary":
249
252
  if not data or "summary" not in data:
250
253
  return cls(summary=ResultSummary.from_dict({}))
251
-
254
+
252
255
  if "task_status" in data:
253
- return cls(summary=ResultSummary.from_dict({}), task_status=data["task_status"])
256
+ return cls(
257
+ summary=ResultSummary.from_dict({}), task_status=data["task_status"]
258
+ )
254
259
 
255
260
  return cls(summary=ResultSummary.from_dict(data["summary"]))
256
261
 
@@ -295,22 +300,22 @@ class RedTeamResultDetails(BaseDTO):
295
300
  def from_dict(cls, data: Dict) -> "RedTeamResultDetails":
296
301
  if not data or "details" not in data:
297
302
  return cls(details=[])
298
-
303
+
299
304
  if "task_status" in data:
300
305
  return cls(details=[], task_status=data["task_status"])
301
306
 
302
307
  # details = []
303
308
  # for result in data["details"]:
304
- # Convert eval_tokens dict to TestEvalTokens object
305
- # eval_tokens = TestEvalTokens(**result["eval_tokens"])
309
+ # Convert eval_tokens dict to TestEvalTokens object
310
+ # eval_tokens = TestEvalTokens(**result["eval_tokens"])
306
311
 
307
- # Create a copy of the result dict and replace eval_tokens
308
- # result_copy = dict(result["details"])
309
- # result_copy["eval_tokens"] = eval_tokens
312
+ # Create a copy of the result dict and replace eval_tokens
313
+ # result_copy = dict(result["details"])
314
+ # result_copy["eval_tokens"] = eval_tokens
310
315
 
311
- # Create TestResult object
312
- # test_result = TestResult(**result_copy)
313
- # details.append(test_result)
316
+ # Create TestResult object
317
+ # test_result = TestResult(**result_copy)
318
+ # details.append(test_result)
314
319
 
315
320
  return cls(details=data["details"])
316
321
 
@@ -336,8 +341,19 @@ class AttackMethods(BaseDTO):
336
341
  basic: List[str] = field(default_factory=lambda: ["basic"])
337
342
  advanced: Dict[str, List[str]] = field(
338
343
  default_factory=lambda: {
339
- "static": ["masking", "figstep", "hades","encoding", "single_shot", "echo", "speed", "pitch", "reverb", "noise" ],
340
- "dynamic": ["iterative","jood"]
344
+ "static": [
345
+ "masking",
346
+ "figstep",
347
+ "hades",
348
+ "encoding",
349
+ "single_shot",
350
+ "echo",
351
+ "speed",
352
+ "pitch",
353
+ "reverb",
354
+ "noise",
355
+ ],
356
+ "dynamic": ["iterative", "jood"],
341
357
  }
342
358
  )
343
359
 
@@ -450,7 +466,7 @@ class OutputModality(str, Enum):
450
466
  # audio = "audio"
451
467
  # video = "video"
452
468
  # code = "code"
453
-
469
+
454
470
 
455
471
  @dataclass
456
472
  class TargetModelConfiguration(BaseDTO):
@@ -478,13 +494,17 @@ class TargetModelConfiguration(BaseDTO):
478
494
  def from_dict(cls, data: dict):
479
495
  data = data.copy()
480
496
  if "custom_headers" in data:
481
- data["custom_headers"] = [CustomHeader.from_dict(header) for header in data["custom_headers"]]
497
+ data["custom_headers"] = [
498
+ CustomHeader.from_dict(header) for header in data["custom_headers"]
499
+ ]
482
500
  if "model_auth_type" in data:
483
501
  data["model_auth_type"] = ModelAuthTypeEnum(data["model_auth_type"])
484
502
  if "model_jwt_config" in data:
485
- data["model_jwt_config"] = ModelJwtConfig.from_dict(data["model_jwt_config"])
503
+ data["model_jwt_config"] = ModelJwtConfig.from_dict(
504
+ data["model_jwt_config"]
505
+ )
486
506
  return cls(**data)
487
-
507
+
488
508
  def to_dict(self) -> dict:
489
509
  d = asdict(self)
490
510
  d["model_auth_type"] = self.model_auth_type.value
@@ -523,9 +543,8 @@ class RedTeamModelHealthConfigV3(BaseDTO):
523
543
  V3 format for model health check that accepts endpoint_configuration
524
544
  similar to add_custom_task.
525
545
  """
526
- endpoint_configuration: ModelConfig = field(
527
- default_factory=ModelConfig
528
- )
546
+
547
+ endpoint_configuration: ModelConfig = field(default_factory=ModelConfig)
529
548
 
530
549
  def to_dict(self) -> dict:
531
550
  d = asdict(self)
@@ -535,25 +554,27 @@ class RedTeamModelHealthConfigV3(BaseDTO):
535
554
  @classmethod
536
555
  def from_dict(cls, data: dict):
537
556
  data = data.copy()
538
- endpoint_config = ModelConfig.from_dict(
539
- data.pop("endpoint_configuration", {})
540
- )
557
+ endpoint_config = ModelConfig.from_dict(data.pop("endpoint_configuration", {}))
541
558
  return cls(
542
559
  endpoint_configuration=endpoint_config,
543
560
  )
544
-
561
+
545
562
  def to_target_model_configuration(self) -> TargetModelConfiguration:
546
563
  """
547
564
  Convert endpoint_configuration to target_model_configuration format.
548
565
  This enables the V3 format to be compatible with the existing backend API.
549
566
  """
550
567
  model_config = self.endpoint_configuration.model_config
551
-
568
+
552
569
  return TargetModelConfiguration(
553
570
  testing_for=self.endpoint_configuration.testing_for,
554
571
  system_prompt=model_config.system_prompt,
555
572
  model_source=model_config.model_source,
556
- model_provider=model_config.model_provider.value if hasattr(model_config.model_provider, 'value') else model_config.model_provider,
573
+ model_provider=(
574
+ model_config.model_provider.value
575
+ if hasattr(model_config.model_provider, "value")
576
+ else model_config.model_provider
577
+ ),
557
578
  model_endpoint_url=model_config.endpoint_url,
558
579
  rate_per_min=model_config.rate_per_min,
559
580
  model_name=self.endpoint_configuration.model_name,
@@ -561,8 +582,14 @@ class RedTeamModelHealthConfigV3(BaseDTO):
561
582
  model_auth_type=model_config.model_auth_type,
562
583
  model_jwt_config=model_config.model_jwt_config,
563
584
  model_api_key=model_config.apikey,
564
- input_modalities=[InputModality(m) if isinstance(m, str) else m for m in model_config.input_modalities],
565
- output_modalities=[OutputModality(m) if isinstance(m, str) else m for m in model_config.output_modalities],
585
+ input_modalities=[
586
+ InputModality(m) if isinstance(m, str) else m
587
+ for m in model_config.input_modalities
588
+ ],
589
+ output_modalities=[
590
+ OutputModality(m) if isinstance(m, str) else m
591
+ for m in model_config.output_modalities
592
+ ],
566
593
  custom_curl_command=model_config.custom_curl_command,
567
594
  custom_headers=model_config.custom_headers,
568
595
  custom_payload=model_config.custom_payload,
@@ -578,22 +605,22 @@ class RedteamModelHealthResponse(BaseDTO):
578
605
  error: str
579
606
  data: Optional[Dict[str, Any]] = field(default_factory=dict)
580
607
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
581
-
608
+
582
609
  @classmethod
583
610
  def from_dict(cls, data: Dict[str, Any]) -> "RedteamModelHealthResponse":
584
611
  return cls(
585
612
  status=data.get("status", ""),
586
613
  message=data.get("message", ""),
587
614
  data=data.get("data", {}),
588
- error=data.get("error", "")
615
+ error=data.get("error", ""),
589
616
  )
590
-
617
+
591
618
  def to_dict(self) -> Dict[str, Any]:
592
619
  return {
593
620
  "status": self.status,
594
621
  "message": self.message,
595
622
  "data": self.data,
596
- "error": self.error
623
+ "error": self.error,
597
624
  }
598
625
 
599
626
 
@@ -671,20 +698,22 @@ class RedTeamCustomConfig(BaseDTO):
671
698
  redteam_test_configurations: RedTeamTestConfigurations = field(
672
699
  default_factory=RedTeamTestConfigurations
673
700
  )
674
- dataset_configuration: DatasetConfig = field(
675
- default_factory=DatasetConfig
676
- )
677
- endpoint_configuration: ModelConfig = field(
678
- default_factory=ModelConfig
679
- )
701
+ dataset_configuration: Optional[DatasetConfig] = None
702
+ endpoint_configuration: Optional[ModelConfig] = None
680
703
 
681
704
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
682
705
 
683
706
  def to_dict(self) -> dict:
684
707
  d = asdict(self)
685
708
  d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
686
- d["dataset_configuration"] = self.dataset_configuration.to_dict()
687
- d["endpoint_configuration"] = self.endpoint_configuration.to_dict()
709
+ if self.dataset_configuration is not None:
710
+ d["dataset_configuration"] = self.dataset_configuration.to_dict()
711
+ else:
712
+ d.pop("dataset_configuration", None)
713
+ if self.endpoint_configuration is not None:
714
+ d["endpoint_configuration"] = self.endpoint_configuration.to_dict()
715
+ else:
716
+ d.pop("endpoint_configuration", None)
688
717
  return d
689
718
 
690
719
  @classmethod
@@ -693,12 +722,18 @@ class RedTeamCustomConfig(BaseDTO):
693
722
  test_configs = RedTeamTestConfigurations.from_dict(
694
723
  data.pop("redteam_test_configurations", {})
695
724
  )
696
- dataset_config = DatasetConfig.from_dict(
697
- data.pop("dataset_configuration", {})
698
- )
699
- endpoint_config = ModelConfig.from_dict(
700
- data.pop("endpoint_configuration", {})
701
- )
725
+ dataset_config = None
726
+ if "dataset_configuration" in data and data["dataset_configuration"]:
727
+ dataset_config = DatasetConfig.from_dict(data.pop("dataset_configuration"))
728
+ else:
729
+ data.pop("dataset_configuration", None)
730
+
731
+ endpoint_config = None
732
+ if "endpoint_configuration" in data and data["endpoint_configuration"]:
733
+ endpoint_config = ModelConfig.from_dict(data.pop("endpoint_configuration"))
734
+ else:
735
+ data.pop("endpoint_configuration", None)
736
+
702
737
  return cls(
703
738
  **data,
704
739
  redteam_test_configurations=test_configs,
@@ -706,6 +741,7 @@ class RedTeamCustomConfig(BaseDTO):
706
741
  endpoint_configuration=endpoint_config,
707
742
  )
708
743
 
744
+
709
745
  @dataclass
710
746
  class RedTeamCustomConfigWithSavedModel(BaseDTO):
711
747
  test_name: str = "Test Name"
@@ -714,16 +750,17 @@ class RedTeamCustomConfigWithSavedModel(BaseDTO):
714
750
  redteam_test_configurations: RedTeamTestConfigurations = field(
715
751
  default_factory=RedTeamTestConfigurations
716
752
  )
717
- dataset_configuration: DatasetConfig = field(
718
- default_factory=DatasetConfig
719
- )
753
+ dataset_configuration: Optional[DatasetConfig] = None
720
754
 
721
755
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
722
756
 
723
757
  def to_dict(self) -> dict:
724
758
  d = asdict(self)
725
759
  d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
726
- d["dataset_configuration"] = self.dataset_configuration.to_dict()
760
+ if self.dataset_configuration is not None:
761
+ d["dataset_configuration"] = self.dataset_configuration.to_dict()
762
+ else:
763
+ d.pop("dataset_configuration", None)
727
764
  return d
728
765
 
729
766
  @classmethod
@@ -732,9 +769,12 @@ class RedTeamCustomConfigWithSavedModel(BaseDTO):
732
769
  test_configs = RedTeamTestConfigurations.from_dict(
733
770
  data.pop("redteam_test_configurations", {})
734
771
  )
735
- dataset_config = DatasetConfig.from_dict(
736
- data.pop("dataset_configuration", {})
737
- )
772
+ dataset_config = None
773
+ if "dataset_configuration" in data and data["dataset_configuration"]:
774
+ dataset_config = DatasetConfig.from_dict(data.pop("dataset_configuration"))
775
+ else:
776
+ data.pop("dataset_configuration", None)
777
+
738
778
  return cls(
739
779
  **data,
740
780
  redteam_test_configurations=test_configs,
@@ -749,7 +789,7 @@ class RedTeamTaskList(BaseDTO):
749
789
  def to_dataframe(self) -> pd.DataFrame:
750
790
  data = [task for task in self.tasks]
751
791
  return pd.DataFrame(data)
752
-
792
+
753
793
 
754
794
  @dataclass
755
795
  class RedTeamRiskMitigationGuardrailsPolicyConfig(BaseDTO):
@@ -793,21 +833,21 @@ class RedTeamRiskMitigationGuardrailsPolicyResponse(BaseDTO):
793
833
 
794
834
  def to_dict(self) -> dict:
795
835
  policy_dict = self.guardrails_policy.to_dict()
796
-
836
+
797
837
  # Remove detector entries that are disabled and have no other config
798
838
  final_policy_dict = {}
799
839
  for key, value in policy_dict.items():
800
840
  if isinstance(value, dict):
801
841
  # Check if 'enabled' is the only key and its value is False
802
- if list(value.keys()) == ['enabled'] and not value['enabled']:
842
+ if list(value.keys()) == ["enabled"] and not value["enabled"]:
803
843
  continue
804
844
  # Check for empty detectors that only have 'enabled': False
805
845
  if not value.get("enabled", True) and len(value) == 1:
806
846
  continue
807
847
  # check for other empty values
808
- if not any(v for k, v in value.items() if k != 'enabled'):
809
- if not value.get('enabled'):
810
- continue
848
+ if not any(v for k, v in value.items() if k != "enabled"):
849
+ if not value.get("enabled"):
850
+ continue
811
851
  final_policy_dict[key] = value
812
852
 
813
853
  return {
@@ -862,21 +902,18 @@ class RedTeamRiskMitigationSystemPromptResponse(BaseDTO):
862
902
  "message": self.message,
863
903
  }
864
904
 
905
+
865
906
  @dataclass
866
907
  class RedTeamKeyFinding(BaseDTO):
867
908
  text: str
868
909
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
869
-
910
+
870
911
  @classmethod
871
912
  def from_dict(cls, data: Dict[str, Any]) -> "RedTeamKeyFinding":
872
- return cls(
873
- text=data.get("text", "")
874
- )
875
-
913
+ return cls(text=data.get("text", ""))
914
+
876
915
  def to_dict(self) -> Dict[str, Any]:
877
- result = {
878
- "text": self.text
879
- }
916
+ result = {"text": self.text}
880
917
  result.update(self._extra_fields)
881
918
  return result
882
919
 
@@ -886,21 +923,20 @@ class RedTeamFindingsResponse(BaseDTO):
886
923
  key_findings: List[RedTeamKeyFinding] = field(default_factory=list)
887
924
  message: str = ""
888
925
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
889
-
926
+
890
927
  @classmethod
891
928
  def from_dict(cls, data: Dict[str, Any]) -> "RedTeamFindingsResponse":
892
929
  key_findings_data = data.get("key_findings", [])
893
- key_findings = [RedTeamKeyFinding.from_dict(finding) for finding in key_findings_data]
894
-
895
- return cls(
896
- key_findings=key_findings,
897
- message=data.get("message", "")
898
- )
899
-
930
+ key_findings = [
931
+ RedTeamKeyFinding.from_dict(finding) for finding in key_findings_data
932
+ ]
933
+
934
+ return cls(key_findings=key_findings, message=data.get("message", ""))
935
+
900
936
  def to_dict(self) -> Dict[str, Any]:
901
937
  result = {
902
938
  "key_findings": [finding.to_dict() for finding in self.key_findings],
903
- "message": self.message
939
+ "message": self.message,
904
940
  }
905
941
  result.update(self._extra_fields)
906
942
  return result
@@ -912,20 +948,20 @@ class RedTeamDownloadLinkResponse(BaseDTO):
912
948
  expiry: str = ""
913
949
  expires_at: str = ""
914
950
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
915
-
951
+
916
952
  @classmethod
917
953
  def from_dict(cls, data: Dict[str, Any]) -> "RedTeamDownloadLinkResponse":
918
954
  return cls(
919
955
  link=data.get("link", ""),
920
956
  expiry=data.get("expiry", ""),
921
- expires_at=data.get("expires_at", "")
957
+ expires_at=data.get("expires_at", ""),
922
958
  )
923
-
959
+
924
960
  def to_dict(self) -> Dict[str, Any]:
925
961
  result = {
926
962
  "link": self.link,
927
963
  "expiry": self.expiry,
928
- "expires_at": self.expires_at
964
+ "expires_at": self.expires_at,
929
965
  }
930
966
  result.update(self._extra_fields)
931
967
  return result
@@ -944,6 +980,7 @@ class AttackMethodsV3(BaseDTO):
944
980
  }
945
981
  }
946
982
  """
983
+
947
984
  _data: Dict[str, Dict[str, Dict[str, Any]]] = field(default_factory=dict)
948
985
 
949
986
  def to_dict(self) -> dict:
@@ -970,13 +1007,15 @@ class TestConfigV3(BaseDTO):
970
1007
  attack_methods = AttackMethodsV3.from_dict(data.get("attack_methods", {}))
971
1008
  return cls(
972
1009
  sample_percentage=data.get("sample_percentage", 5),
973
- attack_methods=attack_methods
1010
+ attack_methods=attack_methods,
974
1011
  )
975
1012
 
976
1013
 
977
1014
  @dataclass
978
1015
  class RedTeamTestConfigurationsV3(BaseDTO):
979
1016
  """V3 format for red team test configurations with nested attack methods"""
1017
+
1018
+ version: str = "3.0"
980
1019
  # Basic tests
981
1020
  bias_test: TestConfigV3 = field(default=None)
982
1021
  cbrn_test: TestConfigV3 = field(default=None)
@@ -1019,7 +1058,12 @@ class RedTeamTestConfigurationsV3(BaseDTO):
1019
1058
 
1020
1059
  @classmethod
1021
1060
  def from_dict(cls, data: dict):
1022
- return cls(**{k: TestConfigV3.from_dict(v) if isinstance(v, dict) else v for k, v in data.items()})
1061
+ return cls(
1062
+ **{
1063
+ k: TestConfigV3.from_dict(v) if isinstance(v, dict) else v
1064
+ for k, v in data.items()
1065
+ }
1066
+ )
1023
1067
 
1024
1068
 
1025
1069
  @dataclass
@@ -1030,20 +1074,22 @@ class RedTeamCustomConfigV3(BaseDTO):
1030
1074
  redteam_test_configurations: RedTeamTestConfigurationsV3 = field(
1031
1075
  default_factory=RedTeamTestConfigurationsV3
1032
1076
  )
1033
- dataset_configuration: DatasetConfig = field(
1034
- default_factory=DatasetConfig
1035
- )
1036
- endpoint_configuration: ModelConfig = field(
1037
- default_factory=ModelConfig
1038
- )
1077
+ dataset_configuration: Optional[DatasetConfig] = None
1078
+ endpoint_configuration: Optional[ModelConfig] = None
1039
1079
 
1040
1080
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
1041
1081
 
1042
1082
  def to_dict(self) -> dict:
1043
1083
  d = asdict(self)
1044
1084
  d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
1045
- d["dataset_configuration"] = self.dataset_configuration.to_dict()
1046
- d["endpoint_configuration"] = self.endpoint_configuration.to_dict()
1085
+ if self.dataset_configuration is not None:
1086
+ d["dataset_configuration"] = self.dataset_configuration.to_dict()
1087
+ else:
1088
+ d.pop("dataset_configuration", None)
1089
+ if self.endpoint_configuration is not None:
1090
+ d["endpoint_configuration"] = self.endpoint_configuration.to_dict()
1091
+ else:
1092
+ d.pop("endpoint_configuration", None)
1047
1093
  return d
1048
1094
 
1049
1095
  @classmethod
@@ -1052,12 +1098,18 @@ class RedTeamCustomConfigV3(BaseDTO):
1052
1098
  test_configs = RedTeamTestConfigurationsV3.from_dict(
1053
1099
  data.pop("redteam_test_configurations", {})
1054
1100
  )
1055
- dataset_config = DatasetConfig.from_dict(
1056
- data.pop("dataset_configuration", {})
1057
- )
1058
- endpoint_config = ModelConfig.from_dict(
1059
- data.pop("endpoint_configuration", {})
1060
- )
1101
+ dataset_config = None
1102
+ if "dataset_configuration" in data and data["dataset_configuration"]:
1103
+ dataset_config = DatasetConfig.from_dict(data.pop("dataset_configuration"))
1104
+ else:
1105
+ data.pop("dataset_configuration", None)
1106
+
1107
+ endpoint_config = None
1108
+ if "endpoint_configuration" in data and data["endpoint_configuration"]:
1109
+ endpoint_config = ModelConfig.from_dict(data.pop("endpoint_configuration"))
1110
+ else:
1111
+ data.pop("endpoint_configuration", None)
1112
+
1061
1113
  return cls(
1062
1114
  **data,
1063
1115
  redteam_test_configurations=test_configs,
@@ -1074,16 +1126,17 @@ class RedTeamCustomConfigWithSavedModelV3(BaseDTO):
1074
1126
  redteam_test_configurations: RedTeamTestConfigurationsV3 = field(
1075
1127
  default_factory=RedTeamTestConfigurationsV3
1076
1128
  )
1077
- dataset_configuration: DatasetConfig = field(
1078
- default_factory=DatasetConfig
1079
- )
1129
+ dataset_configuration: Optional[DatasetConfig] = None
1080
1130
 
1081
1131
  _extra_fields: Dict[str, Any] = field(default_factory=dict)
1082
1132
 
1083
1133
  def to_dict(self) -> dict:
1084
1134
  d = asdict(self)
1085
1135
  d["redteam_test_configurations"] = self.redteam_test_configurations.to_dict()
1086
- d["dataset_configuration"] = self.dataset_configuration.to_dict()
1136
+ if self.dataset_configuration is not None:
1137
+ d["dataset_configuration"] = self.dataset_configuration.to_dict()
1138
+ else:
1139
+ d.pop("dataset_configuration", None)
1087
1140
  return d
1088
1141
 
1089
1142
  @classmethod
@@ -1092,9 +1145,12 @@ class RedTeamCustomConfigWithSavedModelV3(BaseDTO):
1092
1145
  test_configs = RedTeamTestConfigurationsV3.from_dict(
1093
1146
  data.pop("redteam_test_configurations", {})
1094
1147
  )
1095
- dataset_config = DatasetConfig.from_dict(
1096
- data.pop("dataset_configuration", {})
1097
- )
1148
+ dataset_config = None
1149
+ if "dataset_configuration" in data and data["dataset_configuration"]:
1150
+ dataset_config = DatasetConfig.from_dict(data.pop("dataset_configuration"))
1151
+ else:
1152
+ data.pop("dataset_configuration", None)
1153
+
1098
1154
  return cls(
1099
1155
  **data,
1100
1156
  redteam_test_configurations=test_configs,
@@ -1110,4 +1166,6 @@ DEFAULT_CUSTOM_REDTEAM_CONFIG = RedTeamCustomConfig()
1110
1166
  DEFAULT_CUSTOM_REDTEAM_CONFIG_WITH_SAVED_MODEL = RedTeamCustomConfigWithSavedModel()
1111
1167
 
1112
1168
  DEFAULT_CUSTOM_REDTEAM_CONFIG_V3 = RedTeamCustomConfigV3()
1113
- DEFAULT_CUSTOM_REDTEAM_CONFIG_WITH_SAVED_MODEL_V3 = RedTeamCustomConfigWithSavedModelV3()
1169
+ DEFAULT_CUSTOM_REDTEAM_CONFIG_WITH_SAVED_MODEL_V3 = (
1170
+ RedTeamCustomConfigWithSavedModelV3()
1171
+ )