azure-ai-evaluation 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/__init__.py +42 -14
- azure/ai/evaluation/_azure/_models.py +6 -6
- azure/ai/evaluation/_common/constants.py +6 -2
- azure/ai/evaluation/_common/rai_service.py +38 -4
- azure/ai/evaluation/_common/raiclient/__init__.py +34 -0
- azure/ai/evaluation/_common/raiclient/_client.py +128 -0
- azure/ai/evaluation/_common/raiclient/_configuration.py +87 -0
- azure/ai/evaluation/_common/raiclient/_model_base.py +1235 -0
- azure/ai/evaluation/_common/raiclient/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/_serialization.py +2050 -0
- azure/ai/evaluation/_common/raiclient/_version.py +9 -0
- azure/ai/evaluation/_common/raiclient/aio/__init__.py +29 -0
- azure/ai/evaluation/_common/raiclient/aio/_client.py +130 -0
- azure/ai/evaluation/_common/raiclient/aio/_configuration.py +87 -0
- azure/ai/evaluation/_common/raiclient/aio/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/aio/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/raiclient/aio/operations/_operations.py +981 -0
- azure/ai/evaluation/_common/raiclient/aio/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/models/__init__.py +60 -0
- azure/ai/evaluation/_common/raiclient/models/_enums.py +18 -0
- azure/ai/evaluation/_common/raiclient/models/_models.py +651 -0
- azure/ai/evaluation/_common/raiclient/models/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/raiclient/operations/_operations.py +1225 -0
- azure/ai/evaluation/_common/raiclient/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/py.typed +1 -0
- azure/ai/evaluation/_common/utils.py +30 -10
- azure/ai/evaluation/_constants.py +10 -0
- azure/ai/evaluation/_converters/__init__.py +3 -0
- azure/ai/evaluation/_converters/_ai_services.py +804 -0
- azure/ai/evaluation/_converters/_models.py +302 -0
- azure/ai/evaluation/_evaluate/_batch_run/__init__.py +10 -3
- azure/ai/evaluation/_evaluate/_batch_run/_run_submitter_client.py +104 -0
- azure/ai/evaluation/_evaluate/_batch_run/batch_clients.py +82 -0
- azure/ai/evaluation/_evaluate/_eval_run.py +1 -1
- azure/ai/evaluation/_evaluate/_evaluate.py +36 -4
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +23 -3
- azure/ai/evaluation/_evaluators/_code_vulnerability/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_code_vulnerability/_code_vulnerability.py +120 -0
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +21 -2
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +43 -3
- azure/ai/evaluation/_evaluators/_common/_base_multi_eval.py +3 -1
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +43 -4
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +16 -4
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +42 -5
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +15 -0
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +15 -0
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +15 -0
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +15 -0
- azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +28 -4
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +21 -2
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +26 -3
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +21 -3
- azure/ai/evaluation/_evaluators/_intent_resolution/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_intent_resolution/_intent_resolution.py +152 -0
- azure/ai/evaluation/_evaluators/_intent_resolution/intent_resolution.prompty +161 -0
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +26 -3
- azure/ai/evaluation/_evaluators/_qa/_qa.py +51 -7
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +26 -2
- azure/ai/evaluation/_evaluators/_response_completeness/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_response_completeness/_response_completeness.py +157 -0
- azure/ai/evaluation/_evaluators/_response_completeness/response_completeness.prompty +99 -0
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +21 -2
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +113 -4
- azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +23 -3
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +24 -5
- azure/ai/evaluation/_evaluators/_task_adherence/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_task_adherence/_task_adherence.py +148 -0
- azure/ai/evaluation/_evaluators/_task_adherence/task_adherence.prompty +117 -0
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py +292 -0
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/tool_call_accuracy.prompty +71 -0
- azure/ai/evaluation/_evaluators/_ungrounded_attributes/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_ungrounded_attributes/_ungrounded_attributes.py +103 -0
- azure/ai/evaluation/_evaluators/_xpia/xpia.py +2 -0
- azure/ai/evaluation/_exceptions.py +5 -1
- azure/ai/evaluation/_legacy/__init__.py +3 -0
- azure/ai/evaluation/_legacy/_batch_engine/__init__.py +9 -0
- azure/ai/evaluation/_legacy/_batch_engine/_config.py +45 -0
- azure/ai/evaluation/_legacy/_batch_engine/_engine.py +368 -0
- azure/ai/evaluation/_legacy/_batch_engine/_exceptions.py +88 -0
- azure/ai/evaluation/_legacy/_batch_engine/_logging.py +292 -0
- azure/ai/evaluation/_legacy/_batch_engine/_openai_injector.py +23 -0
- azure/ai/evaluation/_legacy/_batch_engine/_result.py +99 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run.py +121 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run_storage.py +128 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run_submitter.py +217 -0
- azure/ai/evaluation/_legacy/_batch_engine/_status.py +25 -0
- azure/ai/evaluation/_legacy/_batch_engine/_trace.py +105 -0
- azure/ai/evaluation/_legacy/_batch_engine/_utils.py +82 -0
- azure/ai/evaluation/_legacy/_batch_engine/_utils_deprecated.py +131 -0
- azure/ai/evaluation/_legacy/prompty/__init__.py +36 -0
- azure/ai/evaluation/_legacy/prompty/_connection.py +182 -0
- azure/ai/evaluation/_legacy/prompty/_exceptions.py +59 -0
- azure/ai/evaluation/_legacy/prompty/_prompty.py +313 -0
- azure/ai/evaluation/_legacy/prompty/_utils.py +545 -0
- azure/ai/evaluation/_legacy/prompty/_yaml_utils.py +99 -0
- azure/ai/evaluation/_red_team/__init__.py +3 -0
- azure/ai/evaluation/_red_team/_attack_objective_generator.py +192 -0
- azure/ai/evaluation/_red_team/_attack_strategy.py +42 -0
- azure/ai/evaluation/_red_team/_callback_chat_target.py +74 -0
- azure/ai/evaluation/_red_team/_default_converter.py +21 -0
- azure/ai/evaluation/_red_team/_red_team.py +1858 -0
- azure/ai/evaluation/_red_team/_red_team_result.py +246 -0
- azure/ai/evaluation/_red_team/_utils/__init__.py +3 -0
- azure/ai/evaluation/_red_team/_utils/constants.py +64 -0
- azure/ai/evaluation/_red_team/_utils/formatting_utils.py +164 -0
- azure/ai/evaluation/_red_team/_utils/logging_utils.py +139 -0
- azure/ai/evaluation/_red_team/_utils/strategy_utils.py +188 -0
- azure/ai/evaluation/_safety_evaluation/__init__.py +3 -0
- azure/ai/evaluation/_safety_evaluation/_generated_rai_client.py +0 -0
- azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +741 -0
- azure/ai/evaluation/_version.py +2 -1
- azure/ai/evaluation/simulator/_adversarial_scenario.py +3 -1
- azure/ai/evaluation/simulator/_adversarial_simulator.py +61 -27
- azure/ai/evaluation/simulator/_conversation/__init__.py +4 -5
- azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -0
- azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +145 -0
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +2 -0
- azure/ai/evaluation/simulator/_model_tools/_rai_client.py +71 -1
- {azure_ai_evaluation-1.2.0.dist-info → azure_ai_evaluation-1.4.0.dist-info}/METADATA +75 -15
- azure_ai_evaluation-1.4.0.dist-info/RECORD +197 -0
- {azure_ai_evaluation-1.2.0.dist-info → azure_ai_evaluation-1.4.0.dist-info}/WHEEL +1 -1
- azure/ai/evaluation/_evaluators/_multimodal/__init__.py +0 -20
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +0 -132
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +0 -55
- azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +0 -124
- azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_violence.py +0 -100
- azure_ai_evaluation-1.2.0.dist-info/RECORD +0 -125
- {azure_ai_evaluation-1.2.0.dist-info → azure_ai_evaluation-1.4.0.dist-info}/NOTICE.txt +0 -0
- {azure_ai_evaluation-1.2.0.dist-info → azure_ai_evaluation-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1858 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
# Third-party imports
|
|
5
|
+
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import math
|
|
8
|
+
import os
|
|
9
|
+
import logging
|
|
10
|
+
import tempfile
|
|
11
|
+
import time
|
|
12
|
+
from datetime import datetime
|
|
13
|
+
from typing import Callable, Dict, List, Optional, Union, cast
|
|
14
|
+
import json
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
import itertools
|
|
17
|
+
import random
|
|
18
|
+
import uuid
|
|
19
|
+
import pandas as pd
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
|
|
22
|
+
# Azure AI Evaluation imports
|
|
23
|
+
from azure.ai.evaluation._evaluate._eval_run import EvalRun
|
|
24
|
+
from azure.ai.evaluation._evaluate._utils import _trace_destination_from_project_scope
|
|
25
|
+
from azure.ai.evaluation._model_configurations import AzureAIProject
|
|
26
|
+
from azure.ai.evaluation._constants import EvaluationRunProperties, DefaultOpenEncoding, EVALUATION_PASS_FAIL_MAPPING
|
|
27
|
+
from azure.ai.evaluation._evaluate._utils import _get_ai_studio_url
|
|
28
|
+
from azure.ai.evaluation._evaluate._utils import extract_workspace_triad_from_trace_provider
|
|
29
|
+
from azure.ai.evaluation._version import VERSION
|
|
30
|
+
from azure.ai.evaluation._azure._clients import LiteMLClient
|
|
31
|
+
from azure.ai.evaluation._evaluate._utils import _write_output
|
|
32
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
33
|
+
from azure.ai.evaluation._model_configurations import EvaluationResult
|
|
34
|
+
from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager, TokenScope, RAIClient
|
|
35
|
+
from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient
|
|
36
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
37
|
+
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
38
|
+
from azure.ai.evaluation._common.math import list_mean_nan_safe, is_none_or_nan
|
|
39
|
+
from azure.ai.evaluation._common.utils import validate_azure_ai_project
|
|
40
|
+
from azure.ai.evaluation import evaluate
|
|
41
|
+
|
|
42
|
+
# Azure Core imports
|
|
43
|
+
from azure.core.credentials import TokenCredential
|
|
44
|
+
|
|
45
|
+
# Red Teaming imports
|
|
46
|
+
from ._red_team_result import _RedTeamResult, _RedTeamingScorecard, _RedTeamingParameters, RedTeamOutput
|
|
47
|
+
from ._attack_strategy import AttackStrategy
|
|
48
|
+
from ._attack_objective_generator import RiskCategory, _AttackObjectiveGenerator
|
|
49
|
+
|
|
50
|
+
# PyRIT imports
|
|
51
|
+
from pyrit.common import initialize_pyrit, DUCK_DB
|
|
52
|
+
from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget
|
|
53
|
+
from pyrit.models import ChatMessage
|
|
54
|
+
from pyrit.orchestrator.single_turn.prompt_sending_orchestrator import PromptSendingOrchestrator
|
|
55
|
+
from pyrit.orchestrator import Orchestrator
|
|
56
|
+
from pyrit.exceptions import PyritException
|
|
57
|
+
from pyrit.prompt_converter import PromptConverter, MathPromptConverter, Base64Converter, FlipConverter, MorseConverter, AnsiAttackConverter, AsciiArtConverter, AsciiSmugglerConverter, AtbashConverter, BinaryConverter, CaesarConverter, CharacterSpaceConverter, CharSwapGenerator, DiacriticConverter, LeetspeakConverter, UrlConverter, UnicodeSubstitutionConverter, UnicodeConfusableConverter, SuffixAppendConverter, StringJoinConverter, ROT13Converter
|
|
58
|
+
|
|
59
|
+
# Local imports - constants and utilities
|
|
60
|
+
from ._utils.constants import (
|
|
61
|
+
BASELINE_IDENTIFIER, DATA_EXT, RESULTS_EXT,
|
|
62
|
+
ATTACK_STRATEGY_COMPLEXITY_MAP, RISK_CATEGORY_EVALUATOR_MAP,
|
|
63
|
+
INTERNAL_TASK_TIMEOUT, TASK_STATUS
|
|
64
|
+
)
|
|
65
|
+
from ._utils.logging_utils import (
|
|
66
|
+
setup_logger, log_section_header, log_subsection_header,
|
|
67
|
+
log_strategy_start, log_strategy_completion, log_error
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@experimental
|
|
71
|
+
class RedTeam():
|
|
72
|
+
"""
|
|
73
|
+
This class uses various attack strategies to test the robustness of AI models against adversarial inputs.
|
|
74
|
+
It logs the results of these evaluations and provides detailed scorecards summarizing the attack success rates.
|
|
75
|
+
|
|
76
|
+
:param azure_ai_project: The Azure AI project configuration
|
|
77
|
+
:type azure_ai_project: dict
|
|
78
|
+
:param credential: The credential to authenticate with Azure services
|
|
79
|
+
:type credential: TokenCredential
|
|
80
|
+
:param risk_categories: List of risk categories to generate attack objectives for (optional if custom_attack_seed_prompts is provided)
|
|
81
|
+
:type risk_categories: Optional[List[RiskCategory]]
|
|
82
|
+
:param num_objectives: Number of objectives to generate per risk category
|
|
83
|
+
:type num_objectives: int
|
|
84
|
+
:param application_scenario: Description of the application scenario for context
|
|
85
|
+
:type application_scenario: Optional[str]
|
|
86
|
+
:param custom_attack_seed_prompts: Path to a JSON file containing custom attack seed prompts (can be absolute or relative path)
|
|
87
|
+
:type custom_attack_seed_prompts: Optional[str]
|
|
88
|
+
:param output_dir: Directory to store all output files. If None, files are created in the current working directory.
|
|
89
|
+
:type output_dir: Optional[str]
|
|
90
|
+
:param max_parallel_tasks: Maximum number of parallel tasks to run when scanning (default: 5)
|
|
91
|
+
:type max_parallel_tasks: int
|
|
92
|
+
"""
|
|
93
|
+
def __init__(self,
|
|
94
|
+
azure_ai_project,
|
|
95
|
+
credential,
|
|
96
|
+
risk_categories: Optional[List[RiskCategory]] = None,
|
|
97
|
+
num_objectives: int = 10,
|
|
98
|
+
application_scenario: Optional[str] = None,
|
|
99
|
+
custom_attack_seed_prompts: Optional[str] = None,
|
|
100
|
+
output_dir=None):
|
|
101
|
+
|
|
102
|
+
self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
|
|
103
|
+
self.credential = credential
|
|
104
|
+
self.output_dir = output_dir
|
|
105
|
+
|
|
106
|
+
# Initialize logger without output directory (will be updated during scan)
|
|
107
|
+
self.logger = setup_logger()
|
|
108
|
+
|
|
109
|
+
self.token_manager = ManagedIdentityAPITokenManager(
|
|
110
|
+
token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
|
|
111
|
+
logger=logging.getLogger("RedTeamLogger"),
|
|
112
|
+
credential=cast(TokenCredential, credential),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Initialize task tracking
|
|
116
|
+
self.task_statuses = {}
|
|
117
|
+
self.total_tasks = 0
|
|
118
|
+
self.completed_tasks = 0
|
|
119
|
+
self.failed_tasks = 0
|
|
120
|
+
self.start_time = None
|
|
121
|
+
self.scan_id = None
|
|
122
|
+
self.scan_output_dir = None
|
|
123
|
+
|
|
124
|
+
self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager)
|
|
125
|
+
self.generated_rai_client = GeneratedRAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.get_aad_credential()) #type: ignore
|
|
126
|
+
|
|
127
|
+
# Initialize a cache for attack objectives by risk category and strategy
|
|
128
|
+
self.attack_objectives = {}
|
|
129
|
+
|
|
130
|
+
# keep track of data and eval result file names
|
|
131
|
+
self.red_team_info = {}
|
|
132
|
+
|
|
133
|
+
initialize_pyrit(memory_db_type=DUCK_DB)
|
|
134
|
+
|
|
135
|
+
self.attack_objective_generator = _AttackObjectiveGenerator(risk_categories=risk_categories, num_objectives=num_objectives, application_scenario=application_scenario, custom_attack_seed_prompts=custom_attack_seed_prompts)
|
|
136
|
+
|
|
137
|
+
self.logger.debug("RedTeam initialized successfully")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _start_redteam_mlflow_run(
|
|
141
|
+
self,
|
|
142
|
+
azure_ai_project: Optional[AzureAIProject] = None,
|
|
143
|
+
run_name: Optional[str] = None
|
|
144
|
+
) -> EvalRun:
|
|
145
|
+
"""Start an MLFlow run for the Red Team Agent evaluation.
|
|
146
|
+
|
|
147
|
+
:param azure_ai_project: Azure AI project details for logging
|
|
148
|
+
:type azure_ai_project: Optional[~azure.ai.evaluation.AzureAIProject]
|
|
149
|
+
:param run_name: Optional name for the MLFlow run
|
|
150
|
+
:type run_name: Optional[str]
|
|
151
|
+
:return: The MLFlow run object
|
|
152
|
+
:rtype: ~azure.ai.evaluation._evaluate._eval_run.EvalRun
|
|
153
|
+
"""
|
|
154
|
+
if not azure_ai_project:
|
|
155
|
+
log_error(self.logger, "No azure_ai_project provided, cannot start MLFlow run")
|
|
156
|
+
raise EvaluationException(
|
|
157
|
+
message="No azure_ai_project provided",
|
|
158
|
+
blame=ErrorBlame.USER_ERROR,
|
|
159
|
+
category=ErrorCategory.MISSING_FIELD,
|
|
160
|
+
target=ErrorTarget.RED_TEAM
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
trace_destination = _trace_destination_from_project_scope(azure_ai_project)
|
|
164
|
+
if not trace_destination:
|
|
165
|
+
self.logger.warning("Could not determine trace destination from project scope")
|
|
166
|
+
raise EvaluationException(
|
|
167
|
+
message="Could not determine trace destination",
|
|
168
|
+
blame=ErrorBlame.SYSTEM_ERROR,
|
|
169
|
+
category=ErrorCategory.UNKNOWN,
|
|
170
|
+
target=ErrorTarget.RED_TEAM
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
ws_triad = extract_workspace_triad_from_trace_provider(trace_destination)
|
|
174
|
+
|
|
175
|
+
management_client = LiteMLClient(
|
|
176
|
+
subscription_id=ws_triad.subscription_id,
|
|
177
|
+
resource_group=ws_triad.resource_group_name,
|
|
178
|
+
logger=self.logger,
|
|
179
|
+
credential=azure_ai_project.get("credential")
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
tracking_uri = management_client.workspace_get_info(ws_triad.workspace_name).ml_flow_tracking_uri
|
|
183
|
+
|
|
184
|
+
run_display_name = run_name or f"redteam-agent-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
|
185
|
+
self.logger.debug(f"Starting MLFlow run with name: {run_display_name}")
|
|
186
|
+
|
|
187
|
+
eval_run = EvalRun(
|
|
188
|
+
run_name=run_display_name,
|
|
189
|
+
tracking_uri=cast(str, tracking_uri),
|
|
190
|
+
subscription_id=ws_triad.subscription_id,
|
|
191
|
+
group_name=ws_triad.resource_group_name,
|
|
192
|
+
workspace_name=ws_triad.workspace_name,
|
|
193
|
+
management_client=management_client, # type: ignore
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.trace_destination = trace_destination
|
|
197
|
+
self.logger.debug(f"MLFlow run created successfully with ID: {eval_run}")
|
|
198
|
+
|
|
199
|
+
return eval_run
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
async def _log_redteam_results_to_mlflow(
|
|
203
|
+
self,
|
|
204
|
+
redteam_output: RedTeamOutput,
|
|
205
|
+
eval_run: EvalRun,
|
|
206
|
+
data_only: bool = False,
|
|
207
|
+
) -> Optional[str]:
|
|
208
|
+
"""Log the Red Team Agent results to MLFlow.
|
|
209
|
+
|
|
210
|
+
:param redteam_output: The output from the red team agent evaluation
|
|
211
|
+
:type redteam_output: ~azure.ai.evaluation.RedTeamOutput
|
|
212
|
+
:param eval_run: The MLFlow run object
|
|
213
|
+
:type eval_run: ~azure.ai.evaluation._evaluate._eval_run.EvalRun
|
|
214
|
+
:param data_only: Whether to log only data without evaluation results
|
|
215
|
+
:type data_only: bool
|
|
216
|
+
:return: The URL to the run in Azure AI Studio, if available
|
|
217
|
+
:rtype: Optional[str]
|
|
218
|
+
"""
|
|
219
|
+
self.logger.debug(f"Logging results to MLFlow, data_only={data_only}")
|
|
220
|
+
artifact_name = "instance_results.json" if not data_only else "instance_data.json"
|
|
221
|
+
|
|
222
|
+
# If we have a scan output directory, save the results there first
|
|
223
|
+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
224
|
+
artifact_path = os.path.join(self.scan_output_dir, artifact_name)
|
|
225
|
+
self.logger.debug(f"Saving artifact to scan output directory: {artifact_path}")
|
|
226
|
+
|
|
227
|
+
with open(artifact_path, "w", encoding=DefaultOpenEncoding.WRITE) as f:
|
|
228
|
+
if data_only:
|
|
229
|
+
# In data_only mode, we write the conversations in conversation/messages format
|
|
230
|
+
f.write(json.dumps({"conversations": redteam_output.redteaming_data or []}))
|
|
231
|
+
elif redteam_output.red_team_result:
|
|
232
|
+
json.dump(redteam_output.red_team_result, f)
|
|
233
|
+
|
|
234
|
+
# Also save a human-readable scorecard if available
|
|
235
|
+
if not data_only and redteam_output.red_team_result:
|
|
236
|
+
scorecard_path = os.path.join(self.scan_output_dir, "scorecard.txt")
|
|
237
|
+
with open(scorecard_path, "w", encoding=DefaultOpenEncoding.WRITE) as f:
|
|
238
|
+
f.write(self._to_scorecard(redteam_output.red_team_result))
|
|
239
|
+
self.logger.debug(f"Saved scorecard to: {scorecard_path}")
|
|
240
|
+
|
|
241
|
+
# Create a dedicated artifacts directory with proper structure for MLFlow
|
|
242
|
+
# MLFlow requires the artifact_name file to be in the directory we're logging
|
|
243
|
+
|
|
244
|
+
import tempfile
|
|
245
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
246
|
+
# First, create the main artifact file that MLFlow expects
|
|
247
|
+
with open(os.path.join(tmpdir, artifact_name), "w", encoding=DefaultOpenEncoding.WRITE) as f:
|
|
248
|
+
if data_only:
|
|
249
|
+
f.write(json.dumps({"conversations": redteam_output.redteaming_data or []}))
|
|
250
|
+
elif redteam_output.red_team_result:
|
|
251
|
+
json.dump(redteam_output.red_team_result, f)
|
|
252
|
+
|
|
253
|
+
# Copy all relevant files to the temp directory
|
|
254
|
+
import shutil
|
|
255
|
+
for file in os.listdir(self.scan_output_dir):
|
|
256
|
+
file_path = os.path.join(self.scan_output_dir, file)
|
|
257
|
+
|
|
258
|
+
# Skip directories and log files if not in debug mode
|
|
259
|
+
if os.path.isdir(file_path):
|
|
260
|
+
continue
|
|
261
|
+
if file.endswith('.log') and not os.environ.get('DEBUG'):
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
try:
|
|
265
|
+
shutil.copy(file_path, os.path.join(tmpdir, file))
|
|
266
|
+
self.logger.debug(f"Copied file to artifact directory: {file}")
|
|
267
|
+
except Exception as e:
|
|
268
|
+
self.logger.warning(f"Failed to copy file {file} to artifact directory: {str(e)}")
|
|
269
|
+
|
|
270
|
+
# Log the entire directory to MLFlow
|
|
271
|
+
try:
|
|
272
|
+
eval_run.log_artifact(tmpdir, artifact_name)
|
|
273
|
+
self.logger.debug(f"Successfully logged artifacts directory to MLFlow")
|
|
274
|
+
except Exception as e:
|
|
275
|
+
self.logger.warning(f"Failed to log artifacts to MLFlow: {str(e)}")
|
|
276
|
+
|
|
277
|
+
# Also log a direct property to capture the scan output directory
|
|
278
|
+
try:
|
|
279
|
+
eval_run.write_properties_to_run_history({"scan_output_dir": str(self.scan_output_dir)})
|
|
280
|
+
self.logger.debug("Logged scan_output_dir property to MLFlow")
|
|
281
|
+
except Exception as e:
|
|
282
|
+
self.logger.warning(f"Failed to log scan_output_dir property to MLFlow: {str(e)}")
|
|
283
|
+
else:
|
|
284
|
+
# Use temporary directory as before if no scan output directory exists
|
|
285
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
286
|
+
artifact_file = Path(tmpdir) / artifact_name
|
|
287
|
+
with open(artifact_file, "w", encoding=DefaultOpenEncoding.WRITE) as f:
|
|
288
|
+
if data_only:
|
|
289
|
+
f.write(json.dumps({"conversations": redteam_output.redteaming_data or []}))
|
|
290
|
+
elif redteam_output.red_team_result:
|
|
291
|
+
json.dump(redteam_output.red_team_result, f)
|
|
292
|
+
eval_run.log_artifact(tmpdir, artifact_name)
|
|
293
|
+
self.logger.debug(f"Logged artifact: {artifact_name}")
|
|
294
|
+
|
|
295
|
+
eval_run.write_properties_to_run_history({
|
|
296
|
+
EvaluationRunProperties.RUN_TYPE: "eval_run",
|
|
297
|
+
"redteaming": "asr", # Red team agent specific run properties to help UI identify this as a redteaming run
|
|
298
|
+
EvaluationRunProperties.EVALUATION_SDK: f"azure-ai-evaluation:{VERSION}",
|
|
299
|
+
"_azureml.evaluate_artifacts": json.dumps([{"path": artifact_name, "type": "table"}]),
|
|
300
|
+
})
|
|
301
|
+
|
|
302
|
+
if redteam_output.red_team_result:
|
|
303
|
+
scorecard = redteam_output.red_team_result["redteaming_scorecard"]
|
|
304
|
+
joint_attack_summary = scorecard["joint_risk_attack_summary"]
|
|
305
|
+
|
|
306
|
+
if joint_attack_summary:
|
|
307
|
+
for risk_category_summary in joint_attack_summary:
|
|
308
|
+
risk_category = risk_category_summary.get("risk_category").lower()
|
|
309
|
+
for key, value in risk_category_summary.items():
|
|
310
|
+
if key != "risk_category":
|
|
311
|
+
eval_run.log_metric(f"{risk_category}_{key}", cast(float, value))
|
|
312
|
+
self.logger.debug(f"Logged metric: {risk_category}_{key} = {value}")
|
|
313
|
+
|
|
314
|
+
self.logger.info("Successfully logged results to MLFlow")
|
|
315
|
+
return None
|
|
316
|
+
|
|
317
|
+
# Using the utility function from strategy_utils.py instead
|
|
318
|
+
def _strategy_converter_map(self):
|
|
319
|
+
from ._utils.strategy_utils import strategy_converter_map
|
|
320
|
+
return strategy_converter_map()
|
|
321
|
+
|
|
322
|
+
async def _get_attack_objectives(
|
|
323
|
+
self,
|
|
324
|
+
risk_category: Optional[RiskCategory] = None, # Now accepting a single risk category
|
|
325
|
+
application_scenario: Optional[str] = None,
|
|
326
|
+
strategy: Optional[str] = None
|
|
327
|
+
) -> List[str]:
|
|
328
|
+
"""Get attack objectives from the RAI client for a specific risk category or from a custom dataset.
|
|
329
|
+
|
|
330
|
+
:param attack_objective_generator: The generator with risk categories to get attack objectives for
|
|
331
|
+
:type attack_objective_generator: ~azure.ai.evaluation.redteam._AttackObjectiveGenerator
|
|
332
|
+
:param risk_category: The specific risk category to get objectives for
|
|
333
|
+
:type risk_category: Optional[RiskCategory]
|
|
334
|
+
:param application_scenario: Optional description of the application scenario for context
|
|
335
|
+
:type application_scenario: str
|
|
336
|
+
:param strategy: Optional attack strategy to get specific objectives for
|
|
337
|
+
:type strategy: str
|
|
338
|
+
:return: A list of attack objective prompts
|
|
339
|
+
:rtype: List[str]
|
|
340
|
+
"""
|
|
341
|
+
attack_objective_generator = self.attack_objective_generator
|
|
342
|
+
# TODO: is this necessary?
|
|
343
|
+
if not risk_category:
|
|
344
|
+
self.logger.warning("No risk category provided, using the first category from the generator")
|
|
345
|
+
risk_category = attack_objective_generator.risk_categories[0] if attack_objective_generator.risk_categories else None
|
|
346
|
+
if not risk_category:
|
|
347
|
+
self.logger.error("No risk categories found in generator")
|
|
348
|
+
return []
|
|
349
|
+
|
|
350
|
+
# Convert risk category to lowercase for consistent caching
|
|
351
|
+
risk_cat_value = risk_category.value.lower()
|
|
352
|
+
num_objectives = attack_objective_generator.num_objectives
|
|
353
|
+
|
|
354
|
+
log_subsection_header(self.logger, f"Getting attack objectives for {risk_cat_value}, strategy: {strategy}")
|
|
355
|
+
|
|
356
|
+
# Check if we already have baseline objectives for this risk category
|
|
357
|
+
baseline_key = ((risk_cat_value,), "baseline")
|
|
358
|
+
baseline_objectives_exist = baseline_key in self.attack_objectives
|
|
359
|
+
current_key = ((risk_cat_value,), strategy)
|
|
360
|
+
|
|
361
|
+
# Check if custom attack seed prompts are provided in the generator
|
|
362
|
+
if attack_objective_generator.custom_attack_seed_prompts and attack_objective_generator.validated_prompts:
|
|
363
|
+
self.logger.info(f"Using custom attack seed prompts from {attack_objective_generator.custom_attack_seed_prompts}")
|
|
364
|
+
|
|
365
|
+
# Get the prompts for this risk category
|
|
366
|
+
custom_objectives = attack_objective_generator.valid_prompts_by_category.get(risk_cat_value, [])
|
|
367
|
+
|
|
368
|
+
if not custom_objectives:
|
|
369
|
+
self.logger.warning(f"No custom objectives found for risk category {risk_cat_value}")
|
|
370
|
+
return []
|
|
371
|
+
|
|
372
|
+
self.logger.info(f"Found {len(custom_objectives)} custom objectives for {risk_cat_value}")
|
|
373
|
+
|
|
374
|
+
# Sample if we have more than needed
|
|
375
|
+
if len(custom_objectives) > num_objectives:
|
|
376
|
+
selected_cat_objectives = random.sample(custom_objectives, num_objectives)
|
|
377
|
+
self.logger.info(f"Sampled {num_objectives} objectives from {len(custom_objectives)} available for {risk_cat_value}")
|
|
378
|
+
# Log ids of selected objectives for traceability
|
|
379
|
+
selected_ids = [obj.get("id", "unknown-id") for obj in selected_cat_objectives]
|
|
380
|
+
self.logger.debug(f"Selected objective IDs for {risk_cat_value}: {selected_ids}")
|
|
381
|
+
else:
|
|
382
|
+
selected_cat_objectives = custom_objectives
|
|
383
|
+
self.logger.info(f"Using all {len(custom_objectives)} available objectives for {risk_cat_value}")
|
|
384
|
+
|
|
385
|
+
# Handle jailbreak strategy - need to apply jailbreak prefixes to messages
|
|
386
|
+
if strategy == "jailbreak":
|
|
387
|
+
self.logger.debug("Applying jailbreak prefixes to custom objectives")
|
|
388
|
+
try:
|
|
389
|
+
jailbreak_prefixes = await self.generated_rai_client.get_jailbreak_prefixes()
|
|
390
|
+
for objective in selected_cat_objectives:
|
|
391
|
+
if "messages" in objective and len(objective["messages"]) > 0:
|
|
392
|
+
message = objective["messages"][0]
|
|
393
|
+
if isinstance(message, dict) and "content" in message:
|
|
394
|
+
message["content"] = f"{random.choice(jailbreak_prefixes)} {message['content']}"
|
|
395
|
+
except Exception as e:
|
|
396
|
+
log_error(self.logger, "Error applying jailbreak prefixes to custom objectives", e)
|
|
397
|
+
# Continue with unmodified prompts instead of failing completely
|
|
398
|
+
|
|
399
|
+
# Extract content from selected objectives
|
|
400
|
+
selected_prompts = []
|
|
401
|
+
for obj in selected_cat_objectives:
|
|
402
|
+
if "messages" in obj and len(obj["messages"]) > 0:
|
|
403
|
+
message = obj["messages"][0]
|
|
404
|
+
if isinstance(message, dict) and "content" in message:
|
|
405
|
+
selected_prompts.append(message["content"])
|
|
406
|
+
|
|
407
|
+
# Process the selected objectives for caching
|
|
408
|
+
objectives_by_category = {risk_cat_value: []}
|
|
409
|
+
|
|
410
|
+
for obj in selected_cat_objectives:
|
|
411
|
+
obj_id = obj.get("id", f"obj-{uuid.uuid4()}")
|
|
412
|
+
target_harms = obj.get("metadata", {}).get("target_harms", [])
|
|
413
|
+
content = ""
|
|
414
|
+
if "messages" in obj and len(obj["messages"]) > 0:
|
|
415
|
+
content = obj["messages"][0].get("content", "")
|
|
416
|
+
|
|
417
|
+
if not content:
|
|
418
|
+
continue
|
|
419
|
+
|
|
420
|
+
obj_data = {
|
|
421
|
+
"id": obj_id,
|
|
422
|
+
"content": content
|
|
423
|
+
}
|
|
424
|
+
objectives_by_category[risk_cat_value].append(obj_data)
|
|
425
|
+
|
|
426
|
+
# Store in cache
|
|
427
|
+
self.attack_objectives[current_key] = {
|
|
428
|
+
"objectives_by_category": objectives_by_category,
|
|
429
|
+
"strategy": strategy,
|
|
430
|
+
"risk_category": risk_cat_value,
|
|
431
|
+
"selected_prompts": selected_prompts,
|
|
432
|
+
"selected_objectives": selected_cat_objectives
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
self.logger.info(f"Using {len(selected_prompts)} custom objectives for {risk_cat_value}")
|
|
436
|
+
return selected_prompts
|
|
437
|
+
|
|
438
|
+
else:
|
|
439
|
+
# Use the RAI service to get attack objectives
|
|
440
|
+
try:
|
|
441
|
+
self.logger.debug(f"API call: get_attack_objectives({risk_cat_value}, app: {application_scenario}, strategy: {strategy})")
|
|
442
|
+
# strategy param specifies whether to get a strategy-specific dataset from the RAI service
|
|
443
|
+
# right now, only tense requires strategy-specific dataset
|
|
444
|
+
if strategy == "tense":
|
|
445
|
+
objectives_response = await self.generated_rai_client.get_attack_objectives(
|
|
446
|
+
risk_category=risk_cat_value,
|
|
447
|
+
application_scenario=application_scenario or "",
|
|
448
|
+
strategy=strategy
|
|
449
|
+
)
|
|
450
|
+
else:
|
|
451
|
+
objectives_response = await self.generated_rai_client.get_attack_objectives(
|
|
452
|
+
risk_category=risk_cat_value,
|
|
453
|
+
application_scenario=application_scenario or "",
|
|
454
|
+
strategy=None
|
|
455
|
+
)
|
|
456
|
+
if isinstance(objectives_response, list):
|
|
457
|
+
self.logger.debug(f"API returned {len(objectives_response)} objectives")
|
|
458
|
+
else:
|
|
459
|
+
self.logger.debug(f"API returned response of type: {type(objectives_response)}")
|
|
460
|
+
|
|
461
|
+
# Handle jailbreak strategy - need to apply jailbreak prefixes to messages
|
|
462
|
+
if strategy == "jailbreak":
|
|
463
|
+
self.logger.debug("Applying jailbreak prefixes to objectives")
|
|
464
|
+
jailbreak_prefixes = await self.generated_rai_client.get_jailbreak_prefixes()
|
|
465
|
+
for objective in objectives_response:
|
|
466
|
+
if "messages" in objective and len(objective["messages"]) > 0:
|
|
467
|
+
message = objective["messages"][0]
|
|
468
|
+
if isinstance(message, dict) and "content" in message:
|
|
469
|
+
message["content"] = f"{random.choice(jailbreak_prefixes)} {message['content']}"
|
|
470
|
+
except Exception as e:
|
|
471
|
+
log_error(self.logger, "Error calling get_attack_objectives", e)
|
|
472
|
+
self.logger.warning("API call failed, returning empty objectives list")
|
|
473
|
+
return []
|
|
474
|
+
|
|
475
|
+
# Check if the response is valid
|
|
476
|
+
if not objectives_response or (isinstance(objectives_response, dict) and not objectives_response.get("objectives")):
|
|
477
|
+
self.logger.warning("Empty or invalid response, returning empty list")
|
|
478
|
+
return []
|
|
479
|
+
|
|
480
|
+
# For non-baseline strategies, filter by baseline IDs if they exist
|
|
481
|
+
if strategy != "baseline" and baseline_objectives_exist:
|
|
482
|
+
self.logger.debug(f"Found existing baseline objectives for {risk_cat_value}, will filter {strategy} by baseline IDs")
|
|
483
|
+
baseline_selected_objectives = self.attack_objectives[baseline_key].get("selected_objectives", [])
|
|
484
|
+
baseline_objective_ids = []
|
|
485
|
+
|
|
486
|
+
# Extract IDs from baseline objectives
|
|
487
|
+
for obj in baseline_selected_objectives:
|
|
488
|
+
if "id" in obj:
|
|
489
|
+
baseline_objective_ids.append(obj["id"])
|
|
490
|
+
|
|
491
|
+
if baseline_objective_ids:
|
|
492
|
+
self.logger.debug(f"Filtering by {len(baseline_objective_ids)} baseline objective IDs for {strategy}")
|
|
493
|
+
|
|
494
|
+
# Filter objectives by baseline IDs
|
|
495
|
+
selected_cat_objectives = []
|
|
496
|
+
for obj in objectives_response:
|
|
497
|
+
if obj.get("id") in baseline_objective_ids:
|
|
498
|
+
selected_cat_objectives.append(obj)
|
|
499
|
+
|
|
500
|
+
self.logger.debug(f"Found {len(selected_cat_objectives)} matching objectives with baseline IDs")
|
|
501
|
+
# If we couldn't find all the baseline IDs, log a warning
|
|
502
|
+
if len(selected_cat_objectives) < len(baseline_objective_ids):
|
|
503
|
+
self.logger.warning(f"Only found {len(selected_cat_objectives)} objectives matching baseline IDs, expected {len(baseline_objective_ids)}")
|
|
504
|
+
else:
|
|
505
|
+
self.logger.warning("No baseline objective IDs found, using random selection")
|
|
506
|
+
# If we don't have baseline IDs for some reason, default to random selection
|
|
507
|
+
if len(objectives_response) > num_objectives:
|
|
508
|
+
selected_cat_objectives = random.sample(objectives_response, num_objectives)
|
|
509
|
+
else:
|
|
510
|
+
selected_cat_objectives = objectives_response
|
|
511
|
+
else:
|
|
512
|
+
# This is the baseline strategy or we don't have baseline objectives yet
|
|
513
|
+
self.logger.debug(f"Using random selection for {strategy} strategy")
|
|
514
|
+
if len(objectives_response) > num_objectives:
|
|
515
|
+
self.logger.debug(f"Selecting {num_objectives} objectives from {len(objectives_response)} available")
|
|
516
|
+
selected_cat_objectives = random.sample(objectives_response, num_objectives)
|
|
517
|
+
else:
|
|
518
|
+
selected_cat_objectives = objectives_response
|
|
519
|
+
|
|
520
|
+
if len(selected_cat_objectives) < num_objectives:
|
|
521
|
+
self.logger.warning(f"Only found {len(selected_cat_objectives)} objectives for {risk_cat_value}, fewer than requested {num_objectives}")
|
|
522
|
+
|
|
523
|
+
# Extract content from selected objectives
|
|
524
|
+
selected_prompts = []
|
|
525
|
+
for obj in selected_cat_objectives:
|
|
526
|
+
if "messages" in obj and len(obj["messages"]) > 0:
|
|
527
|
+
message = obj["messages"][0]
|
|
528
|
+
if isinstance(message, dict) and "content" in message:
|
|
529
|
+
selected_prompts.append(message["content"])
|
|
530
|
+
|
|
531
|
+
# Process the response - organize by category and extract content/IDs
|
|
532
|
+
objectives_by_category = {risk_cat_value: []}
|
|
533
|
+
|
|
534
|
+
# Process list format and organize by category for caching
|
|
535
|
+
for obj in selected_cat_objectives:
|
|
536
|
+
obj_id = obj.get("id", f"obj-{uuid.uuid4()}")
|
|
537
|
+
target_harms = obj.get("metadata", {}).get("target_harms", [])
|
|
538
|
+
content = ""
|
|
539
|
+
if "messages" in obj and len(obj["messages"]) > 0:
|
|
540
|
+
content = obj["messages"][0].get("content", "")
|
|
541
|
+
|
|
542
|
+
if not content:
|
|
543
|
+
continue
|
|
544
|
+
if target_harms:
|
|
545
|
+
for harm in target_harms:
|
|
546
|
+
obj_data = {
|
|
547
|
+
"id": obj_id,
|
|
548
|
+
"content": content
|
|
549
|
+
}
|
|
550
|
+
objectives_by_category[risk_cat_value].append(obj_data)
|
|
551
|
+
break # Just use the first harm for categorization
|
|
552
|
+
|
|
553
|
+
# Store in cache - now including the full selected objectives with IDs
|
|
554
|
+
self.attack_objectives[current_key] = {
|
|
555
|
+
"objectives_by_category": objectives_by_category,
|
|
556
|
+
"strategy": strategy,
|
|
557
|
+
"risk_category": risk_cat_value,
|
|
558
|
+
"selected_prompts": selected_prompts,
|
|
559
|
+
"selected_objectives": selected_cat_objectives # Store full objects with IDs
|
|
560
|
+
}
|
|
561
|
+
self.logger.info(f"Selected {len(selected_prompts)} objectives for {risk_cat_value}")
|
|
562
|
+
|
|
563
|
+
return selected_prompts
|
|
564
|
+
|
|
565
|
+
# Replace with utility function
|
|
566
|
+
def _message_to_dict(self, message: ChatMessage):
|
|
567
|
+
from ._utils.formatting_utils import message_to_dict
|
|
568
|
+
return message_to_dict(message)
|
|
569
|
+
|
|
570
|
+
# Replace with utility function
|
|
571
|
+
def _get_strategy_name(self, attack_strategy: Union[AttackStrategy, List[AttackStrategy]]) -> str:
|
|
572
|
+
from ._utils.formatting_utils import get_strategy_name
|
|
573
|
+
return get_strategy_name(attack_strategy)
|
|
574
|
+
|
|
575
|
+
# Replace with utility function
|
|
576
|
+
def _get_flattened_attack_strategies(self, attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]]) -> List[Union[AttackStrategy, List[AttackStrategy]]]:
|
|
577
|
+
from ._utils.formatting_utils import get_flattened_attack_strategies
|
|
578
|
+
return get_flattened_attack_strategies(attack_strategies)
|
|
579
|
+
|
|
580
|
+
# Replace with utility function
|
|
581
|
+
def _get_converter_for_strategy(self, attack_strategy: Union[AttackStrategy, List[AttackStrategy]]) -> Union[PromptConverter, List[PromptConverter]]:
|
|
582
|
+
from ._utils.strategy_utils import get_converter_for_strategy
|
|
583
|
+
return get_converter_for_strategy(attack_strategy)
|
|
584
|
+
|
|
585
|
+
async def _prompt_sending_orchestrator(
|
|
586
|
+
self,
|
|
587
|
+
chat_target: PromptChatTarget,
|
|
588
|
+
all_prompts: List[str],
|
|
589
|
+
converter: Union[PromptConverter, List[PromptConverter]],
|
|
590
|
+
strategy_name: str = "unknown",
|
|
591
|
+
risk_category: str = "unknown",
|
|
592
|
+
timeout: int = 120
|
|
593
|
+
) -> Orchestrator:
|
|
594
|
+
"""Send prompts via the PromptSendingOrchestrator with optimized performance.
|
|
595
|
+
|
|
596
|
+
:param chat_target: The target to send prompts to
|
|
597
|
+
:type chat_target: PromptChatTarget
|
|
598
|
+
:param all_prompts: List of prompts to send
|
|
599
|
+
:type all_prompts: List[str]
|
|
600
|
+
:param converter: Converter or list of converters to use for prompt transformation
|
|
601
|
+
:type converter: Union[PromptConverter, List[PromptConverter]]
|
|
602
|
+
:param strategy_name: Name of the strategy being used (for logging)
|
|
603
|
+
:type strategy_name: str
|
|
604
|
+
:param risk_category: Name of the risk category being evaluated (for logging)
|
|
605
|
+
:type risk_category: str
|
|
606
|
+
:param timeout: The timeout in seconds for API calls
|
|
607
|
+
:type timeout: int
|
|
608
|
+
:return: The orchestrator instance with processed results
|
|
609
|
+
:rtype: Orchestrator
|
|
610
|
+
"""
|
|
611
|
+
task_key = f"{strategy_name}_{risk_category}_orchestrator"
|
|
612
|
+
self.task_statuses[task_key] = TASK_STATUS["RUNNING"]
|
|
613
|
+
|
|
614
|
+
log_strategy_start(self.logger, strategy_name, risk_category)
|
|
615
|
+
|
|
616
|
+
# Create converter list from single converter or list of converters
|
|
617
|
+
converter_list = [converter] if converter and isinstance(converter, PromptConverter) else converter if converter else []
|
|
618
|
+
|
|
619
|
+
# Log which converter is being used
|
|
620
|
+
if converter_list:
|
|
621
|
+
if isinstance(converter_list, list) and len(converter_list) > 0:
|
|
622
|
+
converter_names = [c.__class__.__name__ for c in converter_list if c is not None]
|
|
623
|
+
self.logger.debug(f"Using converters: {', '.join(converter_names)}")
|
|
624
|
+
elif converter is not None:
|
|
625
|
+
self.logger.debug(f"Using converter: {converter.__class__.__name__}")
|
|
626
|
+
else:
|
|
627
|
+
self.logger.debug("No converters specified")
|
|
628
|
+
|
|
629
|
+
# Optimized orchestrator initialization
|
|
630
|
+
try:
|
|
631
|
+
orchestrator = PromptSendingOrchestrator(
|
|
632
|
+
objective_target=chat_target,
|
|
633
|
+
prompt_converters=converter_list
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
if not all_prompts:
|
|
637
|
+
self.logger.warning(f"No prompts provided to orchestrator for {strategy_name}/{risk_category}")
|
|
638
|
+
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
639
|
+
return orchestrator
|
|
640
|
+
|
|
641
|
+
# Debug log the first few characters of each prompt
|
|
642
|
+
self.logger.debug(f"First prompt (truncated): {all_prompts[0][:50]}...")
|
|
643
|
+
|
|
644
|
+
# Use a batched approach for send_prompts_async to prevent overwhelming
|
|
645
|
+
# the model with too many concurrent requests
|
|
646
|
+
batch_size = min(len(all_prompts), 3) # Process 3 prompts at a time max
|
|
647
|
+
|
|
648
|
+
# Process prompts concurrently within each batch
|
|
649
|
+
if len(all_prompts) > batch_size:
|
|
650
|
+
self.logger.debug(f"Processing {len(all_prompts)} prompts in batches of {batch_size} for {strategy_name}/{risk_category}")
|
|
651
|
+
batches = [all_prompts[i:i + batch_size] for i in range(0, len(all_prompts), batch_size)]
|
|
652
|
+
|
|
653
|
+
for batch_idx, batch in enumerate(batches):
|
|
654
|
+
self.logger.debug(f"Processing batch {batch_idx+1}/{len(batches)} with {len(batch)} prompts for {strategy_name}/{risk_category}")
|
|
655
|
+
|
|
656
|
+
batch_start_time = datetime.now()
|
|
657
|
+
# Send prompts in the batch concurrently with a timeout
|
|
658
|
+
try:
|
|
659
|
+
# Use wait_for to implement a timeout
|
|
660
|
+
await asyncio.wait_for(
|
|
661
|
+
orchestrator.send_prompts_async(prompt_list=batch),
|
|
662
|
+
timeout=timeout # Use provided timeout
|
|
663
|
+
)
|
|
664
|
+
batch_duration = (datetime.now() - batch_start_time).total_seconds()
|
|
665
|
+
self.logger.debug(f"Successfully processed batch {batch_idx+1} for {strategy_name}/{risk_category} in {batch_duration:.2f} seconds")
|
|
666
|
+
|
|
667
|
+
# Print progress to console
|
|
668
|
+
if batch_idx < len(batches) - 1: # Don't print for the last batch
|
|
669
|
+
print(f"Strategy {strategy_name}, Risk {risk_category}: Processed batch {batch_idx+1}/{len(batches)}")
|
|
670
|
+
|
|
671
|
+
except asyncio.TimeoutError:
|
|
672
|
+
self.logger.warning(f"Batch {batch_idx+1} for {strategy_name}/{risk_category} timed out after {timeout} seconds, continuing with partial results")
|
|
673
|
+
self.logger.debug(f"Timeout: Strategy {strategy_name}, Risk {risk_category}, Batch {batch_idx+1} after {timeout} seconds.", exc_info=True)
|
|
674
|
+
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {risk_category}, Batch {batch_idx+1}")
|
|
675
|
+
# Set task status to TIMEOUT
|
|
676
|
+
batch_task_key = f"{strategy_name}_{risk_category}_batch_{batch_idx+1}"
|
|
677
|
+
self.task_statuses[batch_task_key] = TASK_STATUS["TIMEOUT"]
|
|
678
|
+
# Continue with partial results rather than failing completely
|
|
679
|
+
continue
|
|
680
|
+
except Exception as e:
|
|
681
|
+
log_error(self.logger, f"Error processing batch {batch_idx+1}", e, f"{strategy_name}/{risk_category}")
|
|
682
|
+
print(f"❌ ERROR: Strategy {strategy_name}, Risk {risk_category}, Batch {batch_idx+1}: {str(e)}")
|
|
683
|
+
# Continue with other batches even if one fails
|
|
684
|
+
continue
|
|
685
|
+
else:
|
|
686
|
+
# Small number of prompts, process all at once with a timeout
|
|
687
|
+
self.logger.debug(f"Processing {len(all_prompts)} prompts in a single batch for {strategy_name}/{risk_category}")
|
|
688
|
+
batch_start_time = datetime.now()
|
|
689
|
+
try:
|
|
690
|
+
await asyncio.wait_for(
|
|
691
|
+
orchestrator.send_prompts_async(prompt_list=all_prompts),
|
|
692
|
+
timeout=timeout # Use provided timeout
|
|
693
|
+
)
|
|
694
|
+
batch_duration = (datetime.now() - batch_start_time).total_seconds()
|
|
695
|
+
self.logger.debug(f"Successfully processed single batch for {strategy_name}/{risk_category} in {batch_duration:.2f} seconds")
|
|
696
|
+
except asyncio.TimeoutError:
|
|
697
|
+
self.logger.warning(f"Prompt processing for {strategy_name}/{risk_category} timed out after {timeout} seconds, continuing with partial results")
|
|
698
|
+
print(f"⚠️ TIMEOUT: Strategy {strategy_name}, Risk {risk_category}")
|
|
699
|
+
# Set task status to TIMEOUT
|
|
700
|
+
single_batch_task_key = f"{strategy_name}_{risk_category}_single_batch"
|
|
701
|
+
self.task_statuses[single_batch_task_key] = TASK_STATUS["TIMEOUT"]
|
|
702
|
+
except Exception as e:
|
|
703
|
+
log_error(self.logger, "Error processing prompts", e, f"{strategy_name}/{risk_category}")
|
|
704
|
+
print(f"❌ ERROR: Strategy {strategy_name}, Risk {risk_category}: {str(e)}")
|
|
705
|
+
|
|
706
|
+
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
707
|
+
return orchestrator
|
|
708
|
+
|
|
709
|
+
except Exception as e:
|
|
710
|
+
log_error(self.logger, "Failed to initialize orchestrator", e, f"{strategy_name}/{risk_category}")
|
|
711
|
+
print(f"❌ CRITICAL: Failed to create orchestrator for {strategy_name}/{risk_category}: {str(e)}")
|
|
712
|
+
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
713
|
+
raise
|
|
714
|
+
|
|
715
|
+
def _write_pyrit_outputs_to_file(self, orchestrator: Orchestrator) -> str:
|
|
716
|
+
"""Write PyRIT outputs to a file with a name based on orchestrator, converter, and risk category.
|
|
717
|
+
|
|
718
|
+
:param orchestrator: The orchestrator that generated the outputs
|
|
719
|
+
:type orchestrator: Orchestrator
|
|
720
|
+
:return: Path to the output file
|
|
721
|
+
:rtype: Union[str, os.PathLike]
|
|
722
|
+
"""
|
|
723
|
+
base_path = str(uuid.uuid4())
|
|
724
|
+
|
|
725
|
+
# If scan output directory exists, place the file there
|
|
726
|
+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
727
|
+
output_path = os.path.join(self.scan_output_dir, f"{base_path}{DATA_EXT}")
|
|
728
|
+
else:
|
|
729
|
+
output_path = f"{base_path}{DATA_EXT}"
|
|
730
|
+
|
|
731
|
+
self.logger.debug(f"Writing PyRIT outputs to file: {output_path}")
|
|
732
|
+
|
|
733
|
+
memory = orchestrator.get_memory()
|
|
734
|
+
|
|
735
|
+
# Get conversations as a List[List[ChatMessage]]
|
|
736
|
+
conversations = [[item.to_chat_message() for item in group] for conv_id, group in itertools.groupby(memory, key=lambda x: x.conversation_id)]
|
|
737
|
+
|
|
738
|
+
#Convert to json lines
|
|
739
|
+
json_lines = ""
|
|
740
|
+
for conversation in conversations: # each conversation is a List[ChatMessage]
|
|
741
|
+
json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n"
|
|
742
|
+
|
|
743
|
+
with Path(output_path).open("w") as f:
|
|
744
|
+
f.writelines(json_lines)
|
|
745
|
+
|
|
746
|
+
orchestrator.dispose_db_engine()
|
|
747
|
+
self.logger.debug(f"Successfully wrote {len(conversations)} conversations to {output_path}")
|
|
748
|
+
return str(output_path)
|
|
749
|
+
|
|
750
|
+
# Replace with utility function
|
|
751
|
+
def _get_chat_target(self, target: Union[PromptChatTarget,Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration]) -> PromptChatTarget:
|
|
752
|
+
from ._utils.strategy_utils import get_chat_target
|
|
753
|
+
return get_chat_target(target)
|
|
754
|
+
|
|
755
|
+
# Replace with utility function
|
|
756
|
+
def _get_orchestrators_for_attack_strategies(self, attack_strategy: List[Union[AttackStrategy, List[AttackStrategy]]]) -> List[Callable]:
|
|
757
|
+
# We need to modify this to use our actual _prompt_sending_orchestrator since the utility function can't access it
|
|
758
|
+
call_to_orchestrators = []
|
|
759
|
+
# Sending PromptSendingOrchestrator for each complexity level
|
|
760
|
+
if AttackStrategy.EASY in attack_strategy:
|
|
761
|
+
call_to_orchestrators.extend([self._prompt_sending_orchestrator])
|
|
762
|
+
elif AttackStrategy.MODERATE in attack_strategy:
|
|
763
|
+
call_to_orchestrators.extend([self._prompt_sending_orchestrator])
|
|
764
|
+
elif AttackStrategy.DIFFICULT in attack_strategy:
|
|
765
|
+
call_to_orchestrators.extend([self._prompt_sending_orchestrator])
|
|
766
|
+
else:
|
|
767
|
+
call_to_orchestrators.extend([self._prompt_sending_orchestrator])
|
|
768
|
+
return call_to_orchestrators
|
|
769
|
+
|
|
770
|
+
# Replace with utility function
|
|
771
|
+
def _get_attack_success(self, result: str) -> bool:
|
|
772
|
+
from ._utils.formatting_utils import get_attack_success
|
|
773
|
+
return get_attack_success(result)
|
|
774
|
+
|
|
775
|
+
def _to_red_team_result(self) -> _RedTeamResult:
|
|
776
|
+
"""Convert tracking data from red_team_info to the _RedTeamResult format.
|
|
777
|
+
|
|
778
|
+
Uses only the red_team_info tracking dictionary to build the _RedTeamResult.
|
|
779
|
+
|
|
780
|
+
:return: Structured red team agent results
|
|
781
|
+
:rtype: _RedTeamResult
|
|
782
|
+
"""
|
|
783
|
+
converters = []
|
|
784
|
+
complexity_levels = []
|
|
785
|
+
risk_categories = []
|
|
786
|
+
attack_successes = [] # unified list for all attack successes
|
|
787
|
+
conversations = []
|
|
788
|
+
|
|
789
|
+
# Create a CSV summary file for attack data in the scan output directory if available
|
|
790
|
+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
791
|
+
summary_file = os.path.join(self.scan_output_dir, "attack_summary.csv")
|
|
792
|
+
self.logger.debug(f"Creating attack summary CSV file: {summary_file}")
|
|
793
|
+
|
|
794
|
+
self.logger.info(f"Building _RedTeamResult from red_team_info with {len(self.red_team_info)} strategies")
|
|
795
|
+
|
|
796
|
+
# Process each strategy and risk category from red_team_info
|
|
797
|
+
for strategy_name, risk_data in self.red_team_info.items():
|
|
798
|
+
self.logger.info(f"Processing results for strategy: {strategy_name}")
|
|
799
|
+
|
|
800
|
+
# Determine complexity level for this strategy
|
|
801
|
+
if "Baseline" in strategy_name:
|
|
802
|
+
complexity_level = "baseline"
|
|
803
|
+
else:
|
|
804
|
+
# Try to map strategy name to complexity level
|
|
805
|
+
# Default is difficult since we assume it's a composed strategy
|
|
806
|
+
complexity_level = ATTACK_STRATEGY_COMPLEXITY_MAP.get(strategy_name, "difficult")
|
|
807
|
+
|
|
808
|
+
for risk_category, data in risk_data.items():
|
|
809
|
+
self.logger.info(f"Processing data for {risk_category} in strategy {strategy_name}")
|
|
810
|
+
|
|
811
|
+
data_file = data.get("data_file", "")
|
|
812
|
+
eval_result = data.get("evaluation_result")
|
|
813
|
+
|
|
814
|
+
# Process data file to extract conversations
|
|
815
|
+
if data_file and os.path.exists(data_file):
|
|
816
|
+
try:
|
|
817
|
+
with open(data_file, "r") as f:
|
|
818
|
+
for line in f:
|
|
819
|
+
try:
|
|
820
|
+
conv_data = json.loads(line)
|
|
821
|
+
if "conversation" in conv_data and "messages" in conv_data["conversation"]:
|
|
822
|
+
messages = conv_data["conversation"]["messages"]
|
|
823
|
+
|
|
824
|
+
# Determine attack success based on evaluation results if available
|
|
825
|
+
attack_success = None
|
|
826
|
+
risk_assessment = {}
|
|
827
|
+
|
|
828
|
+
# If we have evaluation results, try to extract attack success and risk assessment
|
|
829
|
+
if eval_result:
|
|
830
|
+
# Look for this conversation in the evaluation results
|
|
831
|
+
rows = eval_result.get("rows", [])
|
|
832
|
+
for r in rows:
|
|
833
|
+
if r.get("inputs.conversation", {}).get("messages") == messages:
|
|
834
|
+
# Found matching conversation
|
|
835
|
+
if f"outputs.{risk_category}.{risk_category}_result" in r:
|
|
836
|
+
attack_success = self._get_attack_success(r[f"outputs.{risk_category}.{risk_category}_result"])
|
|
837
|
+
|
|
838
|
+
# Extract risk assessments for all categories
|
|
839
|
+
for risk in self.risk_categories:
|
|
840
|
+
risk_value = risk.value
|
|
841
|
+
if f"outputs.{risk_value}.{risk_value}" in r and f"outputs.{risk_value}.{risk_value}_reason" in r:
|
|
842
|
+
risk_assessment[risk_value] = {
|
|
843
|
+
"severity_label": r[f"outputs.{risk_value}.{risk_value}"],
|
|
844
|
+
"reason": r[f"outputs.{risk_value}.{risk_value}_reason"]
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
# Add to tracking arrays for statistical analysis
|
|
848
|
+
converters.append(strategy_name)
|
|
849
|
+
complexity_levels.append(complexity_level)
|
|
850
|
+
risk_categories.append(risk_category)
|
|
851
|
+
|
|
852
|
+
if attack_success is not None:
|
|
853
|
+
attack_successes.append(1 if attack_success else 0)
|
|
854
|
+
else:
|
|
855
|
+
attack_successes.append(None)
|
|
856
|
+
|
|
857
|
+
# Add conversation object
|
|
858
|
+
conversation = {
|
|
859
|
+
"attack_success": attack_success,
|
|
860
|
+
"attack_technique": strategy_name.replace("Converter", "").replace("Prompt", ""),
|
|
861
|
+
"attack_complexity": complexity_level,
|
|
862
|
+
"risk_category": risk_category,
|
|
863
|
+
"conversation": messages,
|
|
864
|
+
"risk_assessment": risk_assessment if risk_assessment else None
|
|
865
|
+
}
|
|
866
|
+
conversations.append(conversation)
|
|
867
|
+
except json.JSONDecodeError as e:
|
|
868
|
+
self.logger.error(f"Error parsing JSON in data file {data_file}: {e}")
|
|
869
|
+
except Exception as e:
|
|
870
|
+
self.logger.error(f"Error processing data file {data_file}: {e}")
|
|
871
|
+
else:
|
|
872
|
+
self.logger.warning(f"Data file {data_file} not found or not specified for {strategy_name}/{risk_category}")
|
|
873
|
+
|
|
874
|
+
# Sort conversations by attack technique for better readability
|
|
875
|
+
conversations.sort(key=lambda x: x["attack_technique"])
|
|
876
|
+
|
|
877
|
+
self.logger.info(f"Processed {len(conversations)} conversations from all data files")
|
|
878
|
+
|
|
879
|
+
# Create a DataFrame for analysis - with unified structure
|
|
880
|
+
results_dict = {
|
|
881
|
+
"converter": converters,
|
|
882
|
+
"complexity_level": complexity_levels,
|
|
883
|
+
"risk_category": risk_categories,
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
# Only include attack_success if we have evaluation results
|
|
887
|
+
if any(success is not None for success in attack_successes):
|
|
888
|
+
results_dict["attack_success"] = [math.nan if success is None else success for success in attack_successes]
|
|
889
|
+
self.logger.info(f"Including attack success data for {sum(1 for s in attack_successes if s is not None)} conversations")
|
|
890
|
+
|
|
891
|
+
results_df = pd.DataFrame.from_dict(results_dict)
|
|
892
|
+
|
|
893
|
+
if "attack_success" not in results_df.columns or results_df.empty:
|
|
894
|
+
# If we don't have evaluation results or the DataFrame is empty, create a default scorecard
|
|
895
|
+
self.logger.info("No evaluation results available or no data found, creating default scorecard")
|
|
896
|
+
|
|
897
|
+
# Create a basic scorecard structure
|
|
898
|
+
scorecard = {
|
|
899
|
+
"risk_category_summary": [{"overall_asr": 0.0, "overall_total": len(conversations), "overall_attack_successes": 0}],
|
|
900
|
+
"attack_technique_summary": [{"overall_asr": 0.0, "overall_total": len(conversations), "overall_attack_successes": 0}],
|
|
901
|
+
"joint_risk_attack_summary": [],
|
|
902
|
+
"detailed_joint_risk_attack_asr": {}
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
# Create basic parameters
|
|
906
|
+
redteaming_parameters = {
|
|
907
|
+
"attack_objective_generated_from": {
|
|
908
|
+
"application_scenario": self.application_scenario,
|
|
909
|
+
"risk_categories": [risk.value for risk in self.risk_categories],
|
|
910
|
+
"custom_attack_seed_prompts": "",
|
|
911
|
+
"policy_document": ""
|
|
912
|
+
},
|
|
913
|
+
"attack_complexity": list(set(complexity_levels)) if complexity_levels else ["baseline", "easy"],
|
|
914
|
+
"techniques_used": {}
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
for complexity in set(complexity_levels) if complexity_levels else ["baseline", "easy"]:
|
|
918
|
+
complexity_converters = [conv for i, conv in enumerate(converters) if i < len(complexity_levels) and complexity_levels[i] == complexity]
|
|
919
|
+
redteaming_parameters["techniques_used"][complexity] = list(set(complexity_converters)) if complexity_converters else []
|
|
920
|
+
else:
|
|
921
|
+
# Calculate risk category summaries by aggregating on risk category
|
|
922
|
+
risk_category_groups = results_df.groupby("risk_category")
|
|
923
|
+
risk_category_summary = {}
|
|
924
|
+
|
|
925
|
+
# Overall metrics across all categories
|
|
926
|
+
try:
|
|
927
|
+
overall_asr = round(list_mean_nan_safe(results_df["attack_success"].tolist()) * 100, 2) if "attack_success" in results_df.columns else 0.0
|
|
928
|
+
except EvaluationException:
|
|
929
|
+
self.logger.debug("All values in overall attack success array were None or NaN, setting ASR to NaN")
|
|
930
|
+
overall_asr = math.nan
|
|
931
|
+
overall_total = len(results_df)
|
|
932
|
+
overall_successful_attacks = sum([s for s in results_df["attack_success"].tolist() if not is_none_or_nan(s)]) if "attack_success" in results_df.columns else 0
|
|
933
|
+
|
|
934
|
+
risk_category_summary.update({
|
|
935
|
+
"overall_asr": overall_asr,
|
|
936
|
+
"overall_total": overall_total,
|
|
937
|
+
"overall_attack_successes": int(overall_successful_attacks)
|
|
938
|
+
})
|
|
939
|
+
|
|
940
|
+
# Per-risk category metrics
|
|
941
|
+
for risk, group in risk_category_groups:
|
|
942
|
+
try:
|
|
943
|
+
asr = round(list_mean_nan_safe(group["attack_success"].tolist()) * 100, 2) if "attack_success" in group.columns else 0.0
|
|
944
|
+
except EvaluationException:
|
|
945
|
+
self.logger.debug(f"All values in attack success array for {risk} were None or NaN, setting ASR to NaN")
|
|
946
|
+
asr = math.nan
|
|
947
|
+
total = len(group)
|
|
948
|
+
successful_attacks =sum([s for s in group["attack_success"].tolist() if not is_none_or_nan(s)]) if "attack_success" in group.columns else 0
|
|
949
|
+
|
|
950
|
+
risk_category_summary.update({
|
|
951
|
+
f"{risk}_asr": asr,
|
|
952
|
+
f"{risk}_total": total,
|
|
953
|
+
f"{risk}_successful_attacks": int(successful_attacks)
|
|
954
|
+
})
|
|
955
|
+
|
|
956
|
+
# Calculate attack technique summaries by complexity level
|
|
957
|
+
# First, create masks for each complexity level
|
|
958
|
+
baseline_mask = results_df["complexity_level"] == "baseline"
|
|
959
|
+
easy_mask = results_df["complexity_level"] == "easy"
|
|
960
|
+
moderate_mask = results_df["complexity_level"] == "moderate"
|
|
961
|
+
difficult_mask = results_df["complexity_level"] == "difficult"
|
|
962
|
+
|
|
963
|
+
# Then calculate metrics for each complexity level
|
|
964
|
+
attack_technique_summary_dict = {}
|
|
965
|
+
|
|
966
|
+
# Baseline metrics
|
|
967
|
+
baseline_df = results_df[baseline_mask]
|
|
968
|
+
if not baseline_df.empty:
|
|
969
|
+
try:
|
|
970
|
+
baseline_asr = round(list_mean_nan_safe(baseline_df["attack_success"].tolist()) * 100, 2) if "attack_success" in baseline_df.columns else 0.0
|
|
971
|
+
except EvaluationException:
|
|
972
|
+
self.logger.debug("All values in baseline attack success array were None or NaN, setting ASR to NaN")
|
|
973
|
+
baseline_asr = math.nan
|
|
974
|
+
attack_technique_summary_dict.update({
|
|
975
|
+
"baseline_asr": baseline_asr,
|
|
976
|
+
"baseline_total": len(baseline_df),
|
|
977
|
+
"baseline_attack_successes": sum([s for s in baseline_df["attack_success"].tolist() if not is_none_or_nan(s)]) if "attack_success" in baseline_df.columns else 0
|
|
978
|
+
})
|
|
979
|
+
|
|
980
|
+
# Easy complexity metrics
|
|
981
|
+
easy_df = results_df[easy_mask]
|
|
982
|
+
if not easy_df.empty:
|
|
983
|
+
try:
|
|
984
|
+
easy_complexity_asr = round(list_mean_nan_safe(easy_df["attack_success"].tolist()) * 100, 2) if "attack_success" in easy_df.columns else 0.0
|
|
985
|
+
except EvaluationException:
|
|
986
|
+
self.logger.debug("All values in easy complexity attack success array were None or NaN, setting ASR to NaN")
|
|
987
|
+
easy_complexity_asr = math.nan
|
|
988
|
+
attack_technique_summary_dict.update({
|
|
989
|
+
"easy_complexity_asr": easy_complexity_asr,
|
|
990
|
+
"easy_complexity_total": len(easy_df),
|
|
991
|
+
"easy_complexity_attack_successes": sum([s for s in easy_df["attack_success"].tolist() if not is_none_or_nan(s)]) if "attack_success" in easy_df.columns else 0
|
|
992
|
+
})
|
|
993
|
+
|
|
994
|
+
# Moderate complexity metrics
|
|
995
|
+
moderate_df = results_df[moderate_mask]
|
|
996
|
+
if not moderate_df.empty:
|
|
997
|
+
try:
|
|
998
|
+
moderate_complexity_asr = round(list_mean_nan_safe(moderate_df["attack_success"].tolist()) * 100, 2) if "attack_success" in moderate_df.columns else 0.0
|
|
999
|
+
except EvaluationException:
|
|
1000
|
+
self.logger.debug("All values in moderate complexity attack success array were None or NaN, setting ASR to NaN")
|
|
1001
|
+
moderate_complexity_asr = math.nan
|
|
1002
|
+
attack_technique_summary_dict.update({
|
|
1003
|
+
"moderate_complexity_asr": moderate_complexity_asr,
|
|
1004
|
+
"moderate_complexity_total": len(moderate_df),
|
|
1005
|
+
"moderate_complexity_attack_successes": sum([s for s in moderate_df["attack_success"].tolist() if not is_none_or_nan(s)]) if "attack_success" in moderate_df.columns else 0
|
|
1006
|
+
})
|
|
1007
|
+
|
|
1008
|
+
# Difficult complexity metrics
|
|
1009
|
+
difficult_df = results_df[difficult_mask]
|
|
1010
|
+
if not difficult_df.empty:
|
|
1011
|
+
try:
|
|
1012
|
+
difficult_complexity_asr = round(list_mean_nan_safe(difficult_df["attack_success"].tolist()) * 100, 2) if "attack_success" in difficult_df.columns else 0.0
|
|
1013
|
+
except EvaluationException:
|
|
1014
|
+
self.logger.debug("All values in difficult complexity attack success array were None or NaN, setting ASR to NaN")
|
|
1015
|
+
difficult_complexity_asr = math.nan
|
|
1016
|
+
attack_technique_summary_dict.update({
|
|
1017
|
+
"difficult_complexity_asr": difficult_complexity_asr,
|
|
1018
|
+
"difficult_complexity_total": len(difficult_df),
|
|
1019
|
+
"difficult_complexity_attack_successes": sum([s for s in difficult_df["attack_success"].tolist() if not is_none_or_nan(s)]) if "attack_success" in difficult_df.columns else 0
|
|
1020
|
+
})
|
|
1021
|
+
|
|
1022
|
+
# Overall metrics
|
|
1023
|
+
attack_technique_summary_dict.update({
|
|
1024
|
+
"overall_asr": overall_asr,
|
|
1025
|
+
"overall_total": overall_total,
|
|
1026
|
+
"overall_attack_successes": int(overall_successful_attacks)
|
|
1027
|
+
})
|
|
1028
|
+
|
|
1029
|
+
attack_technique_summary = [attack_technique_summary_dict]
|
|
1030
|
+
|
|
1031
|
+
# Create joint risk attack summary
|
|
1032
|
+
joint_risk_attack_summary = []
|
|
1033
|
+
unique_risks = results_df["risk_category"].unique()
|
|
1034
|
+
|
|
1035
|
+
for risk in unique_risks:
|
|
1036
|
+
risk_key = risk.replace("-", "_")
|
|
1037
|
+
risk_mask = results_df["risk_category"] == risk
|
|
1038
|
+
|
|
1039
|
+
joint_risk_dict = {"risk_category": risk_key}
|
|
1040
|
+
|
|
1041
|
+
# Baseline ASR for this risk
|
|
1042
|
+
baseline_risk_df = results_df[risk_mask & baseline_mask]
|
|
1043
|
+
if not baseline_risk_df.empty:
|
|
1044
|
+
try:
|
|
1045
|
+
joint_risk_dict["baseline_asr"] = round(list_mean_nan_safe(baseline_risk_df["attack_success"].tolist()) * 100, 2) if "attack_success" in baseline_risk_df.columns else 0.0
|
|
1046
|
+
except EvaluationException:
|
|
1047
|
+
self.logger.debug(f"All values in baseline attack success array for {risk_key} were None or NaN, setting ASR to NaN")
|
|
1048
|
+
joint_risk_dict["baseline_asr"] = math.nan
|
|
1049
|
+
|
|
1050
|
+
# Easy complexity ASR for this risk
|
|
1051
|
+
easy_risk_df = results_df[risk_mask & easy_mask]
|
|
1052
|
+
if not easy_risk_df.empty:
|
|
1053
|
+
try:
|
|
1054
|
+
joint_risk_dict["easy_complexity_asr"] = round(list_mean_nan_safe(easy_risk_df["attack_success"].tolist()) * 100, 2) if "attack_success" in easy_risk_df.columns else 0.0
|
|
1055
|
+
except EvaluationException:
|
|
1056
|
+
self.logger.debug(f"All values in easy complexity attack success array for {risk_key} were None or NaN, setting ASR to NaN")
|
|
1057
|
+
joint_risk_dict["easy_complexity_asr"] = math.nan
|
|
1058
|
+
|
|
1059
|
+
# Moderate complexity ASR for this risk
|
|
1060
|
+
moderate_risk_df = results_df[risk_mask & moderate_mask]
|
|
1061
|
+
if not moderate_risk_df.empty:
|
|
1062
|
+
try:
|
|
1063
|
+
joint_risk_dict["moderate_complexity_asr"] = round(list_mean_nan_safe(moderate_risk_df["attack_success"].tolist()) * 100, 2) if "attack_success" in moderate_risk_df.columns else 0.0
|
|
1064
|
+
except EvaluationException:
|
|
1065
|
+
self.logger.debug(f"All values in moderate complexity attack success array for {risk_key} were None or NaN, setting ASR to NaN")
|
|
1066
|
+
joint_risk_dict["moderate_complexity_asr"] = math.nan
|
|
1067
|
+
|
|
1068
|
+
# Difficult complexity ASR for this risk
|
|
1069
|
+
difficult_risk_df = results_df[risk_mask & difficult_mask]
|
|
1070
|
+
if not difficult_risk_df.empty:
|
|
1071
|
+
try:
|
|
1072
|
+
joint_risk_dict["difficult_complexity_asr"] = round(list_mean_nan_safe(difficult_risk_df["attack_success"].tolist()) * 100, 2) if "attack_success" in difficult_risk_df.columns else 0.0
|
|
1073
|
+
except EvaluationException:
|
|
1074
|
+
self.logger.debug(f"All values in difficult complexity attack success array for {risk_key} were None or NaN, setting ASR to NaN")
|
|
1075
|
+
joint_risk_dict["difficult_complexity_asr"] = math.nan
|
|
1076
|
+
|
|
1077
|
+
joint_risk_attack_summary.append(joint_risk_dict)
|
|
1078
|
+
|
|
1079
|
+
# Calculate detailed joint risk attack ASR
|
|
1080
|
+
detailed_joint_risk_attack_asr = {}
|
|
1081
|
+
unique_complexities = sorted([c for c in results_df["complexity_level"].unique() if c != "baseline"])
|
|
1082
|
+
|
|
1083
|
+
for complexity in unique_complexities:
|
|
1084
|
+
complexity_mask = results_df["complexity_level"] == complexity
|
|
1085
|
+
if results_df[complexity_mask].empty:
|
|
1086
|
+
continue
|
|
1087
|
+
|
|
1088
|
+
detailed_joint_risk_attack_asr[complexity] = {}
|
|
1089
|
+
|
|
1090
|
+
for risk in unique_risks:
|
|
1091
|
+
risk_key = risk.replace("-", "_")
|
|
1092
|
+
risk_mask = results_df["risk_category"] == risk
|
|
1093
|
+
detailed_joint_risk_attack_asr[complexity][risk_key] = {}
|
|
1094
|
+
|
|
1095
|
+
# Group by converter within this complexity and risk
|
|
1096
|
+
complexity_risk_df = results_df[complexity_mask & risk_mask]
|
|
1097
|
+
if complexity_risk_df.empty:
|
|
1098
|
+
continue
|
|
1099
|
+
|
|
1100
|
+
converter_groups = complexity_risk_df.groupby("converter")
|
|
1101
|
+
for converter_name, converter_group in converter_groups:
|
|
1102
|
+
try:
|
|
1103
|
+
asr_value = round(list_mean_nan_safe(converter_group["attack_success"].tolist()) * 100, 2) if "attack_success" in converter_group.columns else 0.0
|
|
1104
|
+
except EvaluationException:
|
|
1105
|
+
self.logger.debug(f"All values in attack success array for {converter_name} in {complexity}/{risk_key} were None or NaN, setting ASR to NaN")
|
|
1106
|
+
asr_value = math.nan
|
|
1107
|
+
detailed_joint_risk_attack_asr[complexity][risk_key][f"{converter_name}_ASR"] = asr_value
|
|
1108
|
+
|
|
1109
|
+
# Compile the scorecard
|
|
1110
|
+
scorecard = {
|
|
1111
|
+
"risk_category_summary": [risk_category_summary],
|
|
1112
|
+
"attack_technique_summary": attack_technique_summary,
|
|
1113
|
+
"joint_risk_attack_summary": joint_risk_attack_summary,
|
|
1114
|
+
"detailed_joint_risk_attack_asr": detailed_joint_risk_attack_asr
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
# Create redteaming parameters
|
|
1118
|
+
redteaming_parameters = {
|
|
1119
|
+
"attack_objective_generated_from": {
|
|
1120
|
+
"application_scenario": self.application_scenario,
|
|
1121
|
+
"risk_categories": [risk.value for risk in self.risk_categories],
|
|
1122
|
+
"custom_attack_seed_prompts": "",
|
|
1123
|
+
"policy_document": ""
|
|
1124
|
+
},
|
|
1125
|
+
"attack_complexity": [c.capitalize() for c in unique_complexities],
|
|
1126
|
+
"techniques_used": {}
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
# Populate techniques used by complexity level
|
|
1130
|
+
for complexity in unique_complexities:
|
|
1131
|
+
complexity_mask = results_df["complexity_level"] == complexity
|
|
1132
|
+
complexity_df = results_df[complexity_mask]
|
|
1133
|
+
if not complexity_df.empty:
|
|
1134
|
+
complexity_converters = complexity_df["converter"].unique().tolist()
|
|
1135
|
+
redteaming_parameters["techniques_used"][complexity] = complexity_converters
|
|
1136
|
+
|
|
1137
|
+
self.logger.info("_RedTeamResult creation completed")
|
|
1138
|
+
|
|
1139
|
+
# Create the final result
|
|
1140
|
+
red_team_result = _RedTeamResult(
|
|
1141
|
+
redteaming_scorecard=cast(_RedTeamingScorecard, scorecard),
|
|
1142
|
+
redteaming_parameters=cast(_RedTeamingParameters, redteaming_parameters),
|
|
1143
|
+
redteaming_data=conversations,
|
|
1144
|
+
studio_url=self.ai_studio_url or None
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
return red_team_result
|
|
1148
|
+
|
|
1149
|
+
# Replace with utility function
|
|
1150
|
+
def _to_scorecard(self, redteam_result: _RedTeamResult) -> str:
|
|
1151
|
+
from ._utils.formatting_utils import format_scorecard
|
|
1152
|
+
return format_scorecard(redteam_result)
|
|
1153
|
+
|
|
1154
|
+
async def _evaluate(
|
|
1155
|
+
self,
|
|
1156
|
+
data_path: Union[str, os.PathLike],
|
|
1157
|
+
risk_category: RiskCategory,
|
|
1158
|
+
strategy: Union[AttackStrategy, List[AttackStrategy]],
|
|
1159
|
+
scan_name: Optional[str] = None,
|
|
1160
|
+
data_only: bool = False,
|
|
1161
|
+
output_path: Optional[Union[str, os.PathLike]] = None
|
|
1162
|
+
) -> None:
|
|
1163
|
+
"""Call the evaluate method if not data_only.
|
|
1164
|
+
|
|
1165
|
+
:param scan_name: Optional name for the evaluation.
|
|
1166
|
+
:type scan_name: Optional[str]
|
|
1167
|
+
:param data_only: Whether to return only data paths instead of evaluation results.
|
|
1168
|
+
:type data_only: bool
|
|
1169
|
+
:param data_path: Path to the input data.
|
|
1170
|
+
:type data_path: Optional[Union[str, os.PathLike]]
|
|
1171
|
+
:param output_path: Path for output results.
|
|
1172
|
+
:type output_path: Optional[Union[str, os.PathLike]]
|
|
1173
|
+
:return: Evaluation results or data paths.
|
|
1174
|
+
:rtype: Union[Dict[str, EvaluationResult], Dict[str, List[str]]]
|
|
1175
|
+
"""
|
|
1176
|
+
strategy_name = self._get_strategy_name(strategy)
|
|
1177
|
+
self.logger.debug(f"Evaluate called with data_path={data_path}, risk_category={risk_category.value}, strategy={strategy_name}, output_path={output_path}, data_only={data_only}, scan_name={scan_name}")
|
|
1178
|
+
if data_only:
|
|
1179
|
+
return None
|
|
1180
|
+
|
|
1181
|
+
# If output_path is provided, use it; otherwise create one in the scan output directory if available
|
|
1182
|
+
if output_path:
|
|
1183
|
+
result_path = output_path
|
|
1184
|
+
elif hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
1185
|
+
result_filename = f"{strategy_name}_{risk_category.value}_{str(uuid.uuid4())}{RESULTS_EXT}"
|
|
1186
|
+
result_path = os.path.join(self.scan_output_dir, result_filename)
|
|
1187
|
+
else:
|
|
1188
|
+
result_path = f"{str(uuid.uuid4())}{RESULTS_EXT}"
|
|
1189
|
+
|
|
1190
|
+
evaluators_dict = {
|
|
1191
|
+
risk_category.value: RISK_CATEGORY_EVALUATOR_MAP[risk_category](azure_ai_project=self.azure_ai_project, credential=self.credential)
|
|
1192
|
+
}
|
|
1193
|
+
|
|
1194
|
+
# Completely suppress all output during evaluation call
|
|
1195
|
+
import io
|
|
1196
|
+
import sys
|
|
1197
|
+
import logging
|
|
1198
|
+
# Don't re-import os as it's already imported at the module level
|
|
1199
|
+
|
|
1200
|
+
# Create a DevNull class to completely discard all writes
|
|
1201
|
+
class DevNull:
|
|
1202
|
+
def write(self, msg):
|
|
1203
|
+
pass
|
|
1204
|
+
def flush(self):
|
|
1205
|
+
pass
|
|
1206
|
+
|
|
1207
|
+
# Store original stdout, stderr and logger settings
|
|
1208
|
+
original_stdout = sys.stdout
|
|
1209
|
+
original_stderr = sys.stderr
|
|
1210
|
+
|
|
1211
|
+
# Get all relevant loggers
|
|
1212
|
+
root_logger = logging.getLogger()
|
|
1213
|
+
promptflow_logger = logging.getLogger('promptflow')
|
|
1214
|
+
azure_logger = logging.getLogger('azure')
|
|
1215
|
+
|
|
1216
|
+
# Store original levels
|
|
1217
|
+
orig_root_level = root_logger.level
|
|
1218
|
+
orig_promptflow_level = promptflow_logger.level
|
|
1219
|
+
orig_azure_level = azure_logger.level
|
|
1220
|
+
|
|
1221
|
+
# Setup a completely silent logger filter
|
|
1222
|
+
class SilentFilter(logging.Filter):
|
|
1223
|
+
def filter(self, record):
|
|
1224
|
+
return False
|
|
1225
|
+
|
|
1226
|
+
# Get original filters to restore later
|
|
1227
|
+
orig_handlers = []
|
|
1228
|
+
for handler in root_logger.handlers:
|
|
1229
|
+
orig_handlers.append((handler, handler.filters.copy(), handler.level))
|
|
1230
|
+
|
|
1231
|
+
try:
|
|
1232
|
+
# Redirect all stdout/stderr output to DevNull to completely suppress it
|
|
1233
|
+
sys.stdout = DevNull()
|
|
1234
|
+
sys.stderr = DevNull()
|
|
1235
|
+
|
|
1236
|
+
# Set all loggers to CRITICAL level to suppress most log messages
|
|
1237
|
+
root_logger.setLevel(logging.CRITICAL)
|
|
1238
|
+
promptflow_logger.setLevel(logging.CRITICAL)
|
|
1239
|
+
azure_logger.setLevel(logging.CRITICAL)
|
|
1240
|
+
|
|
1241
|
+
# Add silent filter to all handlers
|
|
1242
|
+
silent_filter = SilentFilter()
|
|
1243
|
+
for handler in root_logger.handlers:
|
|
1244
|
+
handler.addFilter(silent_filter)
|
|
1245
|
+
handler.setLevel(logging.CRITICAL)
|
|
1246
|
+
|
|
1247
|
+
# Create a file handler for any logs we actually want to keep
|
|
1248
|
+
file_log_path = os.path.join(self.scan_output_dir, "redteam.log")
|
|
1249
|
+
file_handler = logging.FileHandler(file_log_path, mode='a')
|
|
1250
|
+
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s'))
|
|
1251
|
+
|
|
1252
|
+
# Allow file handler to capture DEBUG logs
|
|
1253
|
+
file_handler.setLevel(logging.DEBUG)
|
|
1254
|
+
|
|
1255
|
+
# Setup our own minimal logger for critical events
|
|
1256
|
+
eval_logger = logging.getLogger('redteam_evaluation')
|
|
1257
|
+
eval_logger.propagate = False # Don't pass to root logger
|
|
1258
|
+
eval_logger.setLevel(logging.DEBUG)
|
|
1259
|
+
eval_logger.addHandler(file_handler)
|
|
1260
|
+
|
|
1261
|
+
# Run evaluation silently
|
|
1262
|
+
eval_logger.debug(f"Starting evaluation for {risk_category.value}/{strategy_name}")
|
|
1263
|
+
evaluate_outputs = evaluate(
|
|
1264
|
+
data=data_path,
|
|
1265
|
+
evaluators=evaluators_dict,
|
|
1266
|
+
output_path=result_path,
|
|
1267
|
+
)
|
|
1268
|
+
eval_logger.debug(f"Completed evaluation for {risk_category.value}/{strategy_name}")
|
|
1269
|
+
|
|
1270
|
+
finally:
|
|
1271
|
+
# Restore original stdout and stderr
|
|
1272
|
+
sys.stdout = original_stdout
|
|
1273
|
+
sys.stderr = original_stderr
|
|
1274
|
+
|
|
1275
|
+
# Restore original log levels
|
|
1276
|
+
root_logger.setLevel(orig_root_level)
|
|
1277
|
+
promptflow_logger.setLevel(orig_promptflow_level)
|
|
1278
|
+
azure_logger.setLevel(orig_azure_level)
|
|
1279
|
+
|
|
1280
|
+
# Restore original handlers and filters
|
|
1281
|
+
for handler, filters, level in orig_handlers:
|
|
1282
|
+
# Remove any filters we added
|
|
1283
|
+
for filter in list(handler.filters):
|
|
1284
|
+
handler.removeFilter(filter)
|
|
1285
|
+
|
|
1286
|
+
# Restore original filters
|
|
1287
|
+
for filter in filters:
|
|
1288
|
+
handler.addFilter(filter)
|
|
1289
|
+
|
|
1290
|
+
# Restore original level
|
|
1291
|
+
handler.setLevel(level)
|
|
1292
|
+
|
|
1293
|
+
# Clean up our custom logger
|
|
1294
|
+
try:
|
|
1295
|
+
if 'eval_logger' in locals() and 'file_handler' in locals():
|
|
1296
|
+
eval_logger.removeHandler(file_handler)
|
|
1297
|
+
file_handler.close()
|
|
1298
|
+
except Exception as e:
|
|
1299
|
+
self.logger.warning(f"Failed to clean up logger: {str(e)}")
|
|
1300
|
+
self.red_team_info[self._get_strategy_name(strategy)][risk_category.value]["evaluation_result_file"] = str(result_path)
|
|
1301
|
+
self.red_team_info[self._get_strategy_name(strategy)][risk_category.value]["evaluation_result"] = evaluate_outputs
|
|
1302
|
+
self.logger.debug(f"Evaluation complete for {strategy_name}/{risk_category.value}, results stored in red_team_info")
|
|
1303
|
+
|
|
1304
|
+
async def _process_attack(
|
|
1305
|
+
self,
|
|
1306
|
+
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
|
|
1307
|
+
call_orchestrator: Callable,
|
|
1308
|
+
strategy: Union[AttackStrategy, List[AttackStrategy]],
|
|
1309
|
+
risk_category: RiskCategory,
|
|
1310
|
+
all_prompts: List[str],
|
|
1311
|
+
progress_bar: tqdm,
|
|
1312
|
+
progress_bar_lock: asyncio.Lock,
|
|
1313
|
+
scan_name: Optional[str] = None,
|
|
1314
|
+
data_only: bool = False,
|
|
1315
|
+
output_path: Optional[Union[str, os.PathLike]] = None,
|
|
1316
|
+
timeout: int = 120,
|
|
1317
|
+
) -> Optional[EvaluationResult]:
|
|
1318
|
+
"""Process a red team scan with the given orchestrator, converter, and prompts.
|
|
1319
|
+
|
|
1320
|
+
:param target: The target model or function to scan
|
|
1321
|
+
:param call_orchestrator: Function to call to create an orchestrator
|
|
1322
|
+
:param strategy: The attack strategy to use
|
|
1323
|
+
:param risk_category: The risk category to evaluate
|
|
1324
|
+
:param all_prompts: List of prompts to use for the scan
|
|
1325
|
+
:param progress_bar: Progress bar to update
|
|
1326
|
+
:param progress_bar_lock: Lock for the progress bar
|
|
1327
|
+
:param scan_name: Optional name for the evaluation
|
|
1328
|
+
:param data_only: Whether to return only data without evaluation
|
|
1329
|
+
:param output_path: Optional path for output
|
|
1330
|
+
:param timeout: The timeout in seconds for API calls
|
|
1331
|
+
"""
|
|
1332
|
+
strategy_name = self._get_strategy_name(strategy)
|
|
1333
|
+
task_key = f"{strategy_name}_{risk_category.value}_attack"
|
|
1334
|
+
self.task_statuses[task_key] = TASK_STATUS["RUNNING"]
|
|
1335
|
+
|
|
1336
|
+
try:
|
|
1337
|
+
start_time = time.time()
|
|
1338
|
+
print(f"▶️ Starting task: {strategy_name} strategy for {risk_category.value} risk category")
|
|
1339
|
+
log_strategy_start(self.logger, strategy_name, risk_category.value)
|
|
1340
|
+
|
|
1341
|
+
converter = self._get_converter_for_strategy(strategy)
|
|
1342
|
+
try:
|
|
1343
|
+
self.logger.debug(f"Calling orchestrator for {strategy_name} strategy")
|
|
1344
|
+
orchestrator = await call_orchestrator(self.chat_target, all_prompts, converter, strategy_name, risk_category.value, timeout)
|
|
1345
|
+
except PyritException as e:
|
|
1346
|
+
log_error(self.logger, f"Error calling orchestrator for {strategy_name} strategy", e)
|
|
1347
|
+
print(f"❌ Orchestrator error for {strategy_name}/{risk_category.value}: {str(e)}")
|
|
1348
|
+
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
1349
|
+
self.failed_tasks += 1
|
|
1350
|
+
|
|
1351
|
+
async with progress_bar_lock:
|
|
1352
|
+
progress_bar.update(1)
|
|
1353
|
+
return None
|
|
1354
|
+
|
|
1355
|
+
data_path = self._write_pyrit_outputs_to_file(orchestrator)
|
|
1356
|
+
|
|
1357
|
+
# Store data file in our tracking dictionary
|
|
1358
|
+
self.red_team_info[strategy_name][risk_category.value]["data_file"] = data_path
|
|
1359
|
+
self.logger.debug(f"Updated red_team_info with data file: {strategy_name} -> {risk_category.value} -> {data_path}")
|
|
1360
|
+
|
|
1361
|
+
try:
|
|
1362
|
+
await self._evaluate(
|
|
1363
|
+
scan_name=scan_name,
|
|
1364
|
+
risk_category=risk_category,
|
|
1365
|
+
strategy=strategy,
|
|
1366
|
+
data_only=data_only,
|
|
1367
|
+
data_path=data_path,
|
|
1368
|
+
output_path=output_path,
|
|
1369
|
+
)
|
|
1370
|
+
except Exception as e:
|
|
1371
|
+
log_error(self.logger, f"Error during evaluation for {strategy_name}/{risk_category.value}", e)
|
|
1372
|
+
print(f"⚠️ Evaluation error for {strategy_name}/{risk_category.value}: {str(e)}")
|
|
1373
|
+
# Continue processing even if evaluation fails
|
|
1374
|
+
|
|
1375
|
+
async with progress_bar_lock:
|
|
1376
|
+
self.completed_tasks += 1
|
|
1377
|
+
progress_bar.update(1)
|
|
1378
|
+
completion_pct = (self.completed_tasks / self.total_tasks) * 100
|
|
1379
|
+
elapsed_time = time.time() - start_time
|
|
1380
|
+
|
|
1381
|
+
# Calculate estimated remaining time
|
|
1382
|
+
if self.start_time:
|
|
1383
|
+
total_elapsed = time.time() - self.start_time
|
|
1384
|
+
avg_time_per_task = total_elapsed / self.completed_tasks if self.completed_tasks > 0 else 0
|
|
1385
|
+
remaining_tasks = self.total_tasks - self.completed_tasks
|
|
1386
|
+
est_remaining_time = avg_time_per_task * remaining_tasks if avg_time_per_task > 0 else 0
|
|
1387
|
+
|
|
1388
|
+
# Print task completion message and estimated time on separate lines
|
|
1389
|
+
# This ensures they don't get concatenated with tqdm output
|
|
1390
|
+
print("") # Empty line to separate from progress bar
|
|
1391
|
+
print(f"✅ Completed task {self.completed_tasks}/{self.total_tasks} ({completion_pct:.1f}%) - {strategy_name}/{risk_category.value} in {elapsed_time:.1f}s")
|
|
1392
|
+
print(f" Est. remaining: {est_remaining_time/60:.1f} minutes")
|
|
1393
|
+
else:
|
|
1394
|
+
print("") # Empty line to separate from progress bar
|
|
1395
|
+
print(f"✅ Completed task {self.completed_tasks}/{self.total_tasks} ({completion_pct:.1f}%) - {strategy_name}/{risk_category.value} in {elapsed_time:.1f}s")
|
|
1396
|
+
|
|
1397
|
+
log_strategy_completion(self.logger, strategy_name, risk_category.value, elapsed_time)
|
|
1398
|
+
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
1399
|
+
|
|
1400
|
+
except Exception as e:
|
|
1401
|
+
log_error(self.logger, f"Unexpected error processing {strategy_name} strategy for {risk_category.value}", e)
|
|
1402
|
+
print(f"❌ Critical error in task {strategy_name}/{risk_category.value}: {str(e)}")
|
|
1403
|
+
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
1404
|
+
self.failed_tasks += 1
|
|
1405
|
+
|
|
1406
|
+
async with progress_bar_lock:
|
|
1407
|
+
progress_bar.update(1)
|
|
1408
|
+
|
|
1409
|
+
return None
|
|
1410
|
+
|
|
1411
|
+
async def scan(
|
|
1412
|
+
self,
|
|
1413
|
+
target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration, PromptChatTarget],
|
|
1414
|
+
scan_name: Optional[str] = None,
|
|
1415
|
+
num_turns : int = 1,
|
|
1416
|
+
attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]] = [],
|
|
1417
|
+
data_only: bool = False,
|
|
1418
|
+
output_path: Optional[Union[str, os.PathLike]] = None,
|
|
1419
|
+
application_scenario: Optional[str] = None,
|
|
1420
|
+
parallel_execution: bool = True,
|
|
1421
|
+
max_parallel_tasks: int = 5,
|
|
1422
|
+
debug_mode: bool = False,
|
|
1423
|
+
timeout: int = 120) -> RedTeamOutput:
|
|
1424
|
+
"""Run a red team scan against the target using the specified strategies.
|
|
1425
|
+
|
|
1426
|
+
:param target: The target model or function to scan
|
|
1427
|
+
:type target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration, PromptChatTarget]
|
|
1428
|
+
:param scan_name: Optional name for the evaluation
|
|
1429
|
+
:type scan_name: Optional[str]
|
|
1430
|
+
:param num_turns: Number of conversation turns to use in the scan
|
|
1431
|
+
:type num_turns: int
|
|
1432
|
+
:param attack_strategies: List of attack strategies to use
|
|
1433
|
+
:type attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]]
|
|
1434
|
+
:param data_only: Whether to return only data without evaluation
|
|
1435
|
+
:type data_only: bool
|
|
1436
|
+
:param output_path: Optional path for output
|
|
1437
|
+
:type output_path: Optional[Union[str, os.PathLike]]
|
|
1438
|
+
:param application_scenario: Optional description of the application scenario
|
|
1439
|
+
:type application_scenario: Optional[str]
|
|
1440
|
+
:param parallel_execution: Whether to execute orchestrator tasks in parallel
|
|
1441
|
+
:type parallel_execution: bool
|
|
1442
|
+
:param max_parallel_tasks: Maximum number of parallel orchestrator tasks to run (default: 5)
|
|
1443
|
+
:type max_parallel_tasks: int
|
|
1444
|
+
:param debug_mode: Whether to run in debug mode (more verbose output)
|
|
1445
|
+
:type debug_mode: bool
|
|
1446
|
+
:param timeout: The timeout in seconds for API calls (default: 120)
|
|
1447
|
+
:type timeout: int
|
|
1448
|
+
:return: The output from the red team scan
|
|
1449
|
+
:rtype: RedTeamOutput
|
|
1450
|
+
"""
|
|
1451
|
+
# Start timing for performance tracking
|
|
1452
|
+
self.start_time = time.time()
|
|
1453
|
+
|
|
1454
|
+
# Reset task counters and statuses
|
|
1455
|
+
self.task_statuses = {}
|
|
1456
|
+
self.completed_tasks = 0
|
|
1457
|
+
self.failed_tasks = 0
|
|
1458
|
+
|
|
1459
|
+
# Generate a unique scan ID for this run
|
|
1460
|
+
self.scan_id = f"scan_{scan_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" if scan_name else f"scan_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
1461
|
+
self.scan_id = self.scan_id.replace(" ", "_")
|
|
1462
|
+
|
|
1463
|
+
# Create output directory for this scan
|
|
1464
|
+
# If DEBUG environment variable is set, use a regular folder name; otherwise, use a hidden folder
|
|
1465
|
+
is_debug = os.environ.get("DEBUG", "").lower() in ("true", "1", "yes", "y")
|
|
1466
|
+
folder_prefix = "" if is_debug else "."
|
|
1467
|
+
self.scan_output_dir = os.path.join(self.output_dir or ".", f"{folder_prefix}{self.scan_id}")
|
|
1468
|
+
os.makedirs(self.scan_output_dir, exist_ok=True)
|
|
1469
|
+
|
|
1470
|
+
# Re-initialize logger with the scan output directory
|
|
1471
|
+
self.logger = setup_logger(output_dir=self.scan_output_dir)
|
|
1472
|
+
|
|
1473
|
+
# Set up logging filter to suppress various logs we don't want in the console
|
|
1474
|
+
class LogFilter(logging.Filter):
|
|
1475
|
+
def filter(self, record):
|
|
1476
|
+
# Filter out promptflow logs and evaluation warnings about artifacts
|
|
1477
|
+
if record.name.startswith('promptflow'):
|
|
1478
|
+
return False
|
|
1479
|
+
if 'The path to the artifact is either not a directory or does not exist' in record.getMessage():
|
|
1480
|
+
return False
|
|
1481
|
+
if 'RedTeamOutput object at' in record.getMessage():
|
|
1482
|
+
return False
|
|
1483
|
+
if 'timeout won\'t take effect' in record.getMessage():
|
|
1484
|
+
return False
|
|
1485
|
+
if 'Submitting run' in record.getMessage():
|
|
1486
|
+
return False
|
|
1487
|
+
return True
|
|
1488
|
+
|
|
1489
|
+
# Apply filter to root logger to suppress unwanted logs
|
|
1490
|
+
root_logger = logging.getLogger()
|
|
1491
|
+
log_filter = LogFilter()
|
|
1492
|
+
|
|
1493
|
+
# Remove existing filters first to avoid duplication
|
|
1494
|
+
for handler in root_logger.handlers:
|
|
1495
|
+
for filter in handler.filters:
|
|
1496
|
+
handler.removeFilter(filter)
|
|
1497
|
+
handler.addFilter(log_filter)
|
|
1498
|
+
|
|
1499
|
+
# Also set up stderr logger to use the same filter
|
|
1500
|
+
stderr_logger = logging.getLogger('stderr')
|
|
1501
|
+
for handler in stderr_logger.handlers:
|
|
1502
|
+
handler.addFilter(log_filter)
|
|
1503
|
+
|
|
1504
|
+
log_section_header(self.logger, "Starting red team scan")
|
|
1505
|
+
self.logger.info(f"Scan started with scan_name: {scan_name}")
|
|
1506
|
+
self.logger.info(f"Scan ID: {self.scan_id}")
|
|
1507
|
+
self.logger.info(f"Scan output directory: {self.scan_output_dir}")
|
|
1508
|
+
self.logger.debug(f"Attack strategies: {attack_strategies}")
|
|
1509
|
+
self.logger.debug(f"data_only: {data_only}, output_path: {output_path}")
|
|
1510
|
+
self.logger.debug(f"Timeout: {timeout} seconds")
|
|
1511
|
+
|
|
1512
|
+
# Clear, minimal output for start of scan
|
|
1513
|
+
print(f"🚀 STARTING RED TEAM SCAN: {scan_name}")
|
|
1514
|
+
print(f"📂 Output directory: {self.scan_output_dir}")
|
|
1515
|
+
self.logger.info(f"Starting RED TEAM SCAN: {scan_name}")
|
|
1516
|
+
self.logger.info(f"Output directory: {self.scan_output_dir}")
|
|
1517
|
+
|
|
1518
|
+
chat_target = self._get_chat_target(target)
|
|
1519
|
+
self.chat_target = chat_target
|
|
1520
|
+
self.application_scenario = application_scenario or ""
|
|
1521
|
+
|
|
1522
|
+
if not self.attack_objective_generator:
|
|
1523
|
+
error_msg = "Attack objective generator is required for red team agent."
|
|
1524
|
+
log_error(self.logger, error_msg)
|
|
1525
|
+
print(f"❌ {error_msg}")
|
|
1526
|
+
raise EvaluationException(
|
|
1527
|
+
message=error_msg,
|
|
1528
|
+
internal_message="Attack objective generator is not provided.",
|
|
1529
|
+
target=ErrorTarget.RED_TEAM,
|
|
1530
|
+
category=ErrorCategory.MISSING_FIELD,
|
|
1531
|
+
blame=ErrorBlame.USER_ERROR
|
|
1532
|
+
)
|
|
1533
|
+
|
|
1534
|
+
# If risk categories aren't specified, use all available categories
|
|
1535
|
+
if not self.attack_objective_generator.risk_categories:
|
|
1536
|
+
self.logger.info("No risk categories specified, using all available categories")
|
|
1537
|
+
self.attack_objective_generator.risk_categories = list(RiskCategory)
|
|
1538
|
+
|
|
1539
|
+
self.risk_categories = self.attack_objective_generator.risk_categories
|
|
1540
|
+
# Show risk categories to user
|
|
1541
|
+
print(f"📊 Risk categories: {[rc.value for rc in self.risk_categories]}")
|
|
1542
|
+
self.logger.info(f"Risk categories to process: {[rc.value for rc in self.risk_categories]}")
|
|
1543
|
+
|
|
1544
|
+
# Prepend AttackStrategy.Baseline to the attack strategy list
|
|
1545
|
+
if AttackStrategy.Baseline not in attack_strategies:
|
|
1546
|
+
attack_strategies.insert(0, AttackStrategy.Baseline)
|
|
1547
|
+
self.logger.debug("Added Baseline to attack strategies")
|
|
1548
|
+
|
|
1549
|
+
# When using custom attack objectives, check for incompatible strategies
|
|
1550
|
+
using_custom_objectives = self.attack_objective_generator and self.attack_objective_generator.custom_attack_seed_prompts
|
|
1551
|
+
if using_custom_objectives:
|
|
1552
|
+
# Maintain a list of converters to avoid duplicates
|
|
1553
|
+
used_converter_types = set()
|
|
1554
|
+
strategies_to_remove = []
|
|
1555
|
+
|
|
1556
|
+
for i, strategy in enumerate(attack_strategies):
|
|
1557
|
+
if isinstance(strategy, list):
|
|
1558
|
+
# Skip composite strategies for now
|
|
1559
|
+
continue
|
|
1560
|
+
|
|
1561
|
+
if strategy == AttackStrategy.Jailbreak:
|
|
1562
|
+
self.logger.warning("Jailbreak strategy with custom attack objectives may not work as expected. The strategy will be run, but results may vary.")
|
|
1563
|
+
print("⚠️ Warning: Jailbreak strategy with custom attack objectives may not work as expected.")
|
|
1564
|
+
|
|
1565
|
+
if strategy == AttackStrategy.Tense:
|
|
1566
|
+
self.logger.warning("Tense strategy requires specific formatting in objectives and may not work correctly with custom attack objectives.")
|
|
1567
|
+
print("⚠️ Warning: Tense strategy requires specific formatting in objectives and may not work correctly with custom attack objectives.")
|
|
1568
|
+
|
|
1569
|
+
# Check for redundant converters
|
|
1570
|
+
# TODO: should this be in flattening logic?
|
|
1571
|
+
converter = self._get_converter_for_strategy(strategy)
|
|
1572
|
+
if converter is not None:
|
|
1573
|
+
converter_type = type(converter).__name__ if not isinstance(converter, list) else ','.join([type(c).__name__ for c in converter])
|
|
1574
|
+
|
|
1575
|
+
if converter_type in used_converter_types and strategy != AttackStrategy.Baseline:
|
|
1576
|
+
self.logger.warning(f"Strategy {strategy.name} uses a converter type that has already been used. Skipping redundant strategy.")
|
|
1577
|
+
print(f"ℹ️ Skipping redundant strategy: {strategy.name} (uses same converter as another strategy)")
|
|
1578
|
+
strategies_to_remove.append(strategy)
|
|
1579
|
+
else:
|
|
1580
|
+
used_converter_types.add(converter_type)
|
|
1581
|
+
|
|
1582
|
+
# Remove redundant strategies
|
|
1583
|
+
if strategies_to_remove:
|
|
1584
|
+
attack_strategies = [s for s in attack_strategies if s not in strategies_to_remove]
|
|
1585
|
+
self.logger.info(f"Removed {len(strategies_to_remove)} redundant strategies: {[s.name for s in strategies_to_remove]}")
|
|
1586
|
+
|
|
1587
|
+
with self._start_redteam_mlflow_run(self.azure_ai_project, scan_name) as eval_run:
|
|
1588
|
+
self.ai_studio_url = _get_ai_studio_url(trace_destination=self.trace_destination, evaluation_id=eval_run.info.run_id)
|
|
1589
|
+
|
|
1590
|
+
# Show URL for tracking progress
|
|
1591
|
+
print(f"🔗 Track your red team scan in AI Foundry: {self.ai_studio_url}")
|
|
1592
|
+
self.logger.info(f"Started MLFlow run: {self.ai_studio_url}")
|
|
1593
|
+
|
|
1594
|
+
log_subsection_header(self.logger, "Setting up scan configuration")
|
|
1595
|
+
flattened_attack_strategies = self._get_flattened_attack_strategies(attack_strategies)
|
|
1596
|
+
self.logger.info(f"Using {len(flattened_attack_strategies)} attack strategies")
|
|
1597
|
+
self.logger.info(f"Found {len(flattened_attack_strategies)} attack strategies")
|
|
1598
|
+
|
|
1599
|
+
orchestrators = self._get_orchestrators_for_attack_strategies(attack_strategies)
|
|
1600
|
+
self.logger.debug(f"Selected {len(orchestrators)} orchestrators for attack strategies")
|
|
1601
|
+
|
|
1602
|
+
# Calculate total tasks: #risk_categories * #converters * #orchestrators
|
|
1603
|
+
self.total_tasks = len(self.risk_categories) * len(flattened_attack_strategies) * len(orchestrators)
|
|
1604
|
+
# Show task count for user awareness
|
|
1605
|
+
print(f"📋 Planning {self.total_tasks} total tasks")
|
|
1606
|
+
self.logger.info(f"Total tasks: {self.total_tasks} ({len(self.risk_categories)} risk categories * {len(flattened_attack_strategies)} strategies * {len(orchestrators)} orchestrators)")
|
|
1607
|
+
|
|
1608
|
+
# Initialize our tracking dictionary early with empty structures
|
|
1609
|
+
# This ensures we have a place to store results even if tasks fail
|
|
1610
|
+
self.red_team_info = {}
|
|
1611
|
+
for strategy in flattened_attack_strategies:
|
|
1612
|
+
strategy_name = self._get_strategy_name(strategy)
|
|
1613
|
+
self.red_team_info[strategy_name] = {}
|
|
1614
|
+
for risk_category in self.risk_categories:
|
|
1615
|
+
self.red_team_info[strategy_name][risk_category.value] = {
|
|
1616
|
+
"data_file": "",
|
|
1617
|
+
"evaluation_result_file": "",
|
|
1618
|
+
"evaluation_result": None,
|
|
1619
|
+
"status": TASK_STATUS["PENDING"]
|
|
1620
|
+
}
|
|
1621
|
+
|
|
1622
|
+
self.logger.debug(f"Initialized tracking dictionary with {len(self.red_team_info)} strategies")
|
|
1623
|
+
|
|
1624
|
+
# More visible progress bar with additional status
|
|
1625
|
+
progress_bar = tqdm(
|
|
1626
|
+
total=self.total_tasks,
|
|
1627
|
+
desc="Scanning: ",
|
|
1628
|
+
ncols=100,
|
|
1629
|
+
unit="scan",
|
|
1630
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
|
|
1631
|
+
)
|
|
1632
|
+
progress_bar.set_postfix({"current": "initializing"})
|
|
1633
|
+
progress_bar_lock = asyncio.Lock()
|
|
1634
|
+
|
|
1635
|
+
# Process all API calls sequentially to respect dependencies between objectives
|
|
1636
|
+
log_section_header(self.logger, "Fetching attack objectives")
|
|
1637
|
+
|
|
1638
|
+
# Log the objective source mode
|
|
1639
|
+
if using_custom_objectives:
|
|
1640
|
+
self.logger.info(f"Using custom attack objectives from {self.attack_objective_generator.custom_attack_seed_prompts}")
|
|
1641
|
+
print(f"📚 Using custom attack objectives from {self.attack_objective_generator.custom_attack_seed_prompts}")
|
|
1642
|
+
else:
|
|
1643
|
+
self.logger.info("Using attack objectives from Azure RAI service")
|
|
1644
|
+
print("📚 Using attack objectives from Azure RAI service")
|
|
1645
|
+
|
|
1646
|
+
# Dictionary to store all objectives
|
|
1647
|
+
all_objectives = {}
|
|
1648
|
+
|
|
1649
|
+
# First fetch baseline objectives for all risk categories
|
|
1650
|
+
# This is important as other strategies depend on baseline objectives
|
|
1651
|
+
self.logger.info("Fetching baseline objectives for all risk categories")
|
|
1652
|
+
for risk_category in self.risk_categories:
|
|
1653
|
+
progress_bar.set_postfix({"current": f"fetching baseline/{risk_category.value}"})
|
|
1654
|
+
self.logger.debug(f"Fetching baseline objectives for {risk_category.value}")
|
|
1655
|
+
baseline_objectives = await self._get_attack_objectives(
|
|
1656
|
+
risk_category=risk_category,
|
|
1657
|
+
application_scenario=application_scenario,
|
|
1658
|
+
strategy="baseline"
|
|
1659
|
+
)
|
|
1660
|
+
if "baseline" not in all_objectives:
|
|
1661
|
+
all_objectives["baseline"] = {}
|
|
1662
|
+
all_objectives["baseline"][risk_category.value] = baseline_objectives
|
|
1663
|
+
print(f"📝 Fetched baseline objectives for {risk_category.value}: {len(baseline_objectives)} objectives")
|
|
1664
|
+
|
|
1665
|
+
# Then fetch objectives for other strategies
|
|
1666
|
+
self.logger.info("Fetching objectives for non-baseline strategies")
|
|
1667
|
+
strategy_count = len(flattened_attack_strategies)
|
|
1668
|
+
for i, strategy in enumerate(flattened_attack_strategies):
|
|
1669
|
+
strategy_name = self._get_strategy_name(strategy)
|
|
1670
|
+
if strategy_name == "baseline":
|
|
1671
|
+
continue # Already fetched
|
|
1672
|
+
|
|
1673
|
+
print(f"🔄 Fetching objectives for strategy {i+1}/{strategy_count}: {strategy_name}")
|
|
1674
|
+
all_objectives[strategy_name] = {}
|
|
1675
|
+
|
|
1676
|
+
for risk_category in self.risk_categories:
|
|
1677
|
+
progress_bar.set_postfix({"current": f"fetching {strategy_name}/{risk_category.value}"})
|
|
1678
|
+
self.logger.debug(f"Fetching objectives for {strategy_name} strategy and {risk_category.value} risk category")
|
|
1679
|
+
|
|
1680
|
+
objectives = await self._get_attack_objectives(
|
|
1681
|
+
risk_category=risk_category,
|
|
1682
|
+
application_scenario=application_scenario,
|
|
1683
|
+
strategy=strategy_name
|
|
1684
|
+
)
|
|
1685
|
+
all_objectives[strategy_name][risk_category.value] = objectives
|
|
1686
|
+
|
|
1687
|
+
# Print status about objective count for this strategy/risk
|
|
1688
|
+
if debug_mode:
|
|
1689
|
+
print(f" - {risk_category.value}: {len(objectives)} objectives")
|
|
1690
|
+
|
|
1691
|
+
self.logger.info("Completed fetching all attack objectives")
|
|
1692
|
+
|
|
1693
|
+
log_section_header(self.logger, "Starting orchestrator processing")
|
|
1694
|
+
# Removed console output
|
|
1695
|
+
|
|
1696
|
+
# Create all tasks for parallel processing
|
|
1697
|
+
orchestrator_tasks = []
|
|
1698
|
+
combinations = list(itertools.product(orchestrators, flattened_attack_strategies, self.risk_categories))
|
|
1699
|
+
|
|
1700
|
+
for combo_idx, (call_orchestrator, strategy, risk_category) in enumerate(combinations):
|
|
1701
|
+
strategy_name = self._get_strategy_name(strategy)
|
|
1702
|
+
objectives = all_objectives[strategy_name][risk_category.value]
|
|
1703
|
+
|
|
1704
|
+
if not objectives:
|
|
1705
|
+
self.logger.warning(f"No objectives found for {strategy_name}+{risk_category.value}, skipping")
|
|
1706
|
+
print(f"⚠️ No objectives found for {strategy_name}/{risk_category.value}, skipping")
|
|
1707
|
+
self.red_team_info[strategy_name][risk_category.value]["status"] = TASK_STATUS["COMPLETED"]
|
|
1708
|
+
async with progress_bar_lock:
|
|
1709
|
+
progress_bar.update(1)
|
|
1710
|
+
continue
|
|
1711
|
+
|
|
1712
|
+
self.logger.debug(f"[{combo_idx+1}/{len(combinations)}] Creating task: {call_orchestrator.__name__} + {strategy_name} + {risk_category.value}")
|
|
1713
|
+
|
|
1714
|
+
orchestrator_tasks.append(
|
|
1715
|
+
self._process_attack(
|
|
1716
|
+
target=target,
|
|
1717
|
+
call_orchestrator=call_orchestrator,
|
|
1718
|
+
all_prompts=objectives,
|
|
1719
|
+
strategy=strategy,
|
|
1720
|
+
progress_bar=progress_bar,
|
|
1721
|
+
progress_bar_lock=progress_bar_lock,
|
|
1722
|
+
scan_name=scan_name,
|
|
1723
|
+
data_only=data_only,
|
|
1724
|
+
output_path=output_path,
|
|
1725
|
+
risk_category=risk_category,
|
|
1726
|
+
timeout=timeout
|
|
1727
|
+
)
|
|
1728
|
+
)
|
|
1729
|
+
|
|
1730
|
+
# Process tasks in parallel with optimized batching
|
|
1731
|
+
if parallel_execution and orchestrator_tasks:
|
|
1732
|
+
print(f"⚙️ Processing {len(orchestrator_tasks)} tasks in parallel (max {max_parallel_tasks} at a time)")
|
|
1733
|
+
self.logger.info(f"Processing {len(orchestrator_tasks)} tasks in parallel (max {max_parallel_tasks} at a time)")
|
|
1734
|
+
|
|
1735
|
+
# Create batches for processing
|
|
1736
|
+
for i in range(0, len(orchestrator_tasks), max_parallel_tasks):
|
|
1737
|
+
end_idx = min(i + max_parallel_tasks, len(orchestrator_tasks))
|
|
1738
|
+
batch = orchestrator_tasks[i:end_idx]
|
|
1739
|
+
progress_bar.set_postfix({"current": f"batch {i//max_parallel_tasks+1}/{math.ceil(len(orchestrator_tasks)/max_parallel_tasks)}"})
|
|
1740
|
+
self.logger.debug(f"Processing batch of {len(batch)} tasks (tasks {i+1} to {end_idx})")
|
|
1741
|
+
|
|
1742
|
+
try:
|
|
1743
|
+
# Add timeout to each batch
|
|
1744
|
+
await asyncio.wait_for(
|
|
1745
|
+
asyncio.gather(*batch),
|
|
1746
|
+
timeout=timeout * 2 # Double timeout for batches
|
|
1747
|
+
)
|
|
1748
|
+
except asyncio.TimeoutError:
|
|
1749
|
+
self.logger.warning(f"Batch {i//max_parallel_tasks+1} timed out after {timeout*2} seconds")
|
|
1750
|
+
print(f"⚠️ Batch {i//max_parallel_tasks+1} timed out, continuing with next batch")
|
|
1751
|
+
# Set task status to TIMEOUT
|
|
1752
|
+
batch_task_key = f"scan_batch_{i//max_parallel_tasks+1}"
|
|
1753
|
+
self.task_statuses[batch_task_key] = TASK_STATUS["TIMEOUT"]
|
|
1754
|
+
continue
|
|
1755
|
+
except Exception as e:
|
|
1756
|
+
log_error(self.logger, f"Error processing batch {i//max_parallel_tasks+1}", e)
|
|
1757
|
+
print(f"❌ Error in batch {i//max_parallel_tasks+1}: {str(e)}")
|
|
1758
|
+
continue
|
|
1759
|
+
else:
|
|
1760
|
+
# Sequential execution
|
|
1761
|
+
self.logger.info("Running orchestrator processing sequentially")
|
|
1762
|
+
print("⚙️ Processing tasks sequentially")
|
|
1763
|
+
for i, task in enumerate(orchestrator_tasks):
|
|
1764
|
+
progress_bar.set_postfix({"current": f"task {i+1}/{len(orchestrator_tasks)}"})
|
|
1765
|
+
self.logger.debug(f"Processing task {i+1}/{len(orchestrator_tasks)}")
|
|
1766
|
+
|
|
1767
|
+
try:
|
|
1768
|
+
# Add timeout to each task
|
|
1769
|
+
await asyncio.wait_for(task, timeout=timeout)
|
|
1770
|
+
except asyncio.TimeoutError:
|
|
1771
|
+
self.logger.warning(f"Task {i+1}/{len(orchestrator_tasks)} timed out after {timeout} seconds")
|
|
1772
|
+
print(f"⚠️ Task {i+1} timed out, continuing with next task")
|
|
1773
|
+
# Set task status to TIMEOUT
|
|
1774
|
+
task_key = f"scan_task_{i+1}"
|
|
1775
|
+
self.task_statuses[task_key] = TASK_STATUS["TIMEOUT"]
|
|
1776
|
+
continue
|
|
1777
|
+
except Exception as e:
|
|
1778
|
+
log_error(self.logger, f"Error processing task {i+1}/{len(orchestrator_tasks)}", e)
|
|
1779
|
+
print(f"❌ Error in task {i+1}: {str(e)}")
|
|
1780
|
+
continue
|
|
1781
|
+
|
|
1782
|
+
progress_bar.close()
|
|
1783
|
+
|
|
1784
|
+
# Print final status
|
|
1785
|
+
tasks_completed = sum(1 for status in self.task_statuses.values() if status == TASK_STATUS["COMPLETED"])
|
|
1786
|
+
tasks_failed = sum(1 for status in self.task_statuses.values() if status == TASK_STATUS["FAILED"])
|
|
1787
|
+
tasks_timeout = sum(1 for status in self.task_statuses.values() if status == TASK_STATUS["TIMEOUT"])
|
|
1788
|
+
|
|
1789
|
+
total_time = time.time() - self.start_time
|
|
1790
|
+
# Only log the summary to file, don't print to console
|
|
1791
|
+
self.logger.info(f"Scan Summary: Total tasks: {self.total_tasks}, Completed: {tasks_completed}, Failed: {tasks_failed}, Timeouts: {tasks_timeout}, Total time: {total_time/60:.1f} minutes")
|
|
1792
|
+
|
|
1793
|
+
# Process results
|
|
1794
|
+
log_section_header(self.logger, "Processing results")
|
|
1795
|
+
|
|
1796
|
+
# Convert results to _RedTeamResult using only red_team_info
|
|
1797
|
+
red_team_result = self._to_red_team_result()
|
|
1798
|
+
|
|
1799
|
+
# Create output with either full results or just conversations
|
|
1800
|
+
if data_only:
|
|
1801
|
+
self.logger.info("Data-only mode, creating output with just conversations")
|
|
1802
|
+
output = RedTeamOutput(redteaming_data=red_team_result["redteaming_data"])
|
|
1803
|
+
else:
|
|
1804
|
+
output = RedTeamOutput(
|
|
1805
|
+
red_team_result=red_team_result,
|
|
1806
|
+
redteaming_data=red_team_result["redteaming_data"]
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
# Log results to MLFlow
|
|
1810
|
+
self.logger.info("Logging results to MLFlow")
|
|
1811
|
+
await self._log_redteam_results_to_mlflow(
|
|
1812
|
+
redteam_output=output,
|
|
1813
|
+
eval_run=eval_run,
|
|
1814
|
+
data_only=data_only
|
|
1815
|
+
)
|
|
1816
|
+
|
|
1817
|
+
if data_only:
|
|
1818
|
+
self.logger.info("Data-only mode, returning results without evaluation")
|
|
1819
|
+
return output
|
|
1820
|
+
|
|
1821
|
+
if output_path and output.red_team_result:
|
|
1822
|
+
# Ensure output_path is an absolute path
|
|
1823
|
+
abs_output_path = output_path if os.path.isabs(output_path) else os.path.abspath(output_path)
|
|
1824
|
+
self.logger.info(f"Writing output to {abs_output_path}")
|
|
1825
|
+
_write_output(abs_output_path, output.red_team_result)
|
|
1826
|
+
|
|
1827
|
+
# Also save a copy to the scan output directory if available
|
|
1828
|
+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
1829
|
+
final_output = os.path.join(self.scan_output_dir, "final_results.json")
|
|
1830
|
+
_write_output(final_output, output.red_team_result)
|
|
1831
|
+
self.logger.info(f"Also saved a copy to {final_output}")
|
|
1832
|
+
elif output.red_team_result and hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
1833
|
+
# If no output_path was specified but we have scan_output_dir, save there
|
|
1834
|
+
final_output = os.path.join(self.scan_output_dir, "final_results.json")
|
|
1835
|
+
_write_output(final_output, output.red_team_result)
|
|
1836
|
+
self.logger.info(f"Saved results to {final_output}")
|
|
1837
|
+
|
|
1838
|
+
if output.red_team_result:
|
|
1839
|
+
self.logger.debug("Generating scorecard")
|
|
1840
|
+
scorecard = self._to_scorecard(output.red_team_result)
|
|
1841
|
+
# Store scorecard in a variable for accessing later if needed
|
|
1842
|
+
self.scorecard = scorecard
|
|
1843
|
+
|
|
1844
|
+
# Print scorecard to console for user visibility (without extra header)
|
|
1845
|
+
print(scorecard)
|
|
1846
|
+
|
|
1847
|
+
# Print URL for detailed results (once only)
|
|
1848
|
+
studio_url = output.red_team_result.get("studio_url", "")
|
|
1849
|
+
if studio_url:
|
|
1850
|
+
print(f"\nDetailed results available at:\n{studio_url}")
|
|
1851
|
+
|
|
1852
|
+
# Print the output directory path so the user can find it easily
|
|
1853
|
+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
|
|
1854
|
+
print(f"\n📂 All scan files saved to: {self.scan_output_dir}")
|
|
1855
|
+
|
|
1856
|
+
print(f"✅ Scan completed successfully!")
|
|
1857
|
+
self.logger.info("Scan completed successfully")
|
|
1858
|
+
return output
|