azure-ai-evaluation 1.7.0__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/_common/onedp/operations/_operations.py +3 -1
- azure/ai/evaluation/_evaluate/_evaluate.py +4 -4
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +4 -4
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/red_team/_agent/__init__.py +3 -0
- azure/ai/evaluation/red_team/_agent/_agent_functions.py +264 -0
- azure/ai/evaluation/red_team/_agent/_agent_tools.py +503 -0
- azure/ai/evaluation/red_team/_agent/_agent_utils.py +69 -0
- azure/ai/evaluation/red_team/_agent/_semantic_kernel_plugin.py +237 -0
- azure/ai/evaluation/red_team/_attack_strategy.py +2 -0
- azure/ai/evaluation/red_team/_red_team.py +388 -78
- azure/ai/evaluation/red_team/_utils/_rai_service_eval_chat_target.py +121 -0
- azure/ai/evaluation/red_team/_utils/_rai_service_target.py +570 -0
- azure/ai/evaluation/red_team/_utils/_rai_service_true_false_scorer.py +108 -0
- azure/ai/evaluation/red_team/_utils/constants.py +5 -1
- azure/ai/evaluation/red_team/_utils/metric_mapping.py +2 -2
- azure/ai/evaluation/red_team/_utils/strategy_utils.py +2 -0
- azure/ai/evaluation/simulator/_adversarial_simulator.py +9 -2
- azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +1 -0
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +15 -7
- {azure_ai_evaluation-1.7.0.dist-info → azure_ai_evaluation-1.8.0.dist-info}/METADATA +10 -1
- {azure_ai_evaluation-1.7.0.dist-info → azure_ai_evaluation-1.8.0.dist-info}/RECORD +25 -17
- {azure_ai_evaluation-1.7.0.dist-info → azure_ai_evaluation-1.8.0.dist-info}/NOTICE.txt +0 -0
- {azure_ai_evaluation-1.7.0.dist-info → azure_ai_evaluation-1.8.0.dist-info}/WHEEL +0 -0
- {azure_ai_evaluation-1.7.0.dist-info → azure_ai_evaluation-1.8.0.dist-info}/top_level.txt +0 -0
|
@@ -48,6 +48,9 @@ from azure.core.credentials import TokenCredential
|
|
|
48
48
|
from ._red_team_result import RedTeamResult, RedTeamingScorecard, RedTeamingParameters, ScanResult
|
|
49
49
|
from ._attack_strategy import AttackStrategy
|
|
50
50
|
from ._attack_objective_generator import RiskCategory, _AttackObjectiveGenerator
|
|
51
|
+
from ._utils._rai_service_target import AzureRAIServiceTarget
|
|
52
|
+
from ._utils._rai_service_true_false_scorer import AzureRAIServiceTrueFalseScorer
|
|
53
|
+
from ._utils._rai_service_eval_chat_target import RAIServiceEvalChatTarget
|
|
51
54
|
|
|
52
55
|
# PyRIT imports
|
|
53
56
|
from pyrit.common import initialize_pyrit, DUCK_DB
|
|
@@ -55,9 +58,11 @@ from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget
|
|
|
55
58
|
from pyrit.models import ChatMessage
|
|
56
59
|
from pyrit.memory import CentralMemory
|
|
57
60
|
from pyrit.orchestrator.single_turn.prompt_sending_orchestrator import PromptSendingOrchestrator
|
|
61
|
+
from pyrit.orchestrator.multi_turn.red_teaming_orchestrator import RedTeamingOrchestrator
|
|
58
62
|
from pyrit.orchestrator import Orchestrator
|
|
59
63
|
from pyrit.exceptions import PyritException
|
|
60
64
|
from pyrit.prompt_converter import PromptConverter, MathPromptConverter, Base64Converter, FlipConverter, MorseConverter, AnsiAttackConverter, AsciiArtConverter, AsciiSmugglerConverter, AtbashConverter, BinaryConverter, CaesarConverter, CharacterSpaceConverter, CharSwapGenerator, DiacriticConverter, LeetspeakConverter, UrlConverter, UnicodeSubstitutionConverter, UnicodeConfusableConverter, SuffixAppendConverter, StringJoinConverter, ROT13Converter
|
|
65
|
+
from pyrit.orchestrator.multi_turn.crescendo_orchestrator import CrescendoOrchestrator
|
|
61
66
|
|
|
62
67
|
# Retry imports
|
|
63
68
|
import httpx
|
|
@@ -243,7 +248,7 @@ class RedTeam:
|
|
|
243
248
|
self.scan_id = None
|
|
244
249
|
self.scan_output_dir = None
|
|
245
250
|
|
|
246
|
-
self.generated_rai_client = GeneratedRAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.
|
|
251
|
+
self.generated_rai_client = GeneratedRAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.credential) #type: ignore
|
|
247
252
|
|
|
248
253
|
# Initialize a cache for attack objectives by risk category and strategy
|
|
249
254
|
self.attack_objectives = {}
|
|
@@ -857,9 +862,11 @@ class RedTeam:
|
|
|
857
862
|
chat_target: PromptChatTarget,
|
|
858
863
|
all_prompts: List[str],
|
|
859
864
|
converter: Union[PromptConverter, List[PromptConverter]],
|
|
865
|
+
*,
|
|
860
866
|
strategy_name: str = "unknown",
|
|
861
|
-
|
|
862
|
-
|
|
867
|
+
risk_category_name: str = "unknown",
|
|
868
|
+
risk_category: Optional[RiskCategory] = None,
|
|
869
|
+
timeout: int = 120,
|
|
863
870
|
) -> Orchestrator:
|
|
864
871
|
"""Send prompts via the PromptSendingOrchestrator with optimized performance.
|
|
865
872
|
|
|
@@ -877,6 +884,8 @@ class RedTeam:
|
|
|
877
884
|
:type converter: Union[PromptConverter, List[PromptConverter]]
|
|
878
885
|
:param strategy_name: Name of the attack strategy being used
|
|
879
886
|
:type strategy_name: str
|
|
887
|
+
:param risk_category_name: Name of the risk category being evaluated
|
|
888
|
+
:type risk_category_name: str
|
|
880
889
|
:param risk_category: Risk category being evaluated
|
|
881
890
|
:type risk_category: str
|
|
882
891
|
:param timeout: Timeout in seconds for each prompt
|
|
@@ -884,10 +893,10 @@ class RedTeam:
|
|
|
884
893
|
:return: Configured and initialized orchestrator
|
|
885
894
|
:rtype: Orchestrator
|
|
886
895
|
"""
|
|
887
|
-
task_key = f"{strategy_name}_{
|
|
896
|
+
task_key = f"{strategy_name}_{risk_category_name}_orchestrator"
|
|
888
897
|
self.task_statuses[task_key] = TASK_STATUS["RUNNING"]
|
|
889
898
|
|
|
890
|
-
log_strategy_start(self.logger, strategy_name,
|
|
899
|
+
log_strategy_start(self.logger, strategy_name, risk_category_name)
|
|
891
900
|
|
|
892
901
|
# Create converter list from single converter or list of converters
|
|
893
902
|
converter_list = [converter] if converter and isinstance(converter, PromptConverter) else converter if converter else []
|
|
@@ -910,7 +919,7 @@ class RedTeam:
|
|
|
910
919
|
)
|
|
911
920
|
|
|
912
921
|
if not all_prompts:
|
|
913
|
-
self.logger.warning(f"No prompts provided to orchestrator for {strategy_name}/{
|
|
922
|
+
self.logger.warning(f"No prompts provided to orchestrator for {strategy_name}/{risk_category_name}")
|
|
914
923
|
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
915
924
|
return orchestrator
|
|
916
925
|
|
|
@@ -930,15 +939,15 @@ class RedTeam:
|
|
|
930
939
|
else:
|
|
931
940
|
output_path = f"{base_path}{DATA_EXT}"
|
|
932
941
|
|
|
933
|
-
self.red_team_info[strategy_name][
|
|
942
|
+
self.red_team_info[strategy_name][risk_category_name]["data_file"] = output_path
|
|
934
943
|
|
|
935
944
|
# Process prompts concurrently within each batch
|
|
936
945
|
if len(all_prompts) > batch_size:
|
|
937
|
-
self.logger.debug(f"Processing {len(all_prompts)} prompts in batches of {batch_size} for {strategy_name}/{
|
|
946
|
+
self.logger.debug(f"Processing {len(all_prompts)} prompts in batches of {batch_size} for {strategy_name}/{risk_category_name}")
|
|
938
947
|
batches = [all_prompts[i:i + batch_size] for i in range(0, len(all_prompts), batch_size)]
|
|
939
948
|
|
|
940
949
|
for batch_idx, batch in enumerate(batches):
|
|
941
|
-
self.logger.debug(f"Processing batch {batch_idx+1}/{len(batches)} with {len(batch)} prompts for {strategy_name}/{
|
|
950
|
+
self.logger.debug(f"Processing batch {batch_idx+1}/{len(batches)} with {len(batch)} prompts for {strategy_name}/{risk_category_name}")
|
|
942
951
|
|
|
943
952
|
batch_start_time = datetime.now() # Send prompts in the batch concurrently with a timeout and retry logic
|
|
944
953
|
try: # Create retry decorator for this specific call with enhanced retry strategy
|
|
@@ -953,7 +962,7 @@ class RedTeam:
|
|
|
953
962
|
ConnectionError, TimeoutError, asyncio.TimeoutError, httpcore.ReadTimeout,
|
|
954
963
|
httpx.HTTPStatusError) as e:
|
|
955
964
|
# Log the error with enhanced information and allow retry logic to handle it
|
|
956
|
-
self.logger.warning(f"Network error in batch {batch_idx+1} for {strategy_name}/{
|
|
965
|
+
self.logger.warning(f"Network error in batch {batch_idx+1} for {strategy_name}/{risk_category_name}: {type(e).__name__}: {str(e)}")
|
|
957
966
|
# Add a small delay before retry to allow network recovery
|
|
958
967
|
await asyncio.sleep(1)
|
|
959
968
|
raise
|
|
@@ -961,32 +970,32 @@ class RedTeam:
|
|
|
961
970
|
# Execute the retry-enabled function
|
|
962
971
|
await send_batch_with_retry()
|
|
963
972
|
batch_duration = (datetime.now() - batch_start_time).total_seconds()
|
|
964
|
-
self.logger.debug(f"Successfully processed batch {batch_idx+1} for {strategy_name}/{
|
|
973
|
+
self.logger.debug(f"Successfully processed batch {batch_idx+1} for {strategy_name}/{risk_category_name} in {batch_duration:.2f} seconds")
|
|
965
974
|
|
|
966
975
|
# Print progress to console
|
|
967
976
|
if batch_idx < len(batches) - 1: # Don't print for the last batch
|
|
968
|
-
print(f"Strategy {strategy_name}, Risk {
|
|
977
|
+
print(f"Strategy {strategy_name}, Risk {risk_category_name}: Processed batch {batch_idx+1}/{len(batches)}")
|
|
969
978
|
|
|
970
979
|
except (asyncio.TimeoutError, tenacity.RetryError):
|
|
971
|
-
self.logger.warning(f"Batch {batch_idx+1} for {strategy_name}/{
|
|
972
|
-
self.logger.debug(f"Timeout: Strategy {strategy_name}, Risk {
|
|
973
|
-
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {
|
|
980
|
+
self.logger.warning(f"Batch {batch_idx+1} for {strategy_name}/{risk_category_name} timed out after {timeout} seconds, continuing with partial results")
|
|
981
|
+
self.logger.debug(f"Timeout: Strategy {strategy_name}, Risk {risk_category_name}, Batch {batch_idx+1} after {timeout} seconds.", exc_info=True)
|
|
982
|
+
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {risk_category_name}, Batch {batch_idx+1}")
|
|
974
983
|
# Set task status to TIMEOUT
|
|
975
|
-
batch_task_key = f"{strategy_name}_{
|
|
984
|
+
batch_task_key = f"{strategy_name}_{risk_category_name}_batch_{batch_idx+1}"
|
|
976
985
|
self.task_statuses[batch_task_key] = TASK_STATUS["TIMEOUT"]
|
|
977
|
-
self.red_team_info[strategy_name][
|
|
978
|
-
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=
|
|
986
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
987
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=batch_idx+1)
|
|
979
988
|
# Continue with partial results rather than failing completely
|
|
980
989
|
continue
|
|
981
990
|
except Exception as e:
|
|
982
|
-
log_error(self.logger, f"Error processing batch {batch_idx+1}", e, f"{strategy_name}/{
|
|
983
|
-
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {
|
|
984
|
-
self.red_team_info[strategy_name][
|
|
985
|
-
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=
|
|
991
|
+
log_error(self.logger, f"Error processing batch {batch_idx+1}", e, f"{strategy_name}/{risk_category_name}")
|
|
992
|
+
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category_name}, Batch {batch_idx+1}: {str(e)}")
|
|
993
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
994
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=batch_idx+1)
|
|
986
995
|
# Continue with other batches even if one fails
|
|
987
996
|
continue
|
|
988
997
|
else: # Small number of prompts, process all at once with a timeout and retry logic
|
|
989
|
-
self.logger.debug(f"Processing {len(all_prompts)} prompts in a single batch for {strategy_name}/{
|
|
998
|
+
self.logger.debug(f"Processing {len(all_prompts)} prompts in a single batch for {strategy_name}/{risk_category_name}")
|
|
990
999
|
batch_start_time = datetime.now()
|
|
991
1000
|
try: # Create retry decorator with enhanced retry strategy
|
|
992
1001
|
@retry(**self._create_retry_config()["network_retry"])
|
|
@@ -1000,7 +1009,7 @@ class RedTeam:
|
|
|
1000
1009
|
ConnectionError, TimeoutError, OSError, asyncio.TimeoutError, httpcore.ReadTimeout,
|
|
1001
1010
|
httpx.HTTPStatusError) as e:
|
|
1002
1011
|
# Enhanced error logging with type information and context
|
|
1003
|
-
self.logger.warning(f"Network error in single batch for {strategy_name}/{
|
|
1012
|
+
self.logger.warning(f"Network error in single batch for {strategy_name}/{risk_category_name}: {type(e).__name__}: {str(e)}")
|
|
1004
1013
|
# Add a small delay before retry to allow network recovery
|
|
1005
1014
|
await asyncio.sleep(2)
|
|
1006
1015
|
raise
|
|
@@ -1008,30 +1017,338 @@ class RedTeam:
|
|
|
1008
1017
|
# Execute the retry-enabled function
|
|
1009
1018
|
await send_all_with_retry()
|
|
1010
1019
|
batch_duration = (datetime.now() - batch_start_time).total_seconds()
|
|
1011
|
-
self.logger.debug(f"Successfully processed single batch for {strategy_name}/{
|
|
1020
|
+
self.logger.debug(f"Successfully processed single batch for {strategy_name}/{risk_category_name} in {batch_duration:.2f} seconds")
|
|
1012
1021
|
except (asyncio.TimeoutError, tenacity.RetryError):
|
|
1013
|
-
self.logger.warning(f"Prompt processing for {strategy_name}/{
|
|
1014
|
-
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {
|
|
1022
|
+
self.logger.warning(f"Prompt processing for {strategy_name}/{risk_category_name} timed out after {timeout} seconds, continuing with partial results")
|
|
1023
|
+
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {risk_category_name}")
|
|
1015
1024
|
# Set task status to TIMEOUT
|
|
1016
|
-
single_batch_task_key = f"{strategy_name}_{
|
|
1025
|
+
single_batch_task_key = f"{strategy_name}_{risk_category_name}_single_batch"
|
|
1017
1026
|
self.task_statuses[single_batch_task_key] = TASK_STATUS["TIMEOUT"]
|
|
1018
|
-
self.red_team_info[strategy_name][
|
|
1019
|
-
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=
|
|
1027
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
1028
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=1)
|
|
1020
1029
|
except Exception as e:
|
|
1021
|
-
log_error(self.logger, "Error processing prompts", e, f"{strategy_name}/{
|
|
1022
|
-
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {
|
|
1023
|
-
self.red_team_info[strategy_name][
|
|
1024
|
-
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=
|
|
1030
|
+
log_error(self.logger, "Error processing prompts", e, f"{strategy_name}/{risk_category_name}")
|
|
1031
|
+
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category_name}: {str(e)}")
|
|
1032
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
1033
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=1)
|
|
1025
1034
|
|
|
1026
1035
|
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
1027
1036
|
return orchestrator
|
|
1028
1037
|
|
|
1029
1038
|
except Exception as e:
|
|
1030
|
-
log_error(self.logger, "Failed to initialize orchestrator", e, f"{strategy_name}/{
|
|
1031
|
-
self.logger.debug(f"CRITICAL: Failed to create orchestrator for {strategy_name}/{
|
|
1039
|
+
log_error(self.logger, "Failed to initialize orchestrator", e, f"{strategy_name}/{risk_category_name}")
|
|
1040
|
+
self.logger.debug(f"CRITICAL: Failed to create orchestrator for {strategy_name}/{risk_category_name}: {str(e)}")
|
|
1032
1041
|
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
1033
1042
|
raise
|
|
1034
1043
|
|
|
1044
|
+
async def _multi_turn_orchestrator(
|
|
1045
|
+
self,
|
|
1046
|
+
chat_target: PromptChatTarget,
|
|
1047
|
+
all_prompts: List[str],
|
|
1048
|
+
converter: Union[PromptConverter, List[PromptConverter]],
|
|
1049
|
+
*,
|
|
1050
|
+
strategy_name: str = "unknown",
|
|
1051
|
+
risk_category_name: str = "unknown",
|
|
1052
|
+
risk_category: Optional[RiskCategory] = None,
|
|
1053
|
+
timeout: int = 120,
|
|
1054
|
+
) -> Orchestrator:
|
|
1055
|
+
"""Send prompts via the RedTeamingOrchestrator, the simplest form of MultiTurnOrchestrator, with optimized performance.
|
|
1056
|
+
|
|
1057
|
+
Creates and configures a PyRIT RedTeamingOrchestrator to efficiently send prompts to the target
|
|
1058
|
+
model or function. The orchestrator handles prompt conversion using the specified converters,
|
|
1059
|
+
applies appropriate timeout settings, and manages the database engine for storing conversation
|
|
1060
|
+
results. This function provides centralized management for prompt-sending operations with proper
|
|
1061
|
+
error handling and performance optimizations.
|
|
1062
|
+
|
|
1063
|
+
:param chat_target: The target to send prompts to
|
|
1064
|
+
:type chat_target: PromptChatTarget
|
|
1065
|
+
:param all_prompts: List of prompts to process and send
|
|
1066
|
+
:type all_prompts: List[str]
|
|
1067
|
+
:param converter: Prompt converter or list of converters to transform prompts
|
|
1068
|
+
:type converter: Union[PromptConverter, List[PromptConverter]]
|
|
1069
|
+
:param strategy_name: Name of the attack strategy being used
|
|
1070
|
+
:type strategy_name: str
|
|
1071
|
+
:param risk_category: Risk category being evaluated
|
|
1072
|
+
:type risk_category: str
|
|
1073
|
+
:param timeout: Timeout in seconds for each prompt
|
|
1074
|
+
:type timeout: int
|
|
1075
|
+
:return: Configured and initialized orchestrator
|
|
1076
|
+
:rtype: Orchestrator
|
|
1077
|
+
"""
|
|
1078
|
+
max_turns = 5 # Set a default max turns value
|
|
1079
|
+
task_key = f"{strategy_name}_{risk_category_name}_orchestrator"
|
|
1080
|
+
self.task_statuses[task_key] = TASK_STATUS["RUNNING"]
|
|
1081
|
+
|
|
1082
|
+
log_strategy_start(self.logger, strategy_name, risk_category_name)
|
|
1083
|
+
converter_list = []
|
|
1084
|
+
# Create converter list from single converter or list of converters
|
|
1085
|
+
if converter and isinstance(converter, PromptConverter):
|
|
1086
|
+
converter_list = [converter]
|
|
1087
|
+
elif converter and isinstance(converter, list):
|
|
1088
|
+
# Filter out None values from the converter list
|
|
1089
|
+
converter_list = [c for c in converter if c is not None]
|
|
1090
|
+
|
|
1091
|
+
# Log which converter is being used
|
|
1092
|
+
if converter_list:
|
|
1093
|
+
if isinstance(converter_list, list) and len(converter_list) > 0:
|
|
1094
|
+
converter_names = [c.__class__.__name__ for c in converter_list if c is not None]
|
|
1095
|
+
self.logger.debug(f"Using converters: {', '.join(converter_names)}")
|
|
1096
|
+
elif converter is not None:
|
|
1097
|
+
self.logger.debug(f"Using converter: {converter.__class__.__name__}")
|
|
1098
|
+
else:
|
|
1099
|
+
self.logger.debug("No converters specified")
|
|
1100
|
+
|
|
1101
|
+
for prompt_idx, prompt in enumerate(all_prompts):
|
|
1102
|
+
prompt_start_time = datetime.now()
|
|
1103
|
+
self.logger.debug(f"Processing prompt {prompt_idx+1}/{len(all_prompts)}")
|
|
1104
|
+
try:
|
|
1105
|
+
azure_rai_service_scorer = AzureRAIServiceTrueFalseScorer(
|
|
1106
|
+
client=self.generated_rai_client,
|
|
1107
|
+
api_version=None,
|
|
1108
|
+
model="gpt-4",
|
|
1109
|
+
# objective=prompt,
|
|
1110
|
+
logger=self.logger,
|
|
1111
|
+
credential=self.credential,
|
|
1112
|
+
risk_category=risk_category,
|
|
1113
|
+
azure_ai_project=self.azure_ai_project,
|
|
1114
|
+
)
|
|
1115
|
+
|
|
1116
|
+
azure_rai_service_target = AzureRAIServiceTarget(
|
|
1117
|
+
client=self.generated_rai_client,
|
|
1118
|
+
api_version=None,
|
|
1119
|
+
model="gpt-4",
|
|
1120
|
+
prompt_template_key="orchestrators/red_teaming/text_generation.yaml",
|
|
1121
|
+
objective=prompt,
|
|
1122
|
+
logger=self.logger,
|
|
1123
|
+
is_one_dp_project=self._one_dp_project,
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
orchestrator = RedTeamingOrchestrator(
|
|
1127
|
+
objective_target=chat_target,
|
|
1128
|
+
adversarial_chat=azure_rai_service_target,
|
|
1129
|
+
# adversarial_chat_seed_prompt=prompt,
|
|
1130
|
+
max_turns=max_turns,
|
|
1131
|
+
prompt_converters=converter_list,
|
|
1132
|
+
objective_scorer=azure_rai_service_scorer,
|
|
1133
|
+
use_score_as_feedback=False,
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
# Debug log the first few characters of the current prompt
|
|
1137
|
+
self.logger.debug(f"Current prompt (truncated): {prompt[:50]}...")
|
|
1138
|
+
|
|
1139
|
+
# Initialize output path for memory labelling
|
|
1140
|
+
base_path = str(uuid.uuid4())
|
|
1141
|
+
|
|
1142
|
+
# If scan output directory exists, place the file there
|
|
1143
|
+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
1144
|
+
output_path = os.path.join(self.scan_output_dir, f"{base_path}{DATA_EXT}")
|
|
1145
|
+
else:
|
|
1146
|
+
output_path = f"{base_path}{DATA_EXT}"
|
|
1147
|
+
|
|
1148
|
+
self.red_team_info[strategy_name][risk_category_name]["data_file"] = output_path
|
|
1149
|
+
|
|
1150
|
+
try: # Create retry decorator for this specific call with enhanced retry strategy
|
|
1151
|
+
@retry(**self._create_retry_config()["network_retry"])
|
|
1152
|
+
async def send_prompt_with_retry():
|
|
1153
|
+
try:
|
|
1154
|
+
return await asyncio.wait_for(
|
|
1155
|
+
orchestrator.run_attack_async(objective=prompt, memory_labels={"risk_strategy_path": output_path, "batch": 1}),
|
|
1156
|
+
timeout=timeout # Use provided timeouts
|
|
1157
|
+
)
|
|
1158
|
+
except (httpx.ConnectTimeout, httpx.ReadTimeout, httpx.ConnectError, httpx.HTTPError,
|
|
1159
|
+
ConnectionError, TimeoutError, asyncio.TimeoutError, httpcore.ReadTimeout,
|
|
1160
|
+
httpx.HTTPStatusError) as e:
|
|
1161
|
+
# Log the error with enhanced information and allow retry logic to handle it
|
|
1162
|
+
self.logger.warning(f"Network error in prompt {prompt_idx+1} for {strategy_name}/{risk_category_name}: {type(e).__name__}: {str(e)}")
|
|
1163
|
+
# Add a small delay before retry to allow network recovery
|
|
1164
|
+
await asyncio.sleep(1)
|
|
1165
|
+
raise
|
|
1166
|
+
|
|
1167
|
+
# Execute the retry-enabled function
|
|
1168
|
+
await send_prompt_with_retry()
|
|
1169
|
+
prompt_duration = (datetime.now() - prompt_start_time).total_seconds()
|
|
1170
|
+
self.logger.debug(f"Successfully processed prompt {prompt_idx+1} for {strategy_name}/{risk_category_name} in {prompt_duration:.2f} seconds")
|
|
1171
|
+
|
|
1172
|
+
# Print progress to console
|
|
1173
|
+
if prompt_idx < len(all_prompts) - 1: # Don't print for the last prompt
|
|
1174
|
+
print(f"Strategy {strategy_name}, Risk {risk_category_name}: Processed prompt {prompt_idx+1}/{len(all_prompts)}")
|
|
1175
|
+
|
|
1176
|
+
except (asyncio.TimeoutError, tenacity.RetryError):
|
|
1177
|
+
self.logger.warning(f"Batch {prompt_idx+1} for {strategy_name}/{risk_category_name} timed out after {timeout} seconds, continuing with partial results")
|
|
1178
|
+
self.logger.debug(f"Timeout: Strategy {strategy_name}, Risk {risk_category_name}, Batch {prompt_idx+1} after {timeout} seconds.", exc_info=True)
|
|
1179
|
+
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {risk_category_name}, Batch {prompt_idx+1}")
|
|
1180
|
+
# Set task status to TIMEOUT
|
|
1181
|
+
batch_task_key = f"{strategy_name}_{risk_category_name}_prompt_{prompt_idx+1}"
|
|
1182
|
+
self.task_statuses[batch_task_key] = TASK_STATUS["TIMEOUT"]
|
|
1183
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
1184
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=1)
|
|
1185
|
+
# Continue with partial results rather than failing completely
|
|
1186
|
+
continue
|
|
1187
|
+
except Exception as e:
|
|
1188
|
+
log_error(self.logger, f"Error processing prompt {prompt_idx+1}", e, f"{strategy_name}/{risk_category_name}")
|
|
1189
|
+
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category_name}, Prompt {prompt_idx+1}: {str(e)}")
|
|
1190
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
1191
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=1)
|
|
1192
|
+
# Continue with other batches even if one fails
|
|
1193
|
+
continue
|
|
1194
|
+
except Exception as e:
|
|
1195
|
+
log_error(self.logger, "Failed to initialize orchestrator", e, f"{strategy_name}/{risk_category_name}")
|
|
1196
|
+
self.logger.debug(f"CRITICAL: Failed to create orchestrator for {strategy_name}/{risk_category_name}: {str(e)}")
|
|
1197
|
+
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
1198
|
+
raise
|
|
1199
|
+
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
1200
|
+
return orchestrator
|
|
1201
|
+
|
|
1202
|
+
async def _crescendo_orchestrator(
|
|
1203
|
+
self,
|
|
1204
|
+
chat_target: PromptChatTarget,
|
|
1205
|
+
all_prompts: List[str],
|
|
1206
|
+
converter: Union[PromptConverter, List[PromptConverter]],
|
|
1207
|
+
*,
|
|
1208
|
+
strategy_name: str = "unknown",
|
|
1209
|
+
risk_category_name: str = "unknown",
|
|
1210
|
+
risk_category: Optional[RiskCategory] = None,
|
|
1211
|
+
timeout: int = 120,
|
|
1212
|
+
) -> Orchestrator:
|
|
1213
|
+
"""Send prompts via the CrescendoOrchestrator with optimized performance.
|
|
1214
|
+
|
|
1215
|
+
Creates and configures a PyRIT CrescendoOrchestrator to send prompts to the target
|
|
1216
|
+
model or function. The orchestrator handles prompt conversion using the specified converters,
|
|
1217
|
+
applies appropriate timeout settings, and manages the database engine for storing conversation
|
|
1218
|
+
results. This function provides centralized management for prompt-sending operations with proper
|
|
1219
|
+
error handling and performance optimizations.
|
|
1220
|
+
|
|
1221
|
+
:param chat_target: The target to send prompts to
|
|
1222
|
+
:type chat_target: PromptChatTarget
|
|
1223
|
+
:param all_prompts: List of prompts to process and send
|
|
1224
|
+
:type all_prompts: List[str]
|
|
1225
|
+
:param converter: Prompt converter or list of converters to transform prompts
|
|
1226
|
+
:type converter: Union[PromptConverter, List[PromptConverter]]
|
|
1227
|
+
:param strategy_name: Name of the attack strategy being used
|
|
1228
|
+
:type strategy_name: str
|
|
1229
|
+
:param risk_category: Risk category being evaluated
|
|
1230
|
+
:type risk_category: str
|
|
1231
|
+
:param timeout: Timeout in seconds for each prompt
|
|
1232
|
+
:type timeout: int
|
|
1233
|
+
:return: Configured and initialized orchestrator
|
|
1234
|
+
:rtype: Orchestrator
|
|
1235
|
+
"""
|
|
1236
|
+
max_turns = 10 # Set a default max turns value
|
|
1237
|
+
max_backtracks = 5
|
|
1238
|
+
task_key = f"{strategy_name}_{risk_category_name}_orchestrator"
|
|
1239
|
+
self.task_statuses[task_key] = TASK_STATUS["RUNNING"]
|
|
1240
|
+
|
|
1241
|
+
log_strategy_start(self.logger, strategy_name, risk_category_name)
|
|
1242
|
+
|
|
1243
|
+
# Initialize output path for memory labelling
|
|
1244
|
+
base_path = str(uuid.uuid4())
|
|
1245
|
+
|
|
1246
|
+
# If scan output directory exists, place the file there
|
|
1247
|
+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
1248
|
+
output_path = os.path.join(self.scan_output_dir, f"{base_path}{DATA_EXT}")
|
|
1249
|
+
else:
|
|
1250
|
+
output_path = f"{base_path}{DATA_EXT}"
|
|
1251
|
+
|
|
1252
|
+
self.red_team_info[strategy_name][risk_category_name]["data_file"] = output_path
|
|
1253
|
+
|
|
1254
|
+
for prompt_idx, prompt in enumerate(all_prompts):
|
|
1255
|
+
prompt_start_time = datetime.now()
|
|
1256
|
+
self.logger.debug(f"Processing prompt {prompt_idx+1}/{len(all_prompts)}")
|
|
1257
|
+
try:
|
|
1258
|
+
red_llm_scoring_target = RAIServiceEvalChatTarget(
|
|
1259
|
+
logger=self.logger,
|
|
1260
|
+
credential=self.credential,
|
|
1261
|
+
risk_category=risk_category,
|
|
1262
|
+
azure_ai_project=self.azure_ai_project,
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
azure_rai_service_target = AzureRAIServiceTarget(
|
|
1266
|
+
client=self.generated_rai_client,
|
|
1267
|
+
api_version=None,
|
|
1268
|
+
model="gpt-4",
|
|
1269
|
+
prompt_template_key="orchestrators/crescendo/crescendo_variant_1.yaml",
|
|
1270
|
+
objective=prompt,
|
|
1271
|
+
logger=self.logger,
|
|
1272
|
+
crescendo_format=True,
|
|
1273
|
+
is_one_dp_project=self._one_dp_project,
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
orchestrator = CrescendoOrchestrator(
|
|
1277
|
+
objective_target=chat_target,
|
|
1278
|
+
adversarial_chat=azure_rai_service_target,
|
|
1279
|
+
max_turns=max_turns,
|
|
1280
|
+
scoring_target=red_llm_scoring_target,
|
|
1281
|
+
max_backtracks=max_backtracks,
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
orchestrator._objective_scorer = AzureRAIServiceTrueFalseScorer(
|
|
1285
|
+
client=self.generated_rai_client,
|
|
1286
|
+
api_version=None,
|
|
1287
|
+
model="gpt-4",
|
|
1288
|
+
# objective=prompt,
|
|
1289
|
+
logger=self.logger,
|
|
1290
|
+
credential=self.credential,
|
|
1291
|
+
risk_category=risk_category,
|
|
1292
|
+
azure_ai_project=self.azure_ai_project,
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
# Debug log the first few characters of the current prompt
|
|
1296
|
+
self.logger.debug(f"Current prompt (truncated): {prompt[:50]}...")
|
|
1297
|
+
|
|
1298
|
+
try: # Create retry decorator for this specific call with enhanced retry strategy
|
|
1299
|
+
@retry(**self._create_retry_config()["network_retry"])
|
|
1300
|
+
async def send_prompt_with_retry():
|
|
1301
|
+
try:
|
|
1302
|
+
return await asyncio.wait_for(
|
|
1303
|
+
orchestrator.run_attack_async(objective=prompt, memory_labels={"risk_strategy_path": output_path, "batch": prompt_idx+1}),
|
|
1304
|
+
timeout=timeout # Use provided timeouts
|
|
1305
|
+
)
|
|
1306
|
+
except (httpx.ConnectTimeout, httpx.ReadTimeout, httpx.ConnectError, httpx.HTTPError,
|
|
1307
|
+
ConnectionError, TimeoutError, asyncio.TimeoutError, httpcore.ReadTimeout,
|
|
1308
|
+
httpx.HTTPStatusError) as e:
|
|
1309
|
+
# Log the error with enhanced information and allow retry logic to handle it
|
|
1310
|
+
self.logger.warning(f"Network error in prompt {prompt_idx+1} for {strategy_name}/{risk_category_name}: {type(e).__name__}: {str(e)}")
|
|
1311
|
+
# Add a small delay before retry to allow network recovery
|
|
1312
|
+
await asyncio.sleep(1)
|
|
1313
|
+
raise
|
|
1314
|
+
|
|
1315
|
+
# Execute the retry-enabled function
|
|
1316
|
+
await send_prompt_with_retry()
|
|
1317
|
+
prompt_duration = (datetime.now() - prompt_start_time).total_seconds()
|
|
1318
|
+
self.logger.debug(f"Successfully processed prompt {prompt_idx+1} for {strategy_name}/{risk_category_name} in {prompt_duration:.2f} seconds")
|
|
1319
|
+
|
|
1320
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=prompt_idx+1)
|
|
1321
|
+
|
|
1322
|
+
# Print progress to console
|
|
1323
|
+
if prompt_idx < len(all_prompts) - 1: # Don't print for the last prompt
|
|
1324
|
+
print(f"Strategy {strategy_name}, Risk {risk_category_name}: Processed prompt {prompt_idx+1}/{len(all_prompts)}")
|
|
1325
|
+
|
|
1326
|
+
except (asyncio.TimeoutError, tenacity.RetryError):
|
|
1327
|
+
self.logger.warning(f"Batch {prompt_idx+1} for {strategy_name}/{risk_category_name} timed out after {timeout} seconds, continuing with partial results")
|
|
1328
|
+
self.logger.debug(f"Timeout: Strategy {strategy_name}, Risk {risk_category_name}, Batch {prompt_idx+1} after {timeout} seconds.", exc_info=True)
|
|
1329
|
+
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {risk_category_name}, Batch {prompt_idx+1}")
|
|
1330
|
+
# Set task status to TIMEOUT
|
|
1331
|
+
batch_task_key = f"{strategy_name}_{risk_category_name}_prompt_{prompt_idx+1}"
|
|
1332
|
+
self.task_statuses[batch_task_key] = TASK_STATUS["TIMEOUT"]
|
|
1333
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
1334
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=prompt_idx+1)
|
|
1335
|
+
# Continue with partial results rather than failing completely
|
|
1336
|
+
continue
|
|
1337
|
+
except Exception as e:
|
|
1338
|
+
log_error(self.logger, f"Error processing prompt {prompt_idx+1}", e, f"{strategy_name}/{risk_category_name}")
|
|
1339
|
+
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category_name}, Prompt {prompt_idx+1}: {str(e)}")
|
|
1340
|
+
self.red_team_info[strategy_name][risk_category_name]["status"] = TASK_STATUS["INCOMPLETE"]
|
|
1341
|
+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category_name, batch_idx=prompt_idx+1)
|
|
1342
|
+
# Continue with other batches even if one fails
|
|
1343
|
+
continue
|
|
1344
|
+
except Exception as e:
|
|
1345
|
+
log_error(self.logger, "Failed to initialize orchestrator", e, f"{strategy_name}/{risk_category_name}")
|
|
1346
|
+
self.logger.debug(f"CRITICAL: Failed to create orchestrator for {strategy_name}/{risk_category_name}: {str(e)}")
|
|
1347
|
+
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
1348
|
+
raise
|
|
1349
|
+
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
1350
|
+
return orchestrator
|
|
1351
|
+
|
|
1035
1352
|
def _write_pyrit_outputs_to_file(self,*, orchestrator: Orchestrator, strategy_name: str, risk_category: str, batch_idx: Optional[int] = None) -> str:
|
|
1036
1353
|
"""Write PyRIT outputs to a file with a name based on orchestrator, strategy, and risk category.
|
|
1037
1354
|
|
|
@@ -1074,6 +1391,9 @@ class RedTeam:
|
|
|
1074
1391
|
#Convert to json lines
|
|
1075
1392
|
json_lines = ""
|
|
1076
1393
|
for conversation in conversations: # each conversation is a List[ChatMessage]
|
|
1394
|
+
if conversation[0].role == "system":
|
|
1395
|
+
# Skip system messages in the output
|
|
1396
|
+
continue
|
|
1077
1397
|
json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n"
|
|
1078
1398
|
with Path(output_path).open("w") as f:
|
|
1079
1399
|
f.writelines(json_lines)
|
|
@@ -1087,7 +1407,11 @@ class RedTeam:
|
|
|
1087
1407
|
self.logger.debug(f"Creating new file: {output_path}")
|
|
1088
1408
|
#Convert to json lines
|
|
1089
1409
|
json_lines = ""
|
|
1410
|
+
|
|
1090
1411
|
for conversation in conversations: # each conversation is a List[ChatMessage]
|
|
1412
|
+
if conversation[0].role == "system":
|
|
1413
|
+
# Skip system messages in the output
|
|
1414
|
+
continue
|
|
1091
1415
|
json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n"
|
|
1092
1416
|
with Path(output_path).open("w") as f:
|
|
1093
1417
|
f.writelines(json_lines)
|
|
@@ -1111,32 +1435,31 @@ class RedTeam:
|
|
|
1111
1435
|
from ._utils.strategy_utils import get_chat_target
|
|
1112
1436
|
return get_chat_target(target)
|
|
1113
1437
|
|
|
1438
|
+
|
|
1114
1439
|
# Replace with utility function
|
|
1115
|
-
def
|
|
1116
|
-
"""Get appropriate orchestrator functions for the specified attack
|
|
1440
|
+
def _get_orchestrator_for_attack_strategy(self, attack_strategy: Union[AttackStrategy, List[AttackStrategy]]) -> Callable:
|
|
1441
|
+
"""Get appropriate orchestrator functions for the specified attack strategy.
|
|
1117
1442
|
|
|
1118
|
-
Determines which orchestrator functions should be used based on the attack strategies.
|
|
1443
|
+
Determines which orchestrator functions should be used based on the attack strategies, max turns.
|
|
1119
1444
|
Returns a list of callable functions that can create orchestrators configured for the
|
|
1120
1445
|
specified strategies. This function is crucial for mapping strategies to the appropriate
|
|
1121
1446
|
execution environment.
|
|
1122
1447
|
|
|
1123
1448
|
:param attack_strategy: List of attack strategies to get orchestrators for
|
|
1124
|
-
:type attack_strategy:
|
|
1449
|
+
:type attack_strategy: Union[AttackStrategy, List[AttackStrategy]]
|
|
1125
1450
|
:return: List of callable functions that create appropriately configured orchestrators
|
|
1126
1451
|
:rtype: List[Callable]
|
|
1127
1452
|
"""
|
|
1128
1453
|
# We need to modify this to use our actual _prompt_sending_orchestrator since the utility function can't access it
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
elif AttackStrategy.
|
|
1134
|
-
|
|
1135
|
-
elif AttackStrategy.
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
call_to_orchestrators.extend([self._prompt_sending_orchestrator])
|
|
1139
|
-
return call_to_orchestrators
|
|
1454
|
+
if isinstance(attack_strategy, list):
|
|
1455
|
+
if AttackStrategy.MultiTurn in attack_strategy or AttackStrategy.Crescendo in attack_strategy:
|
|
1456
|
+
self.logger.error("MultiTurn and Crescendo strategies are not supported in composed attacks.")
|
|
1457
|
+
raise ValueError("MultiTurn and Crescendo strategies are not supported in composed attacks.")
|
|
1458
|
+
elif AttackStrategy.MultiTurn == attack_strategy:
|
|
1459
|
+
return self._multi_turn_orchestrator
|
|
1460
|
+
elif AttackStrategy.Crescendo == attack_strategy:
|
|
1461
|
+
return self._crescendo_orchestrator
|
|
1462
|
+
return self._prompt_sending_orchestrator
|
|
1140
1463
|
|
|
1141
1464
|
# Replace with utility function
|
|
1142
1465
|
def _get_attack_success(self, result: str) -> bool:
|
|
@@ -1678,11 +2001,6 @@ class RedTeam:
|
|
|
1678
2001
|
metric_name = get_metric_from_risk_category(risk_category)
|
|
1679
2002
|
self.logger.debug(f"Using metric '{metric_name}' for risk category '{risk_category.value}'")
|
|
1680
2003
|
|
|
1681
|
-
# Convert hate_unfairness to hate_fairness for RAI service compatibility
|
|
1682
|
-
if metric_name == "hate_unfairness":
|
|
1683
|
-
metric_name = "hate_fairness"
|
|
1684
|
-
self.logger.debug(f"Converted metric name to '{metric_name}' for compatibility with RAI service")
|
|
1685
|
-
|
|
1686
2004
|
# Load all conversations from the data file
|
|
1687
2005
|
conversations = []
|
|
1688
2006
|
try:
|
|
@@ -1736,8 +2054,6 @@ class RedTeam:
|
|
|
1736
2054
|
|
|
1737
2055
|
async def _process_attack(
|
|
1738
2056
|
self,
|
|
1739
|
-
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
|
|
1740
|
-
call_orchestrator: Callable,
|
|
1741
2057
|
strategy: Union[AttackStrategy, List[AttackStrategy]],
|
|
1742
2058
|
risk_category: RiskCategory,
|
|
1743
2059
|
all_prompts: List[str],
|
|
@@ -1756,10 +2072,6 @@ class RedTeam:
|
|
|
1756
2072
|
appropriate converter, saving results to files, and optionally evaluating the results.
|
|
1757
2073
|
The function handles progress tracking, logging, and error handling throughout the process.
|
|
1758
2074
|
|
|
1759
|
-
:param target: The target model or function to scan
|
|
1760
|
-
:type target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration, PromptChatTarget]
|
|
1761
|
-
:param call_orchestrator: Function to call to create an orchestrator
|
|
1762
|
-
:type call_orchestrator: Callable
|
|
1763
2075
|
:param strategy: The attack strategy to use
|
|
1764
2076
|
:type strategy: Union[AttackStrategy, List[AttackStrategy]]
|
|
1765
2077
|
:param risk_category: The risk category to evaluate
|
|
@@ -1793,9 +2105,10 @@ class RedTeam:
|
|
|
1793
2105
|
log_strategy_start(self.logger, strategy_name, risk_category.value)
|
|
1794
2106
|
|
|
1795
2107
|
converter = self._get_converter_for_strategy(strategy)
|
|
2108
|
+
call_orchestrator = self._get_orchestrator_for_attack_strategy(strategy)
|
|
1796
2109
|
try:
|
|
1797
2110
|
self.logger.debug(f"Calling orchestrator for {strategy_name} strategy")
|
|
1798
|
-
orchestrator = await call_orchestrator(self.chat_target, all_prompts, converter, strategy_name, risk_category.value, timeout)
|
|
2111
|
+
orchestrator = await call_orchestrator(chat_target=self.chat_target, all_prompts=all_prompts, converter=converter, strategy_name=strategy_name, risk_category=risk_category, risk_category_name=risk_category.value, timeout=timeout)
|
|
1799
2112
|
except PyritException as e:
|
|
1800
2113
|
log_error(self.logger, f"Error calling orchestrator for {strategy_name} strategy", e)
|
|
1801
2114
|
self.logger.debug(f"Orchestrator error for {strategy_name}/{risk_category.value}: {str(e)}")
|
|
@@ -1869,7 +2182,6 @@ class RedTeam:
|
|
|
1869
2182
|
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration, PromptChatTarget],
|
|
1870
2183
|
*,
|
|
1871
2184
|
scan_name: Optional[str] = None,
|
|
1872
|
-
num_turns : int = 1,
|
|
1873
2185
|
attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]] = [],
|
|
1874
2186
|
skip_upload: bool = False,
|
|
1875
2187
|
output_path: Optional[Union[str, os.PathLike]] = None,
|
|
@@ -1886,8 +2198,6 @@ class RedTeam:
|
|
|
1886
2198
|
:type target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration, PromptChatTarget]
|
|
1887
2199
|
:param scan_name: Optional name for the evaluation
|
|
1888
2200
|
:type scan_name: Optional[str]
|
|
1889
|
-
:param num_turns: Number of conversation turns to use in the scan
|
|
1890
|
-
:type num_turns: int
|
|
1891
2201
|
:param attack_strategies: List of attack strategies to use
|
|
1892
2202
|
:type attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]]
|
|
1893
2203
|
:param skip_upload: Flag to determine if the scan results should be uploaded
|
|
@@ -2057,15 +2367,17 @@ class RedTeam:
|
|
|
2057
2367
|
flattened_attack_strategies = self._get_flattened_attack_strategies(attack_strategies)
|
|
2058
2368
|
self.logger.info(f"Using {len(flattened_attack_strategies)} attack strategies")
|
|
2059
2369
|
self.logger.info(f"Found {len(flattened_attack_strategies)} attack strategies")
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
|
|
2370
|
+
|
|
2371
|
+
if len(flattened_attack_strategies) > 2 and (AttackStrategy.MultiTurn in flattened_attack_strategies or AttackStrategy.Crescendo in flattened_attack_strategies):
|
|
2372
|
+
self.logger.warning("MultiTurn and Crescendo strategies are not compatible with multiple attack strategies.")
|
|
2373
|
+
print("⚠️ Warning: MultiTurn and Crescendo strategies are not compatible with multiple attack strategies.")
|
|
2374
|
+
raise ValueError("MultiTurn and Crescendo strategies are not compatible with multiple attack strategies.")
|
|
2375
|
+
|
|
2376
|
+
# Calculate total tasks: #risk_categories * #converters
|
|
2377
|
+
self.total_tasks = len(self.risk_categories) * len(flattened_attack_strategies)
|
|
2066
2378
|
# Show task count for user awareness
|
|
2067
2379
|
print(f"📋 Planning {self.total_tasks} total tasks")
|
|
2068
|
-
self.logger.info(f"Total tasks: {self.total_tasks} ({len(self.risk_categories)} risk categories * {len(flattened_attack_strategies)} strategies
|
|
2380
|
+
self.logger.info(f"Total tasks: {self.total_tasks} ({len(self.risk_categories)} risk categories * {len(flattened_attack_strategies)} strategies)")
|
|
2069
2381
|
|
|
2070
2382
|
# Initialize our tracking dictionary early with empty structures
|
|
2071
2383
|
# This ensures we have a place to store results even if tasks fail
|
|
@@ -2151,9 +2463,9 @@ class RedTeam:
|
|
|
2151
2463
|
|
|
2152
2464
|
# Create all tasks for parallel processing
|
|
2153
2465
|
orchestrator_tasks = []
|
|
2154
|
-
combinations = list(itertools.product(
|
|
2466
|
+
combinations = list(itertools.product(flattened_attack_strategies, self.risk_categories))
|
|
2155
2467
|
|
|
2156
|
-
for combo_idx, (
|
|
2468
|
+
for combo_idx, (strategy, risk_category) in enumerate(combinations):
|
|
2157
2469
|
strategy_name = self._get_strategy_name(strategy)
|
|
2158
2470
|
objectives = all_objectives[strategy_name][risk_category.value]
|
|
2159
2471
|
|
|
@@ -2165,12 +2477,10 @@ class RedTeam:
|
|
|
2165
2477
|
progress_bar.update(1)
|
|
2166
2478
|
continue
|
|
2167
2479
|
|
|
2168
|
-
self.logger.debug(f"[{combo_idx+1}/{len(combinations)}] Creating task: {
|
|
2480
|
+
self.logger.debug(f"[{combo_idx+1}/{len(combinations)}] Creating task: {strategy_name} + {risk_category.value}")
|
|
2169
2481
|
|
|
2170
2482
|
orchestrator_tasks.append(
|
|
2171
2483
|
self._process_attack(
|
|
2172
|
-
target=target,
|
|
2173
|
-
call_orchestrator=call_orchestrator,
|
|
2174
2484
|
all_prompts=objectives,
|
|
2175
2485
|
strategy=strategy,
|
|
2176
2486
|
progress_bar=progress_bar,
|