azure-ai-evaluation 1.0.1__py3-none-any.whl → 1.13.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/__init__.py +85 -14
- azure/ai/evaluation/_aoai/__init__.py +10 -0
- azure/ai/evaluation/_aoai/aoai_grader.py +140 -0
- azure/ai/evaluation/_aoai/label_grader.py +68 -0
- azure/ai/evaluation/_aoai/python_grader.py +86 -0
- azure/ai/evaluation/_aoai/score_model_grader.py +94 -0
- azure/ai/evaluation/_aoai/string_check_grader.py +66 -0
- azure/ai/evaluation/_aoai/text_similarity_grader.py +80 -0
- azure/ai/evaluation/_azure/__init__.py +3 -0
- azure/ai/evaluation/_azure/_clients.py +204 -0
- azure/ai/evaluation/_azure/_envs.py +207 -0
- azure/ai/evaluation/_azure/_models.py +227 -0
- azure/ai/evaluation/_azure/_token_manager.py +129 -0
- azure/ai/evaluation/_common/__init__.py +9 -1
- azure/ai/evaluation/_common/constants.py +124 -2
- azure/ai/evaluation/_common/evaluation_onedp_client.py +169 -0
- azure/ai/evaluation/_common/onedp/__init__.py +32 -0
- azure/ai/evaluation/_common/onedp/_client.py +166 -0
- azure/ai/evaluation/_common/onedp/_configuration.py +72 -0
- azure/ai/evaluation/_common/onedp/_model_base.py +1232 -0
- azure/ai/evaluation/_common/onedp/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/_serialization.py +2032 -0
- azure/ai/evaluation/_common/onedp/_types.py +21 -0
- azure/ai/evaluation/_common/onedp/_utils/__init__.py +6 -0
- azure/ai/evaluation/_common/onedp/_utils/model_base.py +1232 -0
- azure/ai/evaluation/_common/onedp/_utils/serialization.py +2032 -0
- azure/ai/evaluation/_common/onedp/_validation.py +66 -0
- azure/ai/evaluation/_common/onedp/_vendor.py +50 -0
- azure/ai/evaluation/_common/onedp/_version.py +9 -0
- azure/ai/evaluation/_common/onedp/aio/__init__.py +29 -0
- azure/ai/evaluation/_common/onedp/aio/_client.py +168 -0
- azure/ai/evaluation/_common/onedp/aio/_configuration.py +72 -0
- azure/ai/evaluation/_common/onedp/aio/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/aio/operations/__init__.py +49 -0
- azure/ai/evaluation/_common/onedp/aio/operations/_operations.py +7143 -0
- azure/ai/evaluation/_common/onedp/aio/operations/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/models/__init__.py +358 -0
- azure/ai/evaluation/_common/onedp/models/_enums.py +447 -0
- azure/ai/evaluation/_common/onedp/models/_models.py +5963 -0
- azure/ai/evaluation/_common/onedp/models/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/operations/__init__.py +49 -0
- azure/ai/evaluation/_common/onedp/operations/_operations.py +8951 -0
- azure/ai/evaluation/_common/onedp/operations/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/py.typed +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/_operations.py +34 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/__init__.py +22 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/_operations.py +29 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/__init__.py +22 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/_operations.py +29 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/operations/_operations.py +34 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/rai_service.py +578 -69
- azure/ai/evaluation/_common/raiclient/__init__.py +34 -0
- azure/ai/evaluation/_common/raiclient/_client.py +128 -0
- azure/ai/evaluation/_common/raiclient/_configuration.py +87 -0
- azure/ai/evaluation/_common/raiclient/_model_base.py +1235 -0
- azure/ai/evaluation/_common/raiclient/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/_serialization.py +2050 -0
- azure/ai/evaluation/_common/raiclient/_version.py +9 -0
- azure/ai/evaluation/_common/raiclient/aio/__init__.py +29 -0
- azure/ai/evaluation/_common/raiclient/aio/_client.py +130 -0
- azure/ai/evaluation/_common/raiclient/aio/_configuration.py +87 -0
- azure/ai/evaluation/_common/raiclient/aio/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/aio/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/raiclient/aio/operations/_operations.py +981 -0
- azure/ai/evaluation/_common/raiclient/aio/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/models/__init__.py +60 -0
- azure/ai/evaluation/_common/raiclient/models/_enums.py +18 -0
- azure/ai/evaluation/_common/raiclient/models/_models.py +651 -0
- azure/ai/evaluation/_common/raiclient/models/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/raiclient/operations/_operations.py +1238 -0
- azure/ai/evaluation/_common/raiclient/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/raiclient/py.typed +1 -0
- azure/ai/evaluation/_common/utils.py +505 -27
- azure/ai/evaluation/_constants.py +147 -0
- azure/ai/evaluation/_converters/__init__.py +3 -0
- azure/ai/evaluation/_converters/_ai_services.py +899 -0
- azure/ai/evaluation/_converters/_models.py +467 -0
- azure/ai/evaluation/_converters/_sk_services.py +495 -0
- azure/ai/evaluation/_eval_mapping.py +87 -0
- azure/ai/evaluation/_evaluate/_batch_run/__init__.py +10 -2
- azure/ai/evaluation/_evaluate/_batch_run/_run_submitter_client.py +176 -0
- azure/ai/evaluation/_evaluate/_batch_run/batch_clients.py +82 -0
- azure/ai/evaluation/_evaluate/_batch_run/code_client.py +18 -12
- azure/ai/evaluation/_evaluate/_batch_run/eval_run_context.py +19 -6
- azure/ai/evaluation/_evaluate/_batch_run/proxy_client.py +47 -22
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +18 -2
- azure/ai/evaluation/_evaluate/_eval_run.py +32 -46
- azure/ai/evaluation/_evaluate/_evaluate.py +1809 -142
- azure/ai/evaluation/_evaluate/_evaluate_aoai.py +992 -0
- azure/ai/evaluation/_evaluate/_telemetry/__init__.py +5 -90
- azure/ai/evaluation/_evaluate/_utils.py +237 -42
- azure/ai/evaluation/_evaluator_definition.py +76 -0
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +80 -28
- azure/ai/evaluation/_evaluators/_code_vulnerability/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_code_vulnerability/_code_vulnerability.py +119 -0
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +40 -4
- azure/ai/evaluation/_evaluators/_common/__init__.py +2 -0
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +430 -29
- azure/ai/evaluation/_evaluators/_common/_base_multi_eval.py +63 -0
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +269 -12
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +74 -9
- azure/ai/evaluation/_evaluators/_common/_conversation_aggregators.py +49 -0
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +73 -53
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +35 -5
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +26 -5
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +35 -5
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +34 -4
- azure/ai/evaluation/_evaluators/_document_retrieval/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_document_retrieval/_document_retrieval.py +442 -0
- azure/ai/evaluation/_evaluators/_eci/_eci.py +6 -3
- azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +97 -70
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +39 -3
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +80 -25
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +230 -20
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +30 -29
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +19 -14
- azure/ai/evaluation/_evaluators/_intent_resolution/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_intent_resolution/_intent_resolution.py +196 -0
- azure/ai/evaluation/_evaluators/_intent_resolution/intent_resolution.prompty +275 -0
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +89 -36
- azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +22 -4
- azure/ai/evaluation/_evaluators/_qa/_qa.py +94 -35
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +100 -4
- azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +154 -56
- azure/ai/evaluation/_evaluators/_response_completeness/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_response_completeness/_response_completeness.py +202 -0
- azure/ai/evaluation/_evaluators/_response_completeness/response_completeness.prompty +84 -0
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +39 -3
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +166 -26
- azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +38 -7
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +81 -85
- azure/ai/evaluation/_evaluators/_task_adherence/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_task_adherence/_task_adherence.py +226 -0
- azure/ai/evaluation/_evaluators/_task_adherence/task_adherence.prompty +101 -0
- azure/ai/evaluation/_evaluators/_task_completion/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_task_completion/_task_completion.py +177 -0
- azure/ai/evaluation/_evaluators/_task_completion/task_completion.prompty +220 -0
- azure/ai/evaluation/_evaluators/_task_navigation_efficiency/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_task_navigation_efficiency/_task_navigation_efficiency.py +384 -0
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py +298 -0
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/tool_call_accuracy.prompty +166 -0
- azure/ai/evaluation/_evaluators/_tool_call_success/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_tool_call_success/_tool_call_success.py +306 -0
- azure/ai/evaluation/_evaluators/_tool_call_success/tool_call_success.prompty +321 -0
- azure/ai/evaluation/_evaluators/_tool_input_accuracy/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_tool_input_accuracy/_tool_input_accuracy.py +263 -0
- azure/ai/evaluation/_evaluators/_tool_input_accuracy/tool_input_accuracy.prompty +76 -0
- azure/ai/evaluation/_evaluators/_tool_output_utilization/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_tool_output_utilization/_tool_output_utilization.py +225 -0
- azure/ai/evaluation/_evaluators/_tool_output_utilization/tool_output_utilization.prompty +221 -0
- azure/ai/evaluation/_evaluators/_tool_selection/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_tool_selection/_tool_selection.py +266 -0
- azure/ai/evaluation/_evaluators/_tool_selection/tool_selection.prompty +104 -0
- azure/ai/evaluation/_evaluators/_ungrounded_attributes/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_ungrounded_attributes/_ungrounded_attributes.py +102 -0
- azure/ai/evaluation/_evaluators/_xpia/xpia.py +20 -4
- azure/ai/evaluation/_exceptions.py +24 -1
- azure/ai/evaluation/_http_utils.py +7 -5
- azure/ai/evaluation/_legacy/__init__.py +3 -0
- azure/ai/evaluation/_legacy/_adapters/__init__.py +7 -0
- azure/ai/evaluation/_legacy/_adapters/_check.py +17 -0
- azure/ai/evaluation/_legacy/_adapters/_configuration.py +45 -0
- azure/ai/evaluation/_legacy/_adapters/_constants.py +10 -0
- azure/ai/evaluation/_legacy/_adapters/_errors.py +29 -0
- azure/ai/evaluation/_legacy/_adapters/_flows.py +28 -0
- azure/ai/evaluation/_legacy/_adapters/_service.py +16 -0
- azure/ai/evaluation/_legacy/_adapters/client.py +51 -0
- azure/ai/evaluation/_legacy/_adapters/entities.py +26 -0
- azure/ai/evaluation/_legacy/_adapters/tracing.py +28 -0
- azure/ai/evaluation/_legacy/_adapters/types.py +15 -0
- azure/ai/evaluation/_legacy/_adapters/utils.py +31 -0
- azure/ai/evaluation/_legacy/_batch_engine/__init__.py +9 -0
- azure/ai/evaluation/_legacy/_batch_engine/_config.py +48 -0
- azure/ai/evaluation/_legacy/_batch_engine/_engine.py +477 -0
- azure/ai/evaluation/_legacy/_batch_engine/_exceptions.py +88 -0
- azure/ai/evaluation/_legacy/_batch_engine/_openai_injector.py +132 -0
- azure/ai/evaluation/_legacy/_batch_engine/_result.py +107 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run.py +127 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run_storage.py +128 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run_submitter.py +262 -0
- azure/ai/evaluation/_legacy/_batch_engine/_status.py +25 -0
- azure/ai/evaluation/_legacy/_batch_engine/_trace.py +97 -0
- azure/ai/evaluation/_legacy/_batch_engine/_utils.py +97 -0
- azure/ai/evaluation/_legacy/_batch_engine/_utils_deprecated.py +131 -0
- azure/ai/evaluation/_legacy/_common/__init__.py +3 -0
- azure/ai/evaluation/_legacy/_common/_async_token_provider.py +117 -0
- azure/ai/evaluation/_legacy/_common/_logging.py +292 -0
- azure/ai/evaluation/_legacy/_common/_thread_pool_executor_with_context.py +17 -0
- azure/ai/evaluation/_legacy/prompty/__init__.py +36 -0
- azure/ai/evaluation/_legacy/prompty/_connection.py +119 -0
- azure/ai/evaluation/_legacy/prompty/_exceptions.py +139 -0
- azure/ai/evaluation/_legacy/prompty/_prompty.py +430 -0
- azure/ai/evaluation/_legacy/prompty/_utils.py +663 -0
- azure/ai/evaluation/_legacy/prompty/_yaml_utils.py +99 -0
- azure/ai/evaluation/_model_configurations.py +26 -0
- azure/ai/evaluation/_safety_evaluation/__init__.py +3 -0
- azure/ai/evaluation/_safety_evaluation/_generated_rai_client.py +0 -0
- azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +917 -0
- azure/ai/evaluation/_user_agent.py +32 -1
- azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +0 -4
- azure/ai/evaluation/_vendor/rouge_score/scoring.py +0 -4
- azure/ai/evaluation/_vendor/rouge_score/tokenize.py +0 -4
- azure/ai/evaluation/_version.py +2 -1
- azure/ai/evaluation/red_team/__init__.py +22 -0
- azure/ai/evaluation/red_team/_agent/__init__.py +3 -0
- azure/ai/evaluation/red_team/_agent/_agent_functions.py +261 -0
- azure/ai/evaluation/red_team/_agent/_agent_tools.py +461 -0
- azure/ai/evaluation/red_team/_agent/_agent_utils.py +89 -0
- azure/ai/evaluation/red_team/_agent/_semantic_kernel_plugin.py +228 -0
- azure/ai/evaluation/red_team/_attack_objective_generator.py +268 -0
- azure/ai/evaluation/red_team/_attack_strategy.py +49 -0
- azure/ai/evaluation/red_team/_callback_chat_target.py +115 -0
- azure/ai/evaluation/red_team/_default_converter.py +21 -0
- azure/ai/evaluation/red_team/_evaluation_processor.py +505 -0
- azure/ai/evaluation/red_team/_mlflow_integration.py +430 -0
- azure/ai/evaluation/red_team/_orchestrator_manager.py +803 -0
- azure/ai/evaluation/red_team/_red_team.py +1717 -0
- azure/ai/evaluation/red_team/_red_team_result.py +661 -0
- azure/ai/evaluation/red_team/_result_processor.py +1708 -0
- azure/ai/evaluation/red_team/_utils/__init__.py +37 -0
- azure/ai/evaluation/red_team/_utils/_rai_service_eval_chat_target.py +128 -0
- azure/ai/evaluation/red_team/_utils/_rai_service_target.py +601 -0
- azure/ai/evaluation/red_team/_utils/_rai_service_true_false_scorer.py +114 -0
- azure/ai/evaluation/red_team/_utils/constants.py +72 -0
- azure/ai/evaluation/red_team/_utils/exception_utils.py +345 -0
- azure/ai/evaluation/red_team/_utils/file_utils.py +266 -0
- azure/ai/evaluation/red_team/_utils/formatting_utils.py +365 -0
- azure/ai/evaluation/red_team/_utils/logging_utils.py +139 -0
- azure/ai/evaluation/red_team/_utils/metric_mapping.py +73 -0
- azure/ai/evaluation/red_team/_utils/objective_utils.py +46 -0
- azure/ai/evaluation/red_team/_utils/progress_utils.py +252 -0
- azure/ai/evaluation/red_team/_utils/retry_utils.py +218 -0
- azure/ai/evaluation/red_team/_utils/strategy_utils.py +218 -0
- azure/ai/evaluation/simulator/_adversarial_scenario.py +6 -0
- azure/ai/evaluation/simulator/_adversarial_simulator.py +187 -80
- azure/ai/evaluation/simulator/_constants.py +1 -0
- azure/ai/evaluation/simulator/_conversation/__init__.py +138 -11
- azure/ai/evaluation/simulator/_conversation/_conversation.py +6 -2
- azure/ai/evaluation/simulator/_conversation/constants.py +1 -1
- azure/ai/evaluation/simulator/_direct_attack_simulator.py +37 -24
- azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py +1 -0
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +56 -28
- azure/ai/evaluation/simulator/_model_tools/__init__.py +2 -1
- azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +225 -0
- azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +12 -10
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +100 -45
- azure/ai/evaluation/simulator/_model_tools/_rai_client.py +101 -3
- azure/ai/evaluation/simulator/_model_tools/_template_handler.py +31 -11
- azure/ai/evaluation/simulator/_model_tools/models.py +20 -17
- azure/ai/evaluation/simulator/_simulator.py +43 -19
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.13.5.dist-info}/METADATA +378 -27
- azure_ai_evaluation-1.13.5.dist-info/RECORD +305 -0
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.13.5.dist-info}/WHEEL +1 -1
- azure/ai/evaluation/_evaluators/_multimodal/__init__.py +0 -20
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +0 -132
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +0 -55
- azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +0 -124
- azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_violence.py +0 -100
- azure/ai/evaluation/simulator/_tracing.py +0 -89
- azure_ai_evaluation-1.0.1.dist-info/RECORD +0 -119
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.13.5.dist-info/licenses}/NOTICE.txt +0 -0
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.13.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1717 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
# Third-party imports
|
|
5
|
+
import asyncio
|
|
6
|
+
import itertools
|
|
7
|
+
import logging
|
|
8
|
+
import math
|
|
9
|
+
import os
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
import random
|
|
12
|
+
import time
|
|
13
|
+
import uuid
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from typing import Callable, Dict, List, Optional, Union, cast, Any
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
# Azure AI Evaluation imports
|
|
19
|
+
from azure.ai.evaluation._constants import TokenScope
|
|
20
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
21
|
+
|
|
22
|
+
from azure.ai.evaluation._evaluate._evaluate import (
|
|
23
|
+
emit_eval_result_events_to_app_insights,
|
|
24
|
+
) # TODO: uncomment when app insights checked in
|
|
25
|
+
from azure.ai.evaluation._model_configurations import EvaluationResult
|
|
26
|
+
from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager
|
|
27
|
+
from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient
|
|
28
|
+
from azure.ai.evaluation._user_agent import UserAgentSingleton
|
|
29
|
+
from azure.ai.evaluation._model_configurations import (
|
|
30
|
+
AzureOpenAIModelConfiguration,
|
|
31
|
+
OpenAIModelConfiguration,
|
|
32
|
+
)
|
|
33
|
+
from azure.ai.evaluation._exceptions import (
|
|
34
|
+
ErrorBlame,
|
|
35
|
+
ErrorCategory,
|
|
36
|
+
ErrorTarget,
|
|
37
|
+
EvaluationException,
|
|
38
|
+
)
|
|
39
|
+
from azure.ai.evaluation._common.utils import (
|
|
40
|
+
validate_azure_ai_project,
|
|
41
|
+
is_onedp_project,
|
|
42
|
+
)
|
|
43
|
+
from azure.ai.evaluation._evaluate._utils import _write_output
|
|
44
|
+
|
|
45
|
+
# Azure Core imports
|
|
46
|
+
from azure.core.credentials import TokenCredential
|
|
47
|
+
|
|
48
|
+
# Red Teaming imports
|
|
49
|
+
from ._red_team_result import RedTeamResult
|
|
50
|
+
from ._attack_strategy import AttackStrategy
|
|
51
|
+
from ._attack_objective_generator import (
|
|
52
|
+
RiskCategory,
|
|
53
|
+
SupportedLanguages,
|
|
54
|
+
_AttackObjectiveGenerator,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# PyRIT imports
|
|
58
|
+
from pyrit.common import initialize_pyrit, DUCK_DB
|
|
59
|
+
from pyrit.prompt_target import PromptChatTarget
|
|
60
|
+
|
|
61
|
+
# Local imports - constants and utilities
|
|
62
|
+
from ._utils.constants import TASK_STATUS, MAX_SAMPLING_ITERATIONS_MULTIPLIER, RISK_TO_NUM_SUBTYPE_MAP
|
|
63
|
+
from ._utils.logging_utils import (
|
|
64
|
+
setup_logger,
|
|
65
|
+
log_section_header,
|
|
66
|
+
log_subsection_header,
|
|
67
|
+
)
|
|
68
|
+
from ._utils.formatting_utils import (
|
|
69
|
+
get_strategy_name,
|
|
70
|
+
get_flattened_attack_strategies,
|
|
71
|
+
write_pyrit_outputs_to_file,
|
|
72
|
+
format_scorecard,
|
|
73
|
+
format_content_by_modality,
|
|
74
|
+
)
|
|
75
|
+
from ._utils.strategy_utils import get_chat_target, get_converter_for_strategy
|
|
76
|
+
from ._utils.retry_utils import create_standard_retry_manager
|
|
77
|
+
from ._utils.file_utils import create_file_manager
|
|
78
|
+
from ._utils.metric_mapping import get_attack_objective_from_risk_category
|
|
79
|
+
from ._utils.objective_utils import extract_risk_subtype, get_objective_id
|
|
80
|
+
|
|
81
|
+
from ._orchestrator_manager import OrchestratorManager
|
|
82
|
+
from ._evaluation_processor import EvaluationProcessor
|
|
83
|
+
from ._mlflow_integration import MLflowIntegration
|
|
84
|
+
from ._result_processor import ResultProcessor
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@experimental
|
|
88
|
+
class RedTeam:
|
|
89
|
+
"""
|
|
90
|
+
This class uses various attack strategies to test the robustness of AI models against adversarial inputs.
|
|
91
|
+
It logs the results of these evaluations and provides detailed scorecards summarizing the attack success rates.
|
|
92
|
+
|
|
93
|
+
:param azure_ai_project: The Azure AI project, which can either be a string representing the project endpoint
|
|
94
|
+
or an instance of AzureAIProject. It contains subscription id, resource group, and project name.
|
|
95
|
+
:type azure_ai_project: Union[str, ~azure.ai.evaluation.AzureAIProject]
|
|
96
|
+
:param credential: The credential to authenticate with Azure services
|
|
97
|
+
:type credential: TokenCredential
|
|
98
|
+
:param risk_categories: List of risk categories to generate attack objectives for (optional if custom_attack_seed_prompts is provided)
|
|
99
|
+
:type risk_categories: Optional[List[RiskCategory]]
|
|
100
|
+
:param num_objectives: Number of objectives to generate per risk category
|
|
101
|
+
:type num_objectives: int
|
|
102
|
+
:param application_scenario: Description of the application scenario for context
|
|
103
|
+
:type application_scenario: Optional[str]
|
|
104
|
+
:param custom_attack_seed_prompts: Path to a JSON file containing custom attack seed prompts (can be absolute or relative path)
|
|
105
|
+
:type custom_attack_seed_prompts: Optional[str]
|
|
106
|
+
:param language: Language to use for attack objectives generation. Defaults to English.
|
|
107
|
+
:type language: SupportedLanguages
|
|
108
|
+
:param output_dir: Directory to save output files (optional)
|
|
109
|
+
:type output_dir: Optional[str]
|
|
110
|
+
:param attack_success_thresholds: Threshold configuration for determining attack success.
|
|
111
|
+
Should be a dictionary mapping risk categories (RiskCategory enum values) to threshold values,
|
|
112
|
+
or None to use default binary evaluation (evaluation results determine success).
|
|
113
|
+
When using thresholds, scores >= threshold are considered successful attacks.
|
|
114
|
+
:type attack_success_thresholds: Optional[Dict[RiskCategory, int]]
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
azure_ai_project: Union[dict, str],
|
|
120
|
+
credential,
|
|
121
|
+
*,
|
|
122
|
+
risk_categories: Optional[List[RiskCategory]] = None,
|
|
123
|
+
num_objectives: int = 10,
|
|
124
|
+
application_scenario: Optional[str] = None,
|
|
125
|
+
custom_attack_seed_prompts: Optional[str] = None,
|
|
126
|
+
language: SupportedLanguages = SupportedLanguages.English,
|
|
127
|
+
output_dir=".",
|
|
128
|
+
attack_success_thresholds: Optional[Dict[RiskCategory, int]] = None,
|
|
129
|
+
):
|
|
130
|
+
"""Initialize a new Red Team agent for AI model evaluation.
|
|
131
|
+
|
|
132
|
+
Creates a Red Team agent instance configured with the specified parameters.
|
|
133
|
+
This initializes the token management, attack objective generation, and logging
|
|
134
|
+
needed for running red team evaluations against AI models.
|
|
135
|
+
|
|
136
|
+
:param azure_ai_project: The Azure AI project, which can either be a string representing the project endpoint
|
|
137
|
+
or an instance of AzureAIProject. It contains subscription id, resource group, and project name.
|
|
138
|
+
:type azure_ai_project: Union[str, ~azure.ai.evaluation.AzureAIProject]
|
|
139
|
+
:param credential: Authentication credential for Azure services
|
|
140
|
+
:type credential: TokenCredential
|
|
141
|
+
:param risk_categories: List of risk categories to test (required unless custom prompts provided)
|
|
142
|
+
:type risk_categories: Optional[List[RiskCategory]]
|
|
143
|
+
:param num_objectives: Number of attack objectives to generate per risk category
|
|
144
|
+
:type num_objectives: int
|
|
145
|
+
:param application_scenario: Description of the application scenario
|
|
146
|
+
:type application_scenario: Optional[str]
|
|
147
|
+
:param custom_attack_seed_prompts: Path to a JSON file with custom attack prompts
|
|
148
|
+
:type custom_attack_seed_prompts: Optional[str]
|
|
149
|
+
:param language: Language to use for attack objectives generation. Defaults to English.
|
|
150
|
+
:type language: SupportedLanguages
|
|
151
|
+
:param output_dir: Directory to save evaluation outputs and logs. Defaults to current working directory.
|
|
152
|
+
:type output_dir: str
|
|
153
|
+
:param attack_success_thresholds: Threshold configuration for determining attack success.
|
|
154
|
+
Should be a dictionary mapping risk categories (RiskCategory enum values) to threshold values,
|
|
155
|
+
or None to use default binary evaluation (evaluation results determine success).
|
|
156
|
+
When using thresholds, scores >= threshold are considered successful attacks.
|
|
157
|
+
:type attack_success_thresholds: Optional[Dict[RiskCategory, int]]
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
|
|
161
|
+
self.credential = credential
|
|
162
|
+
self.output_dir = output_dir
|
|
163
|
+
self.language = language
|
|
164
|
+
self._one_dp_project = is_onedp_project(azure_ai_project)
|
|
165
|
+
|
|
166
|
+
# Configure attack success thresholds
|
|
167
|
+
self.attack_success_thresholds = self._configure_attack_success_thresholds(attack_success_thresholds)
|
|
168
|
+
|
|
169
|
+
# Initialize basic logger without file handler (will be properly set up during scan)
|
|
170
|
+
self.logger = logging.getLogger("RedTeamLogger")
|
|
171
|
+
self.logger.setLevel(logging.DEBUG)
|
|
172
|
+
|
|
173
|
+
# Only add console handler for now - file handler will be added during scan setup
|
|
174
|
+
if not any(isinstance(h, logging.StreamHandler) for h in self.logger.handlers):
|
|
175
|
+
console_handler = logging.StreamHandler()
|
|
176
|
+
console_handler.setLevel(logging.WARNING)
|
|
177
|
+
console_formatter = logging.Formatter("%(levelname)s - %(message)s")
|
|
178
|
+
console_handler.setFormatter(console_formatter)
|
|
179
|
+
self.logger.addHandler(console_handler)
|
|
180
|
+
|
|
181
|
+
if not self._one_dp_project:
|
|
182
|
+
self.token_manager = ManagedIdentityAPITokenManager(
|
|
183
|
+
token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
|
|
184
|
+
logger=logging.getLogger("RedTeamLogger"),
|
|
185
|
+
credential=cast(TokenCredential, credential),
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
self.token_manager = ManagedIdentityAPITokenManager(
|
|
189
|
+
token_scope=TokenScope.COGNITIVE_SERVICES_MANAGEMENT,
|
|
190
|
+
logger=logging.getLogger("RedTeamLogger"),
|
|
191
|
+
credential=cast(TokenCredential, credential),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Initialize task tracking
|
|
195
|
+
self.task_statuses = {}
|
|
196
|
+
self.total_tasks = 0
|
|
197
|
+
self.completed_tasks = 0
|
|
198
|
+
self.failed_tasks = 0
|
|
199
|
+
self.start_time = None
|
|
200
|
+
self.scan_id = None
|
|
201
|
+
self.scan_session_id = None
|
|
202
|
+
self.scan_output_dir = None
|
|
203
|
+
|
|
204
|
+
# Initialize RAI client
|
|
205
|
+
self.generated_rai_client = GeneratedRAIClient(
|
|
206
|
+
azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.credential
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Initialize a cache for attack objectives by risk category and strategy
|
|
210
|
+
self.attack_objectives = {}
|
|
211
|
+
|
|
212
|
+
# Keep track of data and eval result file names
|
|
213
|
+
self.red_team_info = {}
|
|
214
|
+
|
|
215
|
+
# keep track of prompt content to context mapping for evaluation
|
|
216
|
+
self.prompt_to_context = {}
|
|
217
|
+
|
|
218
|
+
# keep track of prompt content to risk_sub_type mapping for evaluation
|
|
219
|
+
self.prompt_to_risk_subtype = {}
|
|
220
|
+
|
|
221
|
+
# Initialize PyRIT
|
|
222
|
+
initialize_pyrit(memory_db_type=DUCK_DB)
|
|
223
|
+
|
|
224
|
+
# Initialize attack objective generator
|
|
225
|
+
self.attack_objective_generator = _AttackObjectiveGenerator(
|
|
226
|
+
risk_categories=risk_categories,
|
|
227
|
+
num_objectives=num_objectives,
|
|
228
|
+
application_scenario=application_scenario,
|
|
229
|
+
custom_attack_seed_prompts=custom_attack_seed_prompts,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
# Initialize component managers (will be set up during scan)
|
|
233
|
+
self.orchestrator_manager = None
|
|
234
|
+
self.evaluation_processor = None
|
|
235
|
+
self.mlflow_integration = None
|
|
236
|
+
self.result_processor = None
|
|
237
|
+
|
|
238
|
+
# Initialize utility managers
|
|
239
|
+
self.retry_manager = create_standard_retry_manager(logger=self.logger)
|
|
240
|
+
self.file_manager = create_file_manager(base_output_dir=self.output_dir, logger=self.logger)
|
|
241
|
+
|
|
242
|
+
self.logger.debug("RedTeam initialized successfully")
|
|
243
|
+
|
|
244
|
+
def _configure_attack_success_thresholds(
|
|
245
|
+
self, attack_success_thresholds: Optional[Dict[RiskCategory, int]]
|
|
246
|
+
) -> Dict[str, int]:
|
|
247
|
+
"""Configure attack success thresholds for different risk categories."""
|
|
248
|
+
if attack_success_thresholds is None:
|
|
249
|
+
return {}
|
|
250
|
+
|
|
251
|
+
def validate_threshold(value: int, key: str = "threshold") -> None:
|
|
252
|
+
if not isinstance(value, int) or not (0 <= value <= 7):
|
|
253
|
+
raise ValueError(f"Attack success {key} must be an integer between 0 and 7, got: {value}")
|
|
254
|
+
|
|
255
|
+
configured_thresholds = {}
|
|
256
|
+
|
|
257
|
+
if not isinstance(attack_success_thresholds, dict):
|
|
258
|
+
raise ValueError(
|
|
259
|
+
f"attack_success_thresholds must be a dictionary mapping RiskCategory instances to thresholds, or None. Got: {type(attack_success_thresholds)}"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Per-category thresholds
|
|
263
|
+
for key, value in attack_success_thresholds.items():
|
|
264
|
+
validate_threshold(value, f"threshold for {key}")
|
|
265
|
+
|
|
266
|
+
# Normalize the key to string format
|
|
267
|
+
if hasattr(key, "value"):
|
|
268
|
+
category_key = key.value
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError(f"attack_success_thresholds keys must be RiskCategory instance, got: {type(key)}")
|
|
271
|
+
|
|
272
|
+
configured_thresholds[category_key] = value
|
|
273
|
+
|
|
274
|
+
return configured_thresholds
|
|
275
|
+
|
|
276
|
+
def _setup_component_managers(self):
|
|
277
|
+
"""Initialize component managers with shared configuration."""
|
|
278
|
+
retry_config = self.retry_manager.get_retry_config()
|
|
279
|
+
|
|
280
|
+
# Initialize orchestrator manager
|
|
281
|
+
self.orchestrator_manager = OrchestratorManager(
|
|
282
|
+
logger=self.logger,
|
|
283
|
+
generated_rai_client=self.generated_rai_client,
|
|
284
|
+
credential=self.credential,
|
|
285
|
+
azure_ai_project=self.azure_ai_project,
|
|
286
|
+
one_dp_project=self._one_dp_project,
|
|
287
|
+
retry_config=retry_config,
|
|
288
|
+
scan_output_dir=self.scan_output_dir,
|
|
289
|
+
red_team=self,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Initialize evaluation processor
|
|
293
|
+
self.evaluation_processor = EvaluationProcessor(
|
|
294
|
+
logger=self.logger,
|
|
295
|
+
azure_ai_project=self.azure_ai_project,
|
|
296
|
+
credential=self.credential,
|
|
297
|
+
attack_success_thresholds=self.attack_success_thresholds,
|
|
298
|
+
retry_config=retry_config,
|
|
299
|
+
scan_session_id=self.scan_session_id,
|
|
300
|
+
scan_output_dir=self.scan_output_dir,
|
|
301
|
+
taxonomy_risk_categories=getattr(self, "taxonomy_risk_categories", None),
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Initialize MLflow integration
|
|
305
|
+
self.mlflow_integration = MLflowIntegration(
|
|
306
|
+
logger=self.logger,
|
|
307
|
+
azure_ai_project=self.azure_ai_project,
|
|
308
|
+
generated_rai_client=self.generated_rai_client,
|
|
309
|
+
one_dp_project=self._one_dp_project,
|
|
310
|
+
scan_output_dir=self.scan_output_dir,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Initialize result processor
|
|
314
|
+
self.result_processor = ResultProcessor(
|
|
315
|
+
logger=self.logger,
|
|
316
|
+
attack_success_thresholds=self.attack_success_thresholds,
|
|
317
|
+
application_scenario=getattr(self, "application_scenario", ""),
|
|
318
|
+
risk_categories=getattr(self, "risk_categories", []),
|
|
319
|
+
ai_studio_url=getattr(self.mlflow_integration, "ai_studio_url", None),
|
|
320
|
+
mlflow_integration=self.mlflow_integration,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
async def _get_attack_objectives(
|
|
324
|
+
self,
|
|
325
|
+
risk_category: Optional[RiskCategory] = None,
|
|
326
|
+
application_scenario: Optional[str] = None,
|
|
327
|
+
strategy: Optional[str] = None,
|
|
328
|
+
is_agent_target: Optional[bool] = None,
|
|
329
|
+
client_id: Optional[str] = None,
|
|
330
|
+
) -> List[str]:
|
|
331
|
+
"""Get attack objectives from the RAI client for a specific risk category or from a custom dataset.
|
|
332
|
+
|
|
333
|
+
Retrieves attack objectives based on the provided risk category and strategy. These objectives
|
|
334
|
+
can come from either the RAI service or from custom attack seed prompts if provided. The function
|
|
335
|
+
handles different strategies, including special handling for jailbreak strategy which requires
|
|
336
|
+
applying prefixes to messages. It also maintains a cache of objectives to ensure consistency
|
|
337
|
+
across different strategies for the same risk category.
|
|
338
|
+
|
|
339
|
+
:param risk_category: The specific risk category to get objectives for
|
|
340
|
+
:type risk_category: Optional[RiskCategory]
|
|
341
|
+
:param application_scenario: Optional description of the application scenario for context
|
|
342
|
+
:type application_scenario: Optional[str]
|
|
343
|
+
:param strategy: Optional attack strategy to get specific objectives for
|
|
344
|
+
:type strategy: Optional[str]
|
|
345
|
+
:param is_agent_target: Optional boolean indicating if target is an agent (True) or model (False)
|
|
346
|
+
:type is_agent_target: Optional[bool]
|
|
347
|
+
:return: A list of attack objective prompts
|
|
348
|
+
:rtype: List[str]
|
|
349
|
+
"""
|
|
350
|
+
attack_objective_generator = self.attack_objective_generator
|
|
351
|
+
|
|
352
|
+
# Convert risk category to lowercase for consistent caching
|
|
353
|
+
risk_cat_value = get_attack_objective_from_risk_category(risk_category).lower()
|
|
354
|
+
num_objectives = attack_objective_generator.num_objectives
|
|
355
|
+
|
|
356
|
+
# Calculate num_objectives_with_subtypes based on max subtypes across all risk categories
|
|
357
|
+
# Use attack_objective_generator.risk_categories as self.risk_categories may not be set yet
|
|
358
|
+
risk_categories = getattr(self, "risk_categories", None) or attack_objective_generator.risk_categories
|
|
359
|
+
max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in risk_categories), default=0)
|
|
360
|
+
num_objectives_with_subtypes = max(num_objectives, max_num_subtypes)
|
|
361
|
+
|
|
362
|
+
self.logger.debug(
|
|
363
|
+
f"Calculated num_objectives_with_subtypes for {risk_cat_value}: "
|
|
364
|
+
f"max(num_objectives={num_objectives}, max_subtypes={max_num_subtypes}) = {num_objectives_with_subtypes}"
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
log_subsection_header(
|
|
368
|
+
self.logger,
|
|
369
|
+
f"Getting attack objectives for {risk_cat_value}, strategy: {strategy}, num_objectives: {num_objectives}, num_objectives_with_subtypes: {num_objectives_with_subtypes}",
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Check if we already have baseline objectives for this risk category
|
|
373
|
+
baseline_key = ((risk_cat_value,), "baseline")
|
|
374
|
+
baseline_objectives_exist = baseline_key in self.attack_objectives
|
|
375
|
+
current_key = ((risk_cat_value,), strategy)
|
|
376
|
+
|
|
377
|
+
# Check if custom attack seed prompts are provided in the generator
|
|
378
|
+
if attack_objective_generator.custom_attack_seed_prompts and attack_objective_generator.validated_prompts:
|
|
379
|
+
# Check if this specific risk category has custom objectives
|
|
380
|
+
custom_objectives = attack_objective_generator.valid_prompts_by_category.get(risk_cat_value, [])
|
|
381
|
+
|
|
382
|
+
if custom_objectives:
|
|
383
|
+
# Use custom objectives for this risk category
|
|
384
|
+
return await self._get_custom_attack_objectives(
|
|
385
|
+
risk_cat_value, num_objectives, num_objectives_with_subtypes, strategy, current_key, is_agent_target
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
# No custom objectives for this risk category, but risk_categories was specified
|
|
389
|
+
# Fetch from service if this risk category is in the requested list
|
|
390
|
+
if (
|
|
391
|
+
self.attack_objective_generator.risk_categories
|
|
392
|
+
and risk_category in self.attack_objective_generator.risk_categories
|
|
393
|
+
):
|
|
394
|
+
self.logger.info(
|
|
395
|
+
f"No custom objectives found for risk category {risk_cat_value}, fetching from service"
|
|
396
|
+
)
|
|
397
|
+
return await self._get_rai_attack_objectives(
|
|
398
|
+
risk_category,
|
|
399
|
+
risk_cat_value,
|
|
400
|
+
application_scenario,
|
|
401
|
+
strategy,
|
|
402
|
+
baseline_objectives_exist,
|
|
403
|
+
baseline_key,
|
|
404
|
+
current_key,
|
|
405
|
+
num_objectives,
|
|
406
|
+
num_objectives_with_subtypes,
|
|
407
|
+
is_agent_target,
|
|
408
|
+
client_id,
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
# Risk category not in requested list, return empty
|
|
412
|
+
self.logger.warning(
|
|
413
|
+
f"No custom objectives found for risk category {risk_cat_value} and it's not in the requested risk categories"
|
|
414
|
+
)
|
|
415
|
+
return []
|
|
416
|
+
else:
|
|
417
|
+
return await self._get_rai_attack_objectives(
|
|
418
|
+
risk_category,
|
|
419
|
+
risk_cat_value,
|
|
420
|
+
application_scenario,
|
|
421
|
+
strategy,
|
|
422
|
+
baseline_objectives_exist,
|
|
423
|
+
baseline_key,
|
|
424
|
+
current_key,
|
|
425
|
+
num_objectives,
|
|
426
|
+
num_objectives_with_subtypes,
|
|
427
|
+
is_agent_target,
|
|
428
|
+
client_id,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
async def _get_custom_attack_objectives(
|
|
432
|
+
self,
|
|
433
|
+
risk_cat_value: str,
|
|
434
|
+
num_objectives: int,
|
|
435
|
+
num_objectives_with_subtypes: int,
|
|
436
|
+
strategy: str,
|
|
437
|
+
current_key: tuple,
|
|
438
|
+
is_agent_target: Optional[bool] = None,
|
|
439
|
+
) -> List[str]:
|
|
440
|
+
"""Get attack objectives from custom seed prompts."""
|
|
441
|
+
attack_objective_generator = self.attack_objective_generator
|
|
442
|
+
|
|
443
|
+
self.logger.info(
|
|
444
|
+
f"Using custom attack seed prompts from {attack_objective_generator.custom_attack_seed_prompts}"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Get the prompts for this risk category
|
|
448
|
+
custom_objectives = attack_objective_generator.valid_prompts_by_category.get(risk_cat_value, [])
|
|
449
|
+
|
|
450
|
+
if not custom_objectives:
|
|
451
|
+
self.logger.warning(f"No custom objectives found for risk category {risk_cat_value}")
|
|
452
|
+
return []
|
|
453
|
+
|
|
454
|
+
self.logger.info(f"Found {len(custom_objectives)} custom objectives for {risk_cat_value}")
|
|
455
|
+
|
|
456
|
+
# Deduplicate objectives by ID to avoid selecting the same logical objective multiple times
|
|
457
|
+
seen_ids = set()
|
|
458
|
+
deduplicated_objectives = []
|
|
459
|
+
for obj in custom_objectives:
|
|
460
|
+
obj_id = get_objective_id(obj)
|
|
461
|
+
if obj_id not in seen_ids:
|
|
462
|
+
seen_ids.add(obj_id)
|
|
463
|
+
deduplicated_objectives.append(obj)
|
|
464
|
+
|
|
465
|
+
if len(deduplicated_objectives) < len(custom_objectives):
|
|
466
|
+
self.logger.debug(
|
|
467
|
+
f"Deduplicated {len(custom_objectives)} objectives to {len(deduplicated_objectives)} unique objectives by ID"
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
# Group objectives by risk_subtype if present
|
|
471
|
+
objectives_by_subtype = {}
|
|
472
|
+
objectives_without_subtype = []
|
|
473
|
+
|
|
474
|
+
for obj in deduplicated_objectives:
|
|
475
|
+
risk_subtype = extract_risk_subtype(obj)
|
|
476
|
+
|
|
477
|
+
if risk_subtype:
|
|
478
|
+
if risk_subtype not in objectives_by_subtype:
|
|
479
|
+
objectives_by_subtype[risk_subtype] = []
|
|
480
|
+
objectives_by_subtype[risk_subtype].append(obj)
|
|
481
|
+
else:
|
|
482
|
+
objectives_without_subtype.append(obj)
|
|
483
|
+
|
|
484
|
+
# Determine sampling strategy based on risk_subtype presence
|
|
485
|
+
# Use num_objectives_with_subtypes for initial sampling to ensure coverage
|
|
486
|
+
if objectives_by_subtype:
|
|
487
|
+
# We have risk subtypes - sample evenly across them
|
|
488
|
+
num_subtypes = len(objectives_by_subtype)
|
|
489
|
+
objectives_per_subtype = max(1, num_objectives_with_subtypes // num_subtypes)
|
|
490
|
+
|
|
491
|
+
self.logger.info(
|
|
492
|
+
f"Found {num_subtypes} risk subtypes in custom objectives. "
|
|
493
|
+
f"Sampling {objectives_per_subtype} objectives per subtype to reach ~{num_objectives_with_subtypes} total."
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
selected_cat_objectives = []
|
|
497
|
+
for subtype, subtype_objectives in objectives_by_subtype.items():
|
|
498
|
+
num_to_sample = min(objectives_per_subtype, len(subtype_objectives))
|
|
499
|
+
sampled = random.sample(subtype_objectives, num_to_sample)
|
|
500
|
+
selected_cat_objectives.extend(sampled)
|
|
501
|
+
self.logger.debug(
|
|
502
|
+
f"Sampled {num_to_sample} objectives from risk_subtype '{subtype}' "
|
|
503
|
+
f"({len(subtype_objectives)} available)"
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# If we need more objectives to reach num_objectives_with_subtypes, sample from objectives without subtype
|
|
507
|
+
if len(selected_cat_objectives) < num_objectives_with_subtypes and objectives_without_subtype:
|
|
508
|
+
remaining = num_objectives_with_subtypes - len(selected_cat_objectives)
|
|
509
|
+
num_to_sample = min(remaining, len(objectives_without_subtype))
|
|
510
|
+
selected_cat_objectives.extend(random.sample(objectives_without_subtype, num_to_sample))
|
|
511
|
+
self.logger.debug(f"Added {num_to_sample} objectives without risk_subtype to reach target count")
|
|
512
|
+
|
|
513
|
+
# If we still need more, round-robin through subtypes again
|
|
514
|
+
if len(selected_cat_objectives) < num_objectives_with_subtypes:
|
|
515
|
+
remaining = num_objectives_with_subtypes - len(selected_cat_objectives)
|
|
516
|
+
subtype_list = list(objectives_by_subtype.keys())
|
|
517
|
+
# Track selected objective IDs in a set for O(1) membership checks
|
|
518
|
+
# Use the objective's 'id' field if available, generate UUID-based ID otherwise
|
|
519
|
+
selected_ids = {get_objective_id(obj) for obj in selected_cat_objectives}
|
|
520
|
+
idx = 0
|
|
521
|
+
while remaining > 0 and subtype_list:
|
|
522
|
+
subtype = subtype_list[idx % len(subtype_list)]
|
|
523
|
+
available = [
|
|
524
|
+
obj for obj in objectives_by_subtype[subtype] if get_objective_id(obj) not in selected_ids
|
|
525
|
+
]
|
|
526
|
+
if available:
|
|
527
|
+
selected_obj = random.choice(available)
|
|
528
|
+
selected_cat_objectives.append(selected_obj)
|
|
529
|
+
selected_ids.add(get_objective_id(selected_obj))
|
|
530
|
+
remaining -= 1
|
|
531
|
+
idx += 1
|
|
532
|
+
# Prevent infinite loop if we run out of unique objectives
|
|
533
|
+
if idx > len(subtype_list) * MAX_SAMPLING_ITERATIONS_MULTIPLIER:
|
|
534
|
+
break
|
|
535
|
+
|
|
536
|
+
self.logger.info(f"Sampled {len(selected_cat_objectives)} objectives across {num_subtypes} risk subtypes")
|
|
537
|
+
else:
|
|
538
|
+
# No risk subtypes - use num_objectives_with_subtypes for sampling
|
|
539
|
+
if len(custom_objectives) > num_objectives_with_subtypes:
|
|
540
|
+
selected_cat_objectives = random.sample(custom_objectives, num_objectives_with_subtypes)
|
|
541
|
+
self.logger.info(
|
|
542
|
+
f"Sampled {num_objectives_with_subtypes} objectives from {len(custom_objectives)} available for {risk_cat_value}"
|
|
543
|
+
)
|
|
544
|
+
else:
|
|
545
|
+
selected_cat_objectives = custom_objectives
|
|
546
|
+
self.logger.info(f"Using all {len(custom_objectives)} available objectives for {risk_cat_value}")
|
|
547
|
+
target_type_str = "agent" if is_agent_target else "model" if is_agent_target is not None else None
|
|
548
|
+
# Handle jailbreak strategy - need to apply jailbreak prefixes to messages
|
|
549
|
+
if strategy == "jailbreak":
|
|
550
|
+
selected_cat_objectives = await self._apply_jailbreak_prefixes(selected_cat_objectives)
|
|
551
|
+
elif strategy == "indirect_jailbreak":
|
|
552
|
+
selected_cat_objectives = await self._apply_xpia_prompts(selected_cat_objectives, target_type_str)
|
|
553
|
+
|
|
554
|
+
# Extract content from selected objectives
|
|
555
|
+
selected_prompts = []
|
|
556
|
+
for obj in selected_cat_objectives:
|
|
557
|
+
# Extract risk-subtype from target_harms if present
|
|
558
|
+
risk_subtype = extract_risk_subtype(obj)
|
|
559
|
+
|
|
560
|
+
if "messages" in obj and len(obj["messages"]) > 0:
|
|
561
|
+
message = obj["messages"][0]
|
|
562
|
+
if isinstance(message, dict) and "content" in message:
|
|
563
|
+
content = message["content"]
|
|
564
|
+
context = message.get("context", "")
|
|
565
|
+
selected_prompts.append(content)
|
|
566
|
+
# Store mapping of content to context for later evaluation
|
|
567
|
+
self.prompt_to_context[content] = context
|
|
568
|
+
# Store risk_subtype mapping if it exists
|
|
569
|
+
if risk_subtype:
|
|
570
|
+
self.prompt_to_risk_subtype[content] = risk_subtype
|
|
571
|
+
|
|
572
|
+
# Store in cache and return
|
|
573
|
+
self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives)
|
|
574
|
+
return selected_prompts
|
|
575
|
+
|
|
576
|
+
async def _get_rai_attack_objectives(
|
|
577
|
+
self,
|
|
578
|
+
risk_category: RiskCategory,
|
|
579
|
+
risk_cat_value: str,
|
|
580
|
+
application_scenario: str,
|
|
581
|
+
strategy: str,
|
|
582
|
+
baseline_objectives_exist: bool,
|
|
583
|
+
baseline_key: tuple,
|
|
584
|
+
current_key: tuple,
|
|
585
|
+
num_objectives: int,
|
|
586
|
+
num_objectives_with_subtypes: int,
|
|
587
|
+
is_agent_target: Optional[bool] = None,
|
|
588
|
+
client_id: Optional[str] = None,
|
|
589
|
+
) -> List[str]:
|
|
590
|
+
"""Get attack objectives from the RAI service."""
|
|
591
|
+
content_harm_risk = None
|
|
592
|
+
other_risk = ""
|
|
593
|
+
if risk_cat_value in ["hate_unfairness", "violence", "self_harm", "sexual"]:
|
|
594
|
+
content_harm_risk = risk_cat_value
|
|
595
|
+
else:
|
|
596
|
+
other_risk = risk_cat_value
|
|
597
|
+
|
|
598
|
+
try:
|
|
599
|
+
self.logger.debug(
|
|
600
|
+
f"API call: get_attack_objectives({risk_cat_value}, app: {application_scenario}, strategy: {strategy})"
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
# Get objectives from RAI service
|
|
604
|
+
target_type_str = "agent" if is_agent_target else "model" if is_agent_target is not None else None
|
|
605
|
+
|
|
606
|
+
objectives_response = await self.generated_rai_client.get_attack_objectives(
|
|
607
|
+
risk_type=content_harm_risk,
|
|
608
|
+
risk_category=other_risk,
|
|
609
|
+
application_scenario=application_scenario or "",
|
|
610
|
+
strategy=None,
|
|
611
|
+
language=self.language.value,
|
|
612
|
+
scan_session_id=self.scan_session_id,
|
|
613
|
+
target=target_type_str,
|
|
614
|
+
client_id=client_id,
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
if isinstance(objectives_response, list):
|
|
618
|
+
self.logger.debug(f"API returned {len(objectives_response)} objectives")
|
|
619
|
+
# Handle jailbreak strategy
|
|
620
|
+
if strategy == "jailbreak":
|
|
621
|
+
objectives_response = await self._apply_jailbreak_prefixes(objectives_response)
|
|
622
|
+
elif strategy == "indirect_jailbreak":
|
|
623
|
+
objectives_response = await self._apply_xpia_prompts(objectives_response, target_type_str)
|
|
624
|
+
|
|
625
|
+
except Exception as e:
|
|
626
|
+
self.logger.warning(f"Error calling get_attack_objectives: {str(e)}")
|
|
627
|
+
objectives_response = {}
|
|
628
|
+
|
|
629
|
+
# Check if the response is valid
|
|
630
|
+
if not objectives_response or (
|
|
631
|
+
isinstance(objectives_response, dict) and not objectives_response.get("objectives")
|
|
632
|
+
):
|
|
633
|
+
# If we got no agent objectives, fallback to model objectives
|
|
634
|
+
if is_agent_target:
|
|
635
|
+
self.logger.warning(
|
|
636
|
+
f"No agent-type attack objectives found for {risk_cat_value}. "
|
|
637
|
+
"Falling back to model-type objectives."
|
|
638
|
+
)
|
|
639
|
+
try:
|
|
640
|
+
# Retry with model target type
|
|
641
|
+
objectives_response = await self.generated_rai_client.get_attack_objectives(
|
|
642
|
+
risk_type=content_harm_risk,
|
|
643
|
+
risk_category=other_risk,
|
|
644
|
+
application_scenario=application_scenario or "",
|
|
645
|
+
strategy=None,
|
|
646
|
+
language=self.language.value,
|
|
647
|
+
scan_session_id=self.scan_session_id,
|
|
648
|
+
target="model",
|
|
649
|
+
client_id=client_id,
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
if isinstance(objectives_response, list):
|
|
653
|
+
self.logger.debug(f"Fallback API returned {len(objectives_response)} model-type objectives")
|
|
654
|
+
|
|
655
|
+
# Apply strategy-specific transformations to fallback objectives
|
|
656
|
+
# Still try agent-type attack techniques (jailbreak/XPIA) even with model-type baseline objectives
|
|
657
|
+
if strategy == "jailbreak":
|
|
658
|
+
objectives_response = await self._apply_jailbreak_prefixes(objectives_response)
|
|
659
|
+
elif strategy == "indirect_jailbreak":
|
|
660
|
+
objectives_response = await self._apply_xpia_prompts(objectives_response, target_type_str)
|
|
661
|
+
|
|
662
|
+
# Check if fallback response is also empty
|
|
663
|
+
if not objectives_response or (
|
|
664
|
+
isinstance(objectives_response, dict) and not objectives_response.get("objectives")
|
|
665
|
+
):
|
|
666
|
+
self.logger.warning("Fallback to model-type objectives also returned empty list")
|
|
667
|
+
return []
|
|
668
|
+
|
|
669
|
+
except Exception as fallback_error:
|
|
670
|
+
self.logger.error(f"Error calling get_attack_objectives with model fallback: {str(fallback_error)}")
|
|
671
|
+
self.logger.warning("Fallback API call failed, returning empty objectives list")
|
|
672
|
+
return []
|
|
673
|
+
else:
|
|
674
|
+
self.logger.warning("Empty or invalid response, returning empty list")
|
|
675
|
+
return []
|
|
676
|
+
|
|
677
|
+
# Filter and select objectives using num_objectives_with_subtypes
|
|
678
|
+
selected_cat_objectives = self._filter_and_select_objectives(
|
|
679
|
+
objectives_response, strategy, baseline_objectives_exist, baseline_key, num_objectives_with_subtypes
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Extract content and cache
|
|
683
|
+
selected_prompts = self._extract_objective_content(selected_cat_objectives)
|
|
684
|
+
self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives)
|
|
685
|
+
|
|
686
|
+
return selected_prompts
|
|
687
|
+
|
|
688
|
+
async def _apply_xpia_prompts(self, objectives_list: List, target_type_str: str) -> List:
|
|
689
|
+
"""Apply XPIA prompt formatting to objectives for indirect jailbreak strategy.
|
|
690
|
+
|
|
691
|
+
XPIA prompts are wrapper structures that contain:
|
|
692
|
+
- content: benign user query to trigger tool use
|
|
693
|
+
- context: attack vehicle with {attack_text} placeholder
|
|
694
|
+
- context_type: modality for formatting (email, document, html, code)
|
|
695
|
+
- tool_name: name for mock tool
|
|
696
|
+
|
|
697
|
+
We inject the baseline attack objectives into these XPIA wrapper prompts.
|
|
698
|
+
"""
|
|
699
|
+
self.logger.debug(f"Applying XPIA prompts to objectives for indirect jailbreak (target_type={target_type_str})")
|
|
700
|
+
|
|
701
|
+
try:
|
|
702
|
+
# Fetch XPIA wrapper prompts from RAI service
|
|
703
|
+
@self.retry_manager.create_retry_decorator(context="xpia_prompts")
|
|
704
|
+
async def get_xpia_prompts_with_retry():
|
|
705
|
+
return await self.generated_rai_client.get_attack_objectives(
|
|
706
|
+
risk_type=None,
|
|
707
|
+
risk_category="xpia",
|
|
708
|
+
application_scenario="",
|
|
709
|
+
strategy=None,
|
|
710
|
+
language=self.language.value,
|
|
711
|
+
scan_session_id=self.scan_session_id,
|
|
712
|
+
target=target_type_str,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
xpia_prompts = await get_xpia_prompts_with_retry()
|
|
716
|
+
|
|
717
|
+
# If no agent XPIA prompts and we're trying agent, fallback to model
|
|
718
|
+
if (not xpia_prompts or len(xpia_prompts) == 0) and target_type_str == "agent":
|
|
719
|
+
self.logger.debug("No agent-type XPIA prompts available, falling back to model-type XPIA prompts")
|
|
720
|
+
try:
|
|
721
|
+
xpia_prompts = await self.generated_rai_client.get_attack_objectives(
|
|
722
|
+
risk_type=None,
|
|
723
|
+
risk_category="xpia",
|
|
724
|
+
application_scenario="",
|
|
725
|
+
strategy=None,
|
|
726
|
+
language=self.language.value,
|
|
727
|
+
scan_session_id=self.scan_session_id,
|
|
728
|
+
target="model",
|
|
729
|
+
)
|
|
730
|
+
if xpia_prompts and len(xpia_prompts) > 0:
|
|
731
|
+
self.logger.debug(f"Fetched {len(xpia_prompts)} model-type XPIA wrapper prompts as fallback")
|
|
732
|
+
except Exception as fallback_error:
|
|
733
|
+
self.logger.error(f"Error fetching model-type XPIA prompts as fallback: {str(fallback_error)}")
|
|
734
|
+
|
|
735
|
+
if not xpia_prompts or len(xpia_prompts) == 0:
|
|
736
|
+
self.logger.warning("No XPIA prompts available (even after fallback), returning objectives unchanged")
|
|
737
|
+
return objectives_list
|
|
738
|
+
|
|
739
|
+
self.logger.debug(f"Fetched {len(xpia_prompts)} XPIA wrapper prompts")
|
|
740
|
+
|
|
741
|
+
# Apply XPIA wrapping to each baseline objective
|
|
742
|
+
for objective in objectives_list:
|
|
743
|
+
if "messages" in objective and len(objective["messages"]) > 0:
|
|
744
|
+
message = objective["messages"][0]
|
|
745
|
+
if isinstance(message, dict) and "content" in message:
|
|
746
|
+
# Get the baseline attack content to inject
|
|
747
|
+
baseline_attack_content = message["content"]
|
|
748
|
+
# Preserve the original baseline context if it exists
|
|
749
|
+
baseline_context = message.get("context", "")
|
|
750
|
+
|
|
751
|
+
# Normalize baseline_context to a list of context dicts
|
|
752
|
+
baseline_contexts = []
|
|
753
|
+
if baseline_context:
|
|
754
|
+
# Extract baseline context from RAI service format
|
|
755
|
+
context_dict = {"content": baseline_context}
|
|
756
|
+
if message.get("tool_name"):
|
|
757
|
+
context_dict["tool_name"] = message["tool_name"]
|
|
758
|
+
if message.get("context_type"):
|
|
759
|
+
context_dict["context_type"] = message["context_type"]
|
|
760
|
+
baseline_contexts = [context_dict]
|
|
761
|
+
|
|
762
|
+
# Check if baseline contexts have agent fields (context_type, tool_name)
|
|
763
|
+
baseline_contexts_with_agent_fields = []
|
|
764
|
+
baseline_contexts_without_agent_fields = []
|
|
765
|
+
|
|
766
|
+
for ctx in baseline_contexts:
|
|
767
|
+
if isinstance(ctx, dict):
|
|
768
|
+
if "context_type" in ctx or "tool_name" in ctx:
|
|
769
|
+
# This baseline context has agent fields - preserve it separately
|
|
770
|
+
baseline_contexts_with_agent_fields.append(ctx)
|
|
771
|
+
self.logger.debug(
|
|
772
|
+
f"Found baseline context with agent fields: tool_name={ctx.get('tool_name')}, context_type={ctx.get('context_type')}"
|
|
773
|
+
)
|
|
774
|
+
else:
|
|
775
|
+
# This baseline context has no agent fields - can be embedded
|
|
776
|
+
baseline_contexts_without_agent_fields.append(ctx)
|
|
777
|
+
else:
|
|
778
|
+
baseline_contexts_without_agent_fields.append({"content": str(ctx)})
|
|
779
|
+
|
|
780
|
+
# For baseline contexts without agent fields, embed them in the attack content
|
|
781
|
+
if baseline_contexts_without_agent_fields:
|
|
782
|
+
context_texts = [
|
|
783
|
+
ctx.get("content", "")
|
|
784
|
+
for ctx in baseline_contexts_without_agent_fields
|
|
785
|
+
if ctx.get("content")
|
|
786
|
+
]
|
|
787
|
+
if context_texts:
|
|
788
|
+
combined_context = "\n\n".join(context_texts)
|
|
789
|
+
baseline_attack_content = f"{baseline_attack_content}\n\nContext:\n{combined_context}"
|
|
790
|
+
self.logger.debug(
|
|
791
|
+
f"Embedded {len(context_texts)} baseline context(s) without agent fields into attack content"
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
# Randomly select an XPIA wrapper prompt
|
|
795
|
+
xpia_prompt = random.choice(xpia_prompts)
|
|
796
|
+
xpia_message = xpia_prompt.get("messages", [{}])[0]
|
|
797
|
+
|
|
798
|
+
# Extract XPIA wrapper structure
|
|
799
|
+
user_query = xpia_message.get("content", "")
|
|
800
|
+
attack_vehicle_context = xpia_message.get("context", "")
|
|
801
|
+
context_type = xpia_message.get("context_type") or "text"
|
|
802
|
+
tool_name = xpia_message.get("tool_name", "")
|
|
803
|
+
|
|
804
|
+
# Inject baseline attack (now with appended context) into the {attack_text} placeholder
|
|
805
|
+
if "{attack_text}" in attack_vehicle_context:
|
|
806
|
+
injected_context = attack_vehicle_context.replace("{attack_text}", baseline_attack_content)
|
|
807
|
+
else:
|
|
808
|
+
# No placeholder found, append to end
|
|
809
|
+
injected_context = f"{attack_vehicle_context}\n\n{baseline_attack_content}"
|
|
810
|
+
|
|
811
|
+
# Apply modality-based formatting
|
|
812
|
+
formatted_context = format_content_by_modality(injected_context, context_type)
|
|
813
|
+
|
|
814
|
+
# Update the message with benign user query
|
|
815
|
+
message["content"] = user_query
|
|
816
|
+
|
|
817
|
+
# Build the contexts list: XPIA context + any baseline contexts with agent fields
|
|
818
|
+
contexts = [
|
|
819
|
+
{"content": formatted_context, "context_type": context_type, "tool_name": tool_name}
|
|
820
|
+
]
|
|
821
|
+
|
|
822
|
+
# Add baseline contexts with agent fields as separate context entries
|
|
823
|
+
if baseline_contexts_with_agent_fields:
|
|
824
|
+
contexts.extend(baseline_contexts_with_agent_fields)
|
|
825
|
+
self.logger.debug(
|
|
826
|
+
f"Preserved {len(baseline_contexts_with_agent_fields)} baseline context(s) with agent fields"
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
message["context"] = contexts
|
|
830
|
+
message["context_type"] = (
|
|
831
|
+
context_type # Keep at message level for backward compat (XPIA primary)
|
|
832
|
+
)
|
|
833
|
+
message["tool_name"] = tool_name
|
|
834
|
+
|
|
835
|
+
self.logger.debug(
|
|
836
|
+
f"Wrapped baseline attack in XPIA: total contexts={len(contexts)}, xpia_tool={tool_name}, xpia_type={context_type}"
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
except Exception as e:
|
|
840
|
+
self.logger.error(f"Error applying XPIA prompts: {str(e)}")
|
|
841
|
+
self.logger.warning("XPIA prompt application failed, returning original objectives")
|
|
842
|
+
|
|
843
|
+
return objectives_list
|
|
844
|
+
|
|
845
|
+
async def _apply_jailbreak_prefixes(self, objectives_list: List) -> List:
|
|
846
|
+
"""Apply jailbreak prefixes to objectives."""
|
|
847
|
+
self.logger.debug("Applying jailbreak prefixes to objectives")
|
|
848
|
+
try:
|
|
849
|
+
# Use centralized retry decorator
|
|
850
|
+
@self.retry_manager.create_retry_decorator(context="jailbreak_prefixes")
|
|
851
|
+
async def get_jailbreak_prefixes_with_retry():
|
|
852
|
+
return await self.generated_rai_client.get_jailbreak_prefixes()
|
|
853
|
+
|
|
854
|
+
jailbreak_prefixes = await get_jailbreak_prefixes_with_retry()
|
|
855
|
+
for objective in objectives_list:
|
|
856
|
+
if "messages" in objective and len(objective["messages"]) > 0:
|
|
857
|
+
message = objective["messages"][0]
|
|
858
|
+
if isinstance(message, dict) and "content" in message:
|
|
859
|
+
message["content"] = f"{random.choice(jailbreak_prefixes)} {message['content']}"
|
|
860
|
+
except Exception as e:
|
|
861
|
+
self.logger.error(f"Error applying jailbreak prefixes: {str(e)}")
|
|
862
|
+
|
|
863
|
+
return objectives_list
|
|
864
|
+
|
|
865
|
+
def _filter_and_select_objectives(
|
|
866
|
+
self,
|
|
867
|
+
objectives_response: List,
|
|
868
|
+
strategy: str,
|
|
869
|
+
baseline_objectives_exist: bool,
|
|
870
|
+
baseline_key: tuple,
|
|
871
|
+
num_objectives: int,
|
|
872
|
+
) -> List:
|
|
873
|
+
"""Filter and select objectives based on strategy and baseline requirements."""
|
|
874
|
+
# For non-baseline strategies, filter by baseline IDs if they exist
|
|
875
|
+
if strategy != "baseline" and baseline_objectives_exist:
|
|
876
|
+
self.logger.debug(f"Found existing baseline objectives, will filter {strategy} by baseline IDs")
|
|
877
|
+
baseline_selected_objectives = self.attack_objectives[baseline_key].get("selected_objectives", [])
|
|
878
|
+
baseline_objective_ids = [obj.get("id") for obj in baseline_selected_objectives if "id" in obj]
|
|
879
|
+
|
|
880
|
+
if baseline_objective_ids:
|
|
881
|
+
self.logger.debug(f"Filtering by {len(baseline_objective_ids)} baseline objective IDs for {strategy}")
|
|
882
|
+
# Filter by baseline IDs
|
|
883
|
+
filtered_objectives = [obj for obj in objectives_response if obj.get("id") in baseline_objective_ids]
|
|
884
|
+
self.logger.debug(f"Found {len(filtered_objectives)} matching objectives with baseline IDs")
|
|
885
|
+
|
|
886
|
+
# For strategies like indirect_jailbreak, the RAI service may return multiple
|
|
887
|
+
# objectives per baseline ID (e.g., multiple XPIA variations for one baseline objective).
|
|
888
|
+
# We should select num_objectives total, ensuring each baseline objective gets an XPIA attack.
|
|
889
|
+
# Group by baseline ID and select one objective per baseline ID up to num_objectives.
|
|
890
|
+
selected_by_id = {}
|
|
891
|
+
for obj in filtered_objectives:
|
|
892
|
+
obj_id = obj.get("id")
|
|
893
|
+
if obj_id not in selected_by_id:
|
|
894
|
+
selected_by_id[obj_id] = []
|
|
895
|
+
selected_by_id[obj_id].append(obj)
|
|
896
|
+
|
|
897
|
+
# Select objectives to match num_objectives
|
|
898
|
+
selected_cat_objectives = []
|
|
899
|
+
baseline_ids = list(selected_by_id.keys())
|
|
900
|
+
|
|
901
|
+
# If we have enough baseline IDs to cover num_objectives, select one per baseline ID
|
|
902
|
+
if len(baseline_ids) >= num_objectives:
|
|
903
|
+
# Select from the first num_objectives baseline IDs
|
|
904
|
+
for i in range(num_objectives):
|
|
905
|
+
obj_id = baseline_ids[i]
|
|
906
|
+
selected_cat_objectives.append(random.choice(selected_by_id[obj_id]))
|
|
907
|
+
else:
|
|
908
|
+
# If we have fewer baseline IDs than num_objectives, select all and cycle through
|
|
909
|
+
for i in range(num_objectives):
|
|
910
|
+
obj_id = baseline_ids[i % len(baseline_ids)]
|
|
911
|
+
# For repeated IDs, try to select different variations if available
|
|
912
|
+
available_variations = selected_by_id[obj_id].copy()
|
|
913
|
+
# Remove already selected variations for this baseline ID
|
|
914
|
+
already_selected = [obj for obj in selected_cat_objectives if obj.get("id") == obj_id]
|
|
915
|
+
for selected_obj in already_selected:
|
|
916
|
+
if selected_obj in available_variations:
|
|
917
|
+
available_variations.remove(selected_obj)
|
|
918
|
+
|
|
919
|
+
if available_variations:
|
|
920
|
+
selected_cat_objectives.append(random.choice(available_variations))
|
|
921
|
+
else:
|
|
922
|
+
# If no more variations, reuse one (shouldn't happen with proper XPIA generation)
|
|
923
|
+
selected_cat_objectives.append(random.choice(selected_by_id[obj_id]))
|
|
924
|
+
|
|
925
|
+
self.logger.debug(
|
|
926
|
+
f"Selected {len(selected_cat_objectives)} objectives from {len(baseline_ids)} baseline IDs and {len(filtered_objectives)} total variations for {strategy} strategy"
|
|
927
|
+
)
|
|
928
|
+
else:
|
|
929
|
+
self.logger.warning("No baseline objective IDs found, using random selection")
|
|
930
|
+
selected_cat_objectives = random.sample(
|
|
931
|
+
objectives_response, min(num_objectives, len(objectives_response))
|
|
932
|
+
)
|
|
933
|
+
else:
|
|
934
|
+
# This is the baseline strategy or we don't have baseline objectives yet
|
|
935
|
+
self.logger.debug(f"Using random selection for {strategy} strategy")
|
|
936
|
+
selected_cat_objectives = random.sample(objectives_response, min(num_objectives, len(objectives_response)))
|
|
937
|
+
selection_msg = (
|
|
938
|
+
f"Selected {len(selected_cat_objectives)} objectives using num_objectives={num_objectives} "
|
|
939
|
+
f"(available: {len(objectives_response)})"
|
|
940
|
+
)
|
|
941
|
+
self.logger.info(selection_msg)
|
|
942
|
+
tqdm.write(f"[INFO] {selection_msg}")
|
|
943
|
+
|
|
944
|
+
if len(selected_cat_objectives) < num_objectives:
|
|
945
|
+
self.logger.warning(
|
|
946
|
+
f"Only found {len(selected_cat_objectives)} objectives, fewer than requested {num_objectives}"
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
return selected_cat_objectives
|
|
950
|
+
|
|
951
|
+
def _extract_objective_content(self, selected_objectives: List) -> List[str]:
|
|
952
|
+
"""Extract content from selected objectives and build prompt-to-context mapping."""
|
|
953
|
+
selected_prompts = []
|
|
954
|
+
for obj in selected_objectives:
|
|
955
|
+
risk_subtype = extract_risk_subtype(obj)
|
|
956
|
+
if "messages" in obj and len(obj["messages"]) > 0:
|
|
957
|
+
message = obj["messages"][0]
|
|
958
|
+
if isinstance(message, dict) and "content" in message:
|
|
959
|
+
content = message["content"]
|
|
960
|
+
context_raw = message.get("context", "")
|
|
961
|
+
# TODO is first if necessary?
|
|
962
|
+
# Normalize context to always be a list of dicts with 'content' key
|
|
963
|
+
if isinstance(context_raw, list):
|
|
964
|
+
# Already a list - ensure each item is a dict with 'content' key
|
|
965
|
+
contexts = []
|
|
966
|
+
for ctx in context_raw:
|
|
967
|
+
if isinstance(ctx, dict) and "content" in ctx:
|
|
968
|
+
# Preserve all keys including context_type, tool_name if present
|
|
969
|
+
contexts.append(ctx)
|
|
970
|
+
elif isinstance(ctx, str):
|
|
971
|
+
contexts.append({"content": ctx})
|
|
972
|
+
elif context_raw:
|
|
973
|
+
# Single string value - wrap in dict
|
|
974
|
+
contexts = [{"content": context_raw}]
|
|
975
|
+
if message.get("tool_name"):
|
|
976
|
+
contexts[0]["tool_name"] = message["tool_name"]
|
|
977
|
+
if message.get("context_type"):
|
|
978
|
+
contexts[0]["context_type"] = message["context_type"]
|
|
979
|
+
else:
|
|
980
|
+
contexts = []
|
|
981
|
+
|
|
982
|
+
# Check if any context has agent-specific fields
|
|
983
|
+
has_agent_fields = any(
|
|
984
|
+
isinstance(ctx, dict)
|
|
985
|
+
and ("context_type" in ctx and "tool_name" in ctx and ctx["tool_name"] is not None)
|
|
986
|
+
for ctx in contexts
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
# For contexts without agent fields, append them to the content
|
|
990
|
+
# This applies to baseline and any other attack objectives with plain context
|
|
991
|
+
if contexts and not has_agent_fields:
|
|
992
|
+
# Extract all context content and append to the attack content
|
|
993
|
+
context_texts = []
|
|
994
|
+
for ctx in contexts:
|
|
995
|
+
if isinstance(ctx, dict):
|
|
996
|
+
ctx_content = ctx.get("content", "")
|
|
997
|
+
if ctx_content:
|
|
998
|
+
context_texts.append(ctx_content)
|
|
999
|
+
|
|
1000
|
+
if context_texts:
|
|
1001
|
+
# Append context to content
|
|
1002
|
+
combined_context = "\n\n".join(context_texts)
|
|
1003
|
+
content = f"{content}\n\nContext:\n{combined_context}"
|
|
1004
|
+
self.logger.debug(
|
|
1005
|
+
f"Appended {len(context_texts)} context source(s) to attack content (total context length={len(combined_context)})"
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
selected_prompts.append(content)
|
|
1009
|
+
|
|
1010
|
+
# Store risk_subtype mapping if it exists
|
|
1011
|
+
if risk_subtype:
|
|
1012
|
+
self.prompt_to_risk_subtype[content] = risk_subtype
|
|
1013
|
+
|
|
1014
|
+
# Always store contexts if they exist (whether or not they have agent fields)
|
|
1015
|
+
if contexts:
|
|
1016
|
+
context_dict = {"contexts": contexts}
|
|
1017
|
+
if has_agent_fields:
|
|
1018
|
+
self.logger.debug(f"Stored context with agent fields: {len(contexts)} context source(s)")
|
|
1019
|
+
else:
|
|
1020
|
+
self.logger.debug(
|
|
1021
|
+
f"Stored context without agent fields: {len(contexts)} context source(s) (also embedded in content)"
|
|
1022
|
+
)
|
|
1023
|
+
self.prompt_to_context[content] = context_dict
|
|
1024
|
+
else:
|
|
1025
|
+
self.logger.debug(f"No context to store")
|
|
1026
|
+
return selected_prompts
|
|
1027
|
+
|
|
1028
|
+
def _cache_attack_objectives(
|
|
1029
|
+
self,
|
|
1030
|
+
current_key: tuple,
|
|
1031
|
+
risk_cat_value: str,
|
|
1032
|
+
strategy: str,
|
|
1033
|
+
selected_prompts: List[str],
|
|
1034
|
+
selected_objectives: List,
|
|
1035
|
+
) -> None:
|
|
1036
|
+
"""Cache attack objectives for reuse."""
|
|
1037
|
+
objectives_by_category = {risk_cat_value: []}
|
|
1038
|
+
|
|
1039
|
+
# Process list format and organize by category for caching
|
|
1040
|
+
for obj in selected_objectives:
|
|
1041
|
+
obj_id = obj.get("id", f"obj-{uuid.uuid4()}")
|
|
1042
|
+
content = ""
|
|
1043
|
+
context = ""
|
|
1044
|
+
risk_subtype = extract_risk_subtype(obj)
|
|
1045
|
+
|
|
1046
|
+
if "messages" in obj and len(obj["messages"]) > 0:
|
|
1047
|
+
|
|
1048
|
+
message = obj["messages"][0]
|
|
1049
|
+
content = message.get("content", "")
|
|
1050
|
+
context = message.get("context", "")
|
|
1051
|
+
if content:
|
|
1052
|
+
obj_data = {"id": obj_id, "content": content, "context": context}
|
|
1053
|
+
# Add risk_subtype to obj_data if it exists
|
|
1054
|
+
if risk_subtype:
|
|
1055
|
+
obj_data["risk_subtype"] = risk_subtype
|
|
1056
|
+
objectives_by_category[risk_cat_value].append(obj_data)
|
|
1057
|
+
|
|
1058
|
+
self.attack_objectives[current_key] = {
|
|
1059
|
+
"objectives_by_category": objectives_by_category,
|
|
1060
|
+
"strategy": strategy,
|
|
1061
|
+
"risk_category": risk_cat_value,
|
|
1062
|
+
"selected_prompts": selected_prompts,
|
|
1063
|
+
"selected_objectives": selected_objectives,
|
|
1064
|
+
}
|
|
1065
|
+
self.logger.info(f"Selected {len(selected_prompts)} objectives for {risk_cat_value}")
|
|
1066
|
+
|
|
1067
|
+
async def _process_attack(
|
|
1068
|
+
self,
|
|
1069
|
+
strategy: Union[AttackStrategy, List[AttackStrategy]],
|
|
1070
|
+
risk_category: RiskCategory,
|
|
1071
|
+
all_prompts: List[str],
|
|
1072
|
+
progress_bar: tqdm,
|
|
1073
|
+
progress_bar_lock: asyncio.Lock,
|
|
1074
|
+
scan_name: Optional[str] = None,
|
|
1075
|
+
skip_upload: bool = False,
|
|
1076
|
+
output_path: Optional[Union[str, os.PathLike]] = None,
|
|
1077
|
+
timeout: int = 120,
|
|
1078
|
+
_skip_evals: bool = False,
|
|
1079
|
+
) -> Optional[EvaluationResult]:
|
|
1080
|
+
"""Process a red team scan with the given orchestrator, converter, and prompts.
|
|
1081
|
+
|
|
1082
|
+
Executes a red team attack process using the specified strategy and risk category against the
|
|
1083
|
+
target model or function. This includes creating an orchestrator, applying prompts through the
|
|
1084
|
+
appropriate converter, saving results to files, and optionally evaluating the results.
|
|
1085
|
+
The function handles progress tracking, logging, and error handling throughout the process.
|
|
1086
|
+
|
|
1087
|
+
:param strategy: The attack strategy to use
|
|
1088
|
+
:type strategy: Union[AttackStrategy, List[AttackStrategy]]
|
|
1089
|
+
:param risk_category: The risk category to evaluate
|
|
1090
|
+
:type risk_category: RiskCategory
|
|
1091
|
+
:param all_prompts: List of prompts to use for the scan
|
|
1092
|
+
:type all_prompts: List[str]
|
|
1093
|
+
:param progress_bar: Progress bar to update
|
|
1094
|
+
:type progress_bar: tqdm
|
|
1095
|
+
:param progress_bar_lock: Lock for the progress bar
|
|
1096
|
+
:type progress_bar_lock: asyncio.Lock
|
|
1097
|
+
:param scan_name: Optional name for the evaluation
|
|
1098
|
+
:type scan_name: Optional[str]
|
|
1099
|
+
:param skip_upload: Whether to return only data without evaluation
|
|
1100
|
+
:type skip_upload: bool
|
|
1101
|
+
:param output_path: Optional path for output
|
|
1102
|
+
:type output_path: Optional[Union[str, os.PathLike]]
|
|
1103
|
+
:param timeout: The timeout in seconds for API calls
|
|
1104
|
+
:type timeout: int
|
|
1105
|
+
:param _skip_evals: Whether to skip the actual evaluation process
|
|
1106
|
+
:type _skip_evals: bool
|
|
1107
|
+
:return: Evaluation result if available
|
|
1108
|
+
:rtype: Optional[EvaluationResult]
|
|
1109
|
+
"""
|
|
1110
|
+
strategy_name = get_strategy_name(strategy)
|
|
1111
|
+
task_key = f"{strategy_name}_{risk_category.value}_attack"
|
|
1112
|
+
self.task_statuses[task_key] = TASK_STATUS["RUNNING"]
|
|
1113
|
+
|
|
1114
|
+
try:
|
|
1115
|
+
start_time = time.time()
|
|
1116
|
+
tqdm.write(f"▶️ Starting task: {strategy_name} strategy for {risk_category.value} risk category")
|
|
1117
|
+
|
|
1118
|
+
# Get converter and orchestrator function
|
|
1119
|
+
converter = get_converter_for_strategy(
|
|
1120
|
+
strategy, self.generated_rai_client, self._one_dp_project, self.logger
|
|
1121
|
+
)
|
|
1122
|
+
call_orchestrator = self.orchestrator_manager.get_orchestrator_for_attack_strategy(strategy)
|
|
1123
|
+
|
|
1124
|
+
try:
|
|
1125
|
+
self.logger.debug(f"Calling orchestrator for {strategy_name} strategy")
|
|
1126
|
+
orchestrator = await call_orchestrator(
|
|
1127
|
+
chat_target=self.chat_target,
|
|
1128
|
+
all_prompts=all_prompts,
|
|
1129
|
+
converter=converter,
|
|
1130
|
+
strategy_name=strategy_name,
|
|
1131
|
+
risk_category=risk_category,
|
|
1132
|
+
risk_category_name=risk_category.value,
|
|
1133
|
+
timeout=timeout,
|
|
1134
|
+
red_team_info=self.red_team_info,
|
|
1135
|
+
task_statuses=self.task_statuses,
|
|
1136
|
+
prompt_to_context=self.prompt_to_context,
|
|
1137
|
+
)
|
|
1138
|
+
except Exception as e:
|
|
1139
|
+
self.logger.error(f"Error calling orchestrator for {strategy_name} strategy: {str(e)}")
|
|
1140
|
+
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
1141
|
+
self.failed_tasks += 1
|
|
1142
|
+
async with progress_bar_lock:
|
|
1143
|
+
progress_bar.update(1)
|
|
1144
|
+
return None
|
|
1145
|
+
|
|
1146
|
+
# Write PyRIT outputs to file
|
|
1147
|
+
data_path = write_pyrit_outputs_to_file(
|
|
1148
|
+
output_path=self.red_team_info[strategy_name][risk_category.value]["data_file"],
|
|
1149
|
+
logger=self.logger,
|
|
1150
|
+
prompt_to_context=self.prompt_to_context,
|
|
1151
|
+
)
|
|
1152
|
+
orchestrator.dispose_db_engine()
|
|
1153
|
+
|
|
1154
|
+
# Store data file in our tracking dictionary
|
|
1155
|
+
self.red_team_info[strategy_name][risk_category.value]["data_file"] = data_path
|
|
1156
|
+
self.logger.debug(
|
|
1157
|
+
f"Updated red_team_info with data file: {strategy_name} -> {risk_category.value} -> {data_path}"
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
# Perform evaluation
|
|
1161
|
+
try:
|
|
1162
|
+
await self.evaluation_processor.evaluate(
|
|
1163
|
+
scan_name=scan_name,
|
|
1164
|
+
risk_category=risk_category,
|
|
1165
|
+
strategy=strategy,
|
|
1166
|
+
_skip_evals=_skip_evals,
|
|
1167
|
+
data_path=data_path,
|
|
1168
|
+
output_path=None,
|
|
1169
|
+
red_team_info=self.red_team_info,
|
|
1170
|
+
)
|
|
1171
|
+
except Exception as e:
|
|
1172
|
+
self.logger.error(
|
|
1173
|
+
self.logger,
|
|
1174
|
+
f"Error during evaluation for {strategy_name}/{risk_category.value}",
|
|
1175
|
+
e,
|
|
1176
|
+
)
|
|
1177
|
+
tqdm.write(f"⚠️ Evaluation error for {strategy_name}/{risk_category.value}: {str(e)}")
|
|
1178
|
+
self.red_team_info[strategy_name][risk_category.value]["status"] = TASK_STATUS["FAILED"]
|
|
1179
|
+
|
|
1180
|
+
# Update progress
|
|
1181
|
+
async with progress_bar_lock:
|
|
1182
|
+
self.completed_tasks += 1
|
|
1183
|
+
progress_bar.update(1)
|
|
1184
|
+
completion_pct = (self.completed_tasks / self.total_tasks) * 100
|
|
1185
|
+
elapsed_time = time.time() - start_time
|
|
1186
|
+
|
|
1187
|
+
if self.start_time:
|
|
1188
|
+
total_elapsed = time.time() - self.start_time
|
|
1189
|
+
avg_time_per_task = total_elapsed / self.completed_tasks if self.completed_tasks > 0 else 0
|
|
1190
|
+
remaining_tasks = self.total_tasks - self.completed_tasks
|
|
1191
|
+
est_remaining_time = avg_time_per_task * remaining_tasks if avg_time_per_task > 0 else 0
|
|
1192
|
+
|
|
1193
|
+
tqdm.write(
|
|
1194
|
+
f"✅ Completed task {self.completed_tasks}/{self.total_tasks} ({completion_pct:.1f}%) - {strategy_name}/{risk_category.value} in {elapsed_time:.1f}s"
|
|
1195
|
+
)
|
|
1196
|
+
tqdm.write(f" Est. remaining: {est_remaining_time/60:.1f} minutes")
|
|
1197
|
+
else:
|
|
1198
|
+
tqdm.write(
|
|
1199
|
+
f"✅ Completed task {self.completed_tasks}/{self.total_tasks} ({completion_pct:.1f}%) - {strategy_name}/{risk_category.value} in {elapsed_time:.1f}s"
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
|
|
1203
|
+
|
|
1204
|
+
except Exception as e:
|
|
1205
|
+
self.logger.error(
|
|
1206
|
+
f"Unexpected error processing {strategy_name} strategy for {risk_category.value}: {str(e)}"
|
|
1207
|
+
)
|
|
1208
|
+
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
|
|
1209
|
+
self.failed_tasks += 1
|
|
1210
|
+
async with progress_bar_lock:
|
|
1211
|
+
progress_bar.update(1)
|
|
1212
|
+
|
|
1213
|
+
return None
|
|
1214
|
+
|
|
1215
|
+
async def scan(
|
|
1216
|
+
self,
|
|
1217
|
+
target: Union[
|
|
1218
|
+
Callable,
|
|
1219
|
+
AzureOpenAIModelConfiguration,
|
|
1220
|
+
OpenAIModelConfiguration,
|
|
1221
|
+
PromptChatTarget,
|
|
1222
|
+
],
|
|
1223
|
+
*,
|
|
1224
|
+
scan_name: Optional[str] = None,
|
|
1225
|
+
attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]] = [],
|
|
1226
|
+
skip_upload: bool = False,
|
|
1227
|
+
output_path: Optional[Union[str, os.PathLike]] = None,
|
|
1228
|
+
application_scenario: Optional[str] = None,
|
|
1229
|
+
parallel_execution: bool = True,
|
|
1230
|
+
max_parallel_tasks: int = 5,
|
|
1231
|
+
timeout: int = 3600,
|
|
1232
|
+
skip_evals: bool = False,
|
|
1233
|
+
**kwargs: Any,
|
|
1234
|
+
) -> RedTeamResult:
|
|
1235
|
+
"""Run a red team scan against the target using the specified strategies.
|
|
1236
|
+
|
|
1237
|
+
:param target: The target model or function to scan
|
|
1238
|
+
:type target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration, PromptChatTarget]
|
|
1239
|
+
:param scan_name: Optional name for the evaluation
|
|
1240
|
+
:type scan_name: Optional[str]
|
|
1241
|
+
:param attack_strategies: List of attack strategies to use
|
|
1242
|
+
:type attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]]
|
|
1243
|
+
:param skip_upload: Flag to determine if the scan results should be uploaded
|
|
1244
|
+
:type skip_upload: bool
|
|
1245
|
+
:param output_path: Optional path for output
|
|
1246
|
+
:type output_path: Optional[Union[str, os.PathLike]]
|
|
1247
|
+
:param application_scenario: Optional description of the application scenario
|
|
1248
|
+
:type application_scenario: Optional[str]
|
|
1249
|
+
:param parallel_execution: Whether to execute orchestrator tasks in parallel
|
|
1250
|
+
:type parallel_execution: bool
|
|
1251
|
+
:param max_parallel_tasks: Maximum number of parallel orchestrator tasks to run (default: 5)
|
|
1252
|
+
:type max_parallel_tasks: int
|
|
1253
|
+
:param timeout: The timeout in seconds for API calls (default: 120)
|
|
1254
|
+
:type timeout: int
|
|
1255
|
+
:param skip_evals: Whether to skip the evaluation process
|
|
1256
|
+
:type skip_evals: bool
|
|
1257
|
+
:return: The output from the red team scan
|
|
1258
|
+
:rtype: RedTeamResult
|
|
1259
|
+
"""
|
|
1260
|
+
user_agent: Optional[str] = kwargs.get("user_agent", "(type=redteam; subtype=RedTeam)")
|
|
1261
|
+
run_id_override = kwargs.get("run_id") or kwargs.get("runId")
|
|
1262
|
+
eval_id_override = kwargs.get("eval_id") or kwargs.get("evalId")
|
|
1263
|
+
created_at_override = kwargs.get("created_at") or kwargs.get("createdAt")
|
|
1264
|
+
taxonomy_risk_categories = kwargs.get("taxonomy_risk_categories") # key is risk category value is taxonomy
|
|
1265
|
+
_app_insights_configuration = kwargs.get("_app_insights_configuration")
|
|
1266
|
+
self._app_insights_configuration = _app_insights_configuration
|
|
1267
|
+
self.taxonomy_risk_categories = taxonomy_risk_categories or {}
|
|
1268
|
+
is_agent_target: Optional[bool] = kwargs.get("is_agent_target", False)
|
|
1269
|
+
client_id: Optional[str] = kwargs.get("client_id")
|
|
1270
|
+
|
|
1271
|
+
with UserAgentSingleton().add_useragent_product(user_agent):
|
|
1272
|
+
# Initialize scan
|
|
1273
|
+
self._initialize_scan(scan_name, application_scenario)
|
|
1274
|
+
|
|
1275
|
+
# Setup logging and directories FIRST
|
|
1276
|
+
self._setup_scan_environment()
|
|
1277
|
+
|
|
1278
|
+
# Setup component managers AFTER scan environment is set up
|
|
1279
|
+
self._setup_component_managers()
|
|
1280
|
+
|
|
1281
|
+
# Update result processor with AI studio URL
|
|
1282
|
+
self.result_processor.ai_studio_url = getattr(self.mlflow_integration, "ai_studio_url", None)
|
|
1283
|
+
|
|
1284
|
+
# Update component managers with the new logger
|
|
1285
|
+
self.orchestrator_manager.logger = self.logger
|
|
1286
|
+
self.evaluation_processor.logger = self.logger
|
|
1287
|
+
self.mlflow_integration.logger = self.logger
|
|
1288
|
+
self.result_processor.logger = self.logger
|
|
1289
|
+
|
|
1290
|
+
self.mlflow_integration.set_run_identity_overrides(
|
|
1291
|
+
run_id=run_id_override,
|
|
1292
|
+
eval_id=eval_id_override,
|
|
1293
|
+
created_at=created_at_override,
|
|
1294
|
+
)
|
|
1295
|
+
|
|
1296
|
+
# Validate attack objective generator
|
|
1297
|
+
if not self.attack_objective_generator:
|
|
1298
|
+
raise EvaluationException(
|
|
1299
|
+
message="Attack objective generator is required for red team agent.",
|
|
1300
|
+
internal_message="Attack objective generator is not provided.",
|
|
1301
|
+
target=ErrorTarget.RED_TEAM,
|
|
1302
|
+
category=ErrorCategory.MISSING_FIELD,
|
|
1303
|
+
blame=ErrorBlame.USER_ERROR,
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
# Set default risk categories if not specified
|
|
1307
|
+
if not self.attack_objective_generator.risk_categories:
|
|
1308
|
+
self.logger.info("No risk categories specified, using all available categories")
|
|
1309
|
+
self.attack_objective_generator.risk_categories = [
|
|
1310
|
+
RiskCategory.HateUnfairness,
|
|
1311
|
+
RiskCategory.Sexual,
|
|
1312
|
+
RiskCategory.Violence,
|
|
1313
|
+
RiskCategory.SelfHarm,
|
|
1314
|
+
]
|
|
1315
|
+
|
|
1316
|
+
self.risk_categories = self.attack_objective_generator.risk_categories
|
|
1317
|
+
self.result_processor.risk_categories = self.risk_categories
|
|
1318
|
+
|
|
1319
|
+
# Validate risk categories for target type
|
|
1320
|
+
if not is_agent_target:
|
|
1321
|
+
# Check if any agent-only risk categories are used with model targets
|
|
1322
|
+
for risk_cat in self.risk_categories:
|
|
1323
|
+
if risk_cat == RiskCategory.SensitiveDataLeakage:
|
|
1324
|
+
raise EvaluationException(
|
|
1325
|
+
message=f"Risk category '{risk_cat.value}' is only available for agent targets",
|
|
1326
|
+
internal_message=f"Risk category {risk_cat.value} requires agent target",
|
|
1327
|
+
target=ErrorTarget.RED_TEAM,
|
|
1328
|
+
category=ErrorCategory.INVALID_VALUE,
|
|
1329
|
+
blame=ErrorBlame.USER_ERROR,
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
# Show risk categories to user
|
|
1333
|
+
tqdm.write(f"📊 Risk categories: {[rc.value for rc in self.risk_categories]}")
|
|
1334
|
+
self.logger.info(f"Risk categories to process: {[rc.value for rc in self.risk_categories]}")
|
|
1335
|
+
|
|
1336
|
+
# Setup attack strategies
|
|
1337
|
+
if AttackStrategy.Baseline not in attack_strategies:
|
|
1338
|
+
attack_strategies.insert(0, AttackStrategy.Baseline)
|
|
1339
|
+
|
|
1340
|
+
# Start MLFlow run if not skipping upload
|
|
1341
|
+
if skip_upload:
|
|
1342
|
+
eval_run = {}
|
|
1343
|
+
else:
|
|
1344
|
+
eval_run = self.mlflow_integration.start_redteam_mlflow_run(self.azure_ai_project, scan_name)
|
|
1345
|
+
tqdm.write(f"🔗 Track your red team scan in AI Foundry: {self.mlflow_integration.ai_studio_url}")
|
|
1346
|
+
|
|
1347
|
+
# Update result processor with the AI studio URL now that it's available
|
|
1348
|
+
self.result_processor.ai_studio_url = self.mlflow_integration.ai_studio_url
|
|
1349
|
+
|
|
1350
|
+
# Process strategies and execute scan
|
|
1351
|
+
flattened_attack_strategies = get_flattened_attack_strategies(attack_strategies)
|
|
1352
|
+
self._validate_strategies(flattened_attack_strategies)
|
|
1353
|
+
|
|
1354
|
+
# Calculate total tasks and initialize tracking
|
|
1355
|
+
self.total_tasks = len(self.risk_categories) * len(flattened_attack_strategies)
|
|
1356
|
+
tqdm.write(f"📋 Planning {self.total_tasks} total tasks")
|
|
1357
|
+
self._initialize_tracking_dict(flattened_attack_strategies)
|
|
1358
|
+
|
|
1359
|
+
# Fetch attack objectives
|
|
1360
|
+
all_objectives = await self._fetch_all_objectives(
|
|
1361
|
+
flattened_attack_strategies, application_scenario, is_agent_target, client_id
|
|
1362
|
+
)
|
|
1363
|
+
|
|
1364
|
+
chat_target = get_chat_target(target)
|
|
1365
|
+
self.chat_target = chat_target
|
|
1366
|
+
|
|
1367
|
+
# Execute attacks
|
|
1368
|
+
await self._execute_attacks(
|
|
1369
|
+
flattened_attack_strategies,
|
|
1370
|
+
all_objectives,
|
|
1371
|
+
scan_name,
|
|
1372
|
+
skip_upload,
|
|
1373
|
+
output_path,
|
|
1374
|
+
timeout,
|
|
1375
|
+
skip_evals,
|
|
1376
|
+
parallel_execution,
|
|
1377
|
+
max_parallel_tasks,
|
|
1378
|
+
)
|
|
1379
|
+
|
|
1380
|
+
# Process and return results
|
|
1381
|
+
return await self._finalize_results(skip_upload, skip_evals, eval_run, output_path, scan_name)
|
|
1382
|
+
|
|
1383
|
+
def _initialize_scan(self, scan_name: Optional[str], application_scenario: Optional[str]):
|
|
1384
|
+
"""Initialize scan-specific variables."""
|
|
1385
|
+
self.start_time = time.time()
|
|
1386
|
+
self.task_statuses = {}
|
|
1387
|
+
self.completed_tasks = 0
|
|
1388
|
+
self.failed_tasks = 0
|
|
1389
|
+
|
|
1390
|
+
# Generate unique scan ID and session ID
|
|
1391
|
+
self.scan_id = (
|
|
1392
|
+
f"scan_{scan_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
1393
|
+
if scan_name
|
|
1394
|
+
else f"scan_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
1395
|
+
)
|
|
1396
|
+
self.scan_id = self.scan_id.replace(" ", "_")
|
|
1397
|
+
self.scan_session_id = str(uuid.uuid4())
|
|
1398
|
+
self.application_scenario = application_scenario or ""
|
|
1399
|
+
|
|
1400
|
+
def _setup_scan_environment(self):
|
|
1401
|
+
"""Setup scan output directory and logging."""
|
|
1402
|
+
# Use file manager to create scan output directory
|
|
1403
|
+
self.scan_output_dir = self.file_manager.get_scan_output_path(self.scan_id)
|
|
1404
|
+
|
|
1405
|
+
# Re-initialize logger with the scan output directory
|
|
1406
|
+
self.logger = setup_logger(output_dir=self.scan_output_dir)
|
|
1407
|
+
|
|
1408
|
+
# Setup logging filters
|
|
1409
|
+
self._setup_logging_filters()
|
|
1410
|
+
|
|
1411
|
+
log_section_header(self.logger, "Starting red team scan")
|
|
1412
|
+
tqdm.write(f"🚀 STARTING RED TEAM SCAN")
|
|
1413
|
+
tqdm.write(f"📂 Output directory: {self.scan_output_dir}")
|
|
1414
|
+
|
|
1415
|
+
def _setup_logging_filters(self):
|
|
1416
|
+
"""Setup logging filters to suppress unwanted logs."""
|
|
1417
|
+
|
|
1418
|
+
class LogFilter(logging.Filter):
|
|
1419
|
+
def filter(self, record):
|
|
1420
|
+
# Filter out promptflow logs and evaluation warnings about artifacts
|
|
1421
|
+
if record.name.startswith("promptflow"):
|
|
1422
|
+
return False
|
|
1423
|
+
if "The path to the artifact is either not a directory or does not exist" in record.getMessage():
|
|
1424
|
+
return False
|
|
1425
|
+
if "RedTeamResult object at" in record.getMessage():
|
|
1426
|
+
return False
|
|
1427
|
+
if "timeout won't take effect" in record.getMessage():
|
|
1428
|
+
return False
|
|
1429
|
+
if "Submitting run" in record.getMessage():
|
|
1430
|
+
return False
|
|
1431
|
+
return True
|
|
1432
|
+
|
|
1433
|
+
# Apply filter to root logger
|
|
1434
|
+
root_logger = logging.getLogger()
|
|
1435
|
+
log_filter = LogFilter()
|
|
1436
|
+
|
|
1437
|
+
for handler in root_logger.handlers:
|
|
1438
|
+
for filter in handler.filters:
|
|
1439
|
+
handler.removeFilter(filter)
|
|
1440
|
+
handler.addFilter(log_filter)
|
|
1441
|
+
|
|
1442
|
+
def _validate_strategies(self, flattened_attack_strategies: List):
|
|
1443
|
+
"""Validate attack strategies."""
|
|
1444
|
+
if len(flattened_attack_strategies) > 2 and (
|
|
1445
|
+
AttackStrategy.MultiTurn in flattened_attack_strategies
|
|
1446
|
+
or AttackStrategy.Crescendo in flattened_attack_strategies
|
|
1447
|
+
):
|
|
1448
|
+
self.logger.warning(
|
|
1449
|
+
"MultiTurn and Crescendo strategies are not compatible with multiple attack strategies."
|
|
1450
|
+
)
|
|
1451
|
+
raise ValueError("MultiTurn and Crescendo strategies are not compatible with multiple attack strategies.")
|
|
1452
|
+
|
|
1453
|
+
def _initialize_tracking_dict(self, flattened_attack_strategies: List):
|
|
1454
|
+
"""Initialize the red_team_info tracking dictionary."""
|
|
1455
|
+
self.red_team_info = {}
|
|
1456
|
+
for strategy in flattened_attack_strategies:
|
|
1457
|
+
strategy_name = get_strategy_name(strategy)
|
|
1458
|
+
self.red_team_info[strategy_name] = {}
|
|
1459
|
+
for risk_category in self.risk_categories:
|
|
1460
|
+
self.red_team_info[strategy_name][risk_category.value] = {
|
|
1461
|
+
"data_file": "",
|
|
1462
|
+
"evaluation_result_file": "",
|
|
1463
|
+
"evaluation_result": None,
|
|
1464
|
+
"status": TASK_STATUS["PENDING"],
|
|
1465
|
+
}
|
|
1466
|
+
|
|
1467
|
+
async def _fetch_all_objectives(
|
|
1468
|
+
self,
|
|
1469
|
+
flattened_attack_strategies: List,
|
|
1470
|
+
application_scenario: str,
|
|
1471
|
+
is_agent_target: bool,
|
|
1472
|
+
client_id: Optional[str] = None,
|
|
1473
|
+
) -> Dict:
|
|
1474
|
+
"""Fetch all attack objectives for all strategies and risk categories."""
|
|
1475
|
+
log_section_header(self.logger, "Fetching attack objectives")
|
|
1476
|
+
all_objectives = {}
|
|
1477
|
+
|
|
1478
|
+
# Calculate and log num_objectives_with_subtypes once globally
|
|
1479
|
+
num_objectives = self.attack_objective_generator.num_objectives
|
|
1480
|
+
max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in self.risk_categories), default=0)
|
|
1481
|
+
num_objectives_with_subtypes = max(num_objectives, max_num_subtypes)
|
|
1482
|
+
|
|
1483
|
+
if num_objectives_with_subtypes != num_objectives:
|
|
1484
|
+
warning_msg = (
|
|
1485
|
+
f"Using {num_objectives_with_subtypes} objectives per risk category instead of requested {num_objectives} "
|
|
1486
|
+
f"to ensure adequate coverage of {max_num_subtypes} subtypes"
|
|
1487
|
+
)
|
|
1488
|
+
self.logger.warning(warning_msg)
|
|
1489
|
+
tqdm.write(f"[WARNING] {warning_msg}")
|
|
1490
|
+
|
|
1491
|
+
# First fetch baseline objectives for all risk categories
|
|
1492
|
+
self.logger.info("Fetching baseline objectives for all risk categories")
|
|
1493
|
+
for risk_category in self.risk_categories:
|
|
1494
|
+
baseline_objectives = await self._get_attack_objectives(
|
|
1495
|
+
risk_category=risk_category,
|
|
1496
|
+
application_scenario=application_scenario,
|
|
1497
|
+
strategy="baseline",
|
|
1498
|
+
is_agent_target=is_agent_target,
|
|
1499
|
+
client_id=client_id,
|
|
1500
|
+
)
|
|
1501
|
+
if "baseline" not in all_objectives:
|
|
1502
|
+
all_objectives["baseline"] = {}
|
|
1503
|
+
all_objectives["baseline"][risk_category.value] = baseline_objectives
|
|
1504
|
+
status_msg = f"📝 Fetched baseline objectives for {risk_category.value}: {len(baseline_objectives)}/{num_objectives_with_subtypes} objectives"
|
|
1505
|
+
if len(baseline_objectives) < num_objectives_with_subtypes:
|
|
1506
|
+
status_msg += f" (⚠️ fewer than expected)"
|
|
1507
|
+
tqdm.write(status_msg)
|
|
1508
|
+
|
|
1509
|
+
# Then fetch objectives for other strategies
|
|
1510
|
+
strategy_count = len(flattened_attack_strategies)
|
|
1511
|
+
for i, strategy in enumerate(flattened_attack_strategies):
|
|
1512
|
+
strategy_name = get_strategy_name(strategy)
|
|
1513
|
+
if strategy_name == "baseline":
|
|
1514
|
+
continue
|
|
1515
|
+
|
|
1516
|
+
tqdm.write(f"🔄 Fetching objectives for strategy {i+1}/{strategy_count}: {strategy_name}")
|
|
1517
|
+
all_objectives[strategy_name] = {}
|
|
1518
|
+
|
|
1519
|
+
for risk_category in self.risk_categories:
|
|
1520
|
+
objectives = await self._get_attack_objectives(
|
|
1521
|
+
risk_category=risk_category,
|
|
1522
|
+
application_scenario=application_scenario,
|
|
1523
|
+
strategy=strategy_name,
|
|
1524
|
+
is_agent_target=is_agent_target,
|
|
1525
|
+
client_id=client_id,
|
|
1526
|
+
)
|
|
1527
|
+
all_objectives[strategy_name][risk_category.value] = objectives
|
|
1528
|
+
|
|
1529
|
+
return all_objectives
|
|
1530
|
+
|
|
1531
|
+
async def _execute_attacks(
|
|
1532
|
+
self,
|
|
1533
|
+
flattened_attack_strategies: List,
|
|
1534
|
+
all_objectives: Dict,
|
|
1535
|
+
scan_name: str,
|
|
1536
|
+
skip_upload: bool,
|
|
1537
|
+
output_path: str,
|
|
1538
|
+
timeout: int,
|
|
1539
|
+
skip_evals: bool,
|
|
1540
|
+
parallel_execution: bool,
|
|
1541
|
+
max_parallel_tasks: int,
|
|
1542
|
+
):
|
|
1543
|
+
"""Execute all attack combinations."""
|
|
1544
|
+
log_section_header(self.logger, "Starting orchestrator processing")
|
|
1545
|
+
|
|
1546
|
+
# Create progress bar
|
|
1547
|
+
progress_bar = tqdm(
|
|
1548
|
+
total=self.total_tasks,
|
|
1549
|
+
desc="Scanning: ",
|
|
1550
|
+
ncols=100,
|
|
1551
|
+
unit="scan",
|
|
1552
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
|
|
1553
|
+
)
|
|
1554
|
+
progress_bar.set_postfix({"current": "initializing"})
|
|
1555
|
+
progress_bar_lock = asyncio.Lock()
|
|
1556
|
+
|
|
1557
|
+
# Create all tasks for parallel processing
|
|
1558
|
+
orchestrator_tasks = []
|
|
1559
|
+
combinations = list(itertools.product(flattened_attack_strategies, self.risk_categories))
|
|
1560
|
+
|
|
1561
|
+
for combo_idx, (strategy, risk_category) in enumerate(combinations):
|
|
1562
|
+
strategy_name = get_strategy_name(strategy)
|
|
1563
|
+
objectives = all_objectives[strategy_name][risk_category.value]
|
|
1564
|
+
|
|
1565
|
+
if not objectives:
|
|
1566
|
+
self.logger.warning(f"No objectives found for {strategy_name}+{risk_category.value}, skipping")
|
|
1567
|
+
tqdm.write(f"⚠️ No objectives found for {strategy_name}/{risk_category.value}, skipping")
|
|
1568
|
+
self.red_team_info[strategy_name][risk_category.value]["status"] = TASK_STATUS["COMPLETED"]
|
|
1569
|
+
async with progress_bar_lock:
|
|
1570
|
+
progress_bar.update(1)
|
|
1571
|
+
continue
|
|
1572
|
+
|
|
1573
|
+
orchestrator_tasks.append(
|
|
1574
|
+
self._process_attack(
|
|
1575
|
+
all_prompts=objectives,
|
|
1576
|
+
strategy=strategy,
|
|
1577
|
+
progress_bar=progress_bar,
|
|
1578
|
+
progress_bar_lock=progress_bar_lock,
|
|
1579
|
+
scan_name=scan_name,
|
|
1580
|
+
skip_upload=skip_upload,
|
|
1581
|
+
output_path=output_path,
|
|
1582
|
+
risk_category=risk_category,
|
|
1583
|
+
timeout=timeout,
|
|
1584
|
+
_skip_evals=skip_evals,
|
|
1585
|
+
)
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
# Process tasks
|
|
1589
|
+
await self._process_orchestrator_tasks(orchestrator_tasks, parallel_execution, max_parallel_tasks, timeout)
|
|
1590
|
+
progress_bar.close()
|
|
1591
|
+
|
|
1592
|
+
async def _process_orchestrator_tasks(
|
|
1593
|
+
self, orchestrator_tasks: List, parallel_execution: bool, max_parallel_tasks: int, timeout: int
|
|
1594
|
+
):
|
|
1595
|
+
"""Process orchestrator tasks either in parallel or sequentially."""
|
|
1596
|
+
if parallel_execution and orchestrator_tasks:
|
|
1597
|
+
tqdm.write(f"⚙️ Processing {len(orchestrator_tasks)} tasks in parallel (max {max_parallel_tasks} at a time)")
|
|
1598
|
+
|
|
1599
|
+
# Process tasks in batches
|
|
1600
|
+
for i in range(0, len(orchestrator_tasks), max_parallel_tasks):
|
|
1601
|
+
end_idx = min(i + max_parallel_tasks, len(orchestrator_tasks))
|
|
1602
|
+
batch = orchestrator_tasks[i:end_idx]
|
|
1603
|
+
|
|
1604
|
+
try:
|
|
1605
|
+
await asyncio.wait_for(asyncio.gather(*batch), timeout=timeout * 2)
|
|
1606
|
+
except asyncio.TimeoutError:
|
|
1607
|
+
self.logger.warning(f"Batch {i//max_parallel_tasks+1} timed out")
|
|
1608
|
+
tqdm.write(f"⚠️ Batch {i//max_parallel_tasks+1} timed out, continuing with next batch")
|
|
1609
|
+
continue
|
|
1610
|
+
except Exception as e:
|
|
1611
|
+
self.logger.error(f"Error processing batch {i//max_parallel_tasks+1}: {str(e)}")
|
|
1612
|
+
continue
|
|
1613
|
+
else:
|
|
1614
|
+
# Sequential execution
|
|
1615
|
+
tqdm.write("⚙️ Processing tasks sequentially")
|
|
1616
|
+
for i, task in enumerate(orchestrator_tasks):
|
|
1617
|
+
try:
|
|
1618
|
+
await asyncio.wait_for(task, timeout=timeout)
|
|
1619
|
+
except asyncio.TimeoutError:
|
|
1620
|
+
self.logger.warning(f"Task {i+1} timed out")
|
|
1621
|
+
tqdm.write(f"⚠️ Task {i+1} timed out, continuing with next task")
|
|
1622
|
+
continue
|
|
1623
|
+
except Exception as e:
|
|
1624
|
+
self.logger.error(f"Error processing task {i+1}: {str(e)}")
|
|
1625
|
+
continue
|
|
1626
|
+
|
|
1627
|
+
async def _finalize_results(
|
|
1628
|
+
self, skip_upload: bool, skip_evals: bool, eval_run, output_path: str, scan_name: str
|
|
1629
|
+
) -> RedTeamResult:
|
|
1630
|
+
"""Process and finalize scan results."""
|
|
1631
|
+
log_section_header(self.logger, "Processing results")
|
|
1632
|
+
|
|
1633
|
+
# Convert results to RedTeamResult (now builds AOAI summary internally)
|
|
1634
|
+
red_team_result = self.result_processor.to_red_team_result(
|
|
1635
|
+
red_team_info=self.red_team_info,
|
|
1636
|
+
eval_run=eval_run,
|
|
1637
|
+
scan_name=scan_name,
|
|
1638
|
+
)
|
|
1639
|
+
|
|
1640
|
+
# Extract AOAI summary for passing to MLflow logging
|
|
1641
|
+
aoai_summary = red_team_result.scan_result.get("AOAI_Compatible_Summary")
|
|
1642
|
+
if self._app_insights_configuration:
|
|
1643
|
+
# Get redacted results from the result processor for App Insights logging
|
|
1644
|
+
redacted_results = self.result_processor.get_app_insights_redacted_results(
|
|
1645
|
+
aoai_summary["output_items"]["data"]
|
|
1646
|
+
)
|
|
1647
|
+
emit_eval_result_events_to_app_insights(self._app_insights_configuration, redacted_results)
|
|
1648
|
+
# Log results to MLFlow if not skipping upload
|
|
1649
|
+
if not skip_upload:
|
|
1650
|
+
self.logger.info("Logging results to AI Foundry")
|
|
1651
|
+
await self.mlflow_integration.log_redteam_results_to_mlflow(
|
|
1652
|
+
redteam_result=red_team_result,
|
|
1653
|
+
eval_run=eval_run,
|
|
1654
|
+
red_team_info=self.red_team_info,
|
|
1655
|
+
_skip_evals=skip_evals,
|
|
1656
|
+
aoai_summary=aoai_summary,
|
|
1657
|
+
)
|
|
1658
|
+
# Write output to specified path
|
|
1659
|
+
if output_path and red_team_result.scan_result:
|
|
1660
|
+
abs_output_path = output_path if os.path.isabs(output_path) else os.path.abspath(output_path)
|
|
1661
|
+
self.logger.info(f"Writing output to {abs_output_path}")
|
|
1662
|
+
|
|
1663
|
+
# Ensure output_path is treated as a directory
|
|
1664
|
+
# If it exists as a file, remove it first
|
|
1665
|
+
if os.path.exists(abs_output_path) and not os.path.isdir(abs_output_path):
|
|
1666
|
+
os.remove(abs_output_path)
|
|
1667
|
+
os.makedirs(abs_output_path, exist_ok=True)
|
|
1668
|
+
|
|
1669
|
+
# Create a copy of scan_result without AOAI properties for eval_result.json
|
|
1670
|
+
scan_result_without_aoai = {
|
|
1671
|
+
key: value
|
|
1672
|
+
for key, value in red_team_result.scan_result.items()
|
|
1673
|
+
if not key.startswith("AOAI_Compatible")
|
|
1674
|
+
}
|
|
1675
|
+
|
|
1676
|
+
# Write scan result without AOAI properties to eval_result.json
|
|
1677
|
+
_write_output(abs_output_path, scan_result_without_aoai)
|
|
1678
|
+
|
|
1679
|
+
# Write the AOAI summary to results.json
|
|
1680
|
+
if aoai_summary:
|
|
1681
|
+
_write_output(os.path.join(abs_output_path, "results.json"), aoai_summary)
|
|
1682
|
+
else:
|
|
1683
|
+
self.logger.warning("AOAI summary not available for output_path write")
|
|
1684
|
+
|
|
1685
|
+
# Also save a copy to the scan output directory if available
|
|
1686
|
+
if self.scan_output_dir:
|
|
1687
|
+
final_output = os.path.join(self.scan_output_dir, "final_results.json")
|
|
1688
|
+
_write_output(final_output, red_team_result.scan_result)
|
|
1689
|
+
elif red_team_result.scan_result and self.scan_output_dir:
|
|
1690
|
+
# If no output_path was specified but we have scan_output_dir, save there
|
|
1691
|
+
final_output = os.path.join(self.scan_output_dir, "final_results.json")
|
|
1692
|
+
_write_output(final_output, red_team_result.scan_result)
|
|
1693
|
+
|
|
1694
|
+
# Display final scorecard and results
|
|
1695
|
+
if red_team_result.scan_result:
|
|
1696
|
+
scorecard = format_scorecard(red_team_result.scan_result)
|
|
1697
|
+
tqdm.write(scorecard)
|
|
1698
|
+
|
|
1699
|
+
# Print URL for detailed results
|
|
1700
|
+
studio_url = red_team_result.scan_result.get("studio_url", "")
|
|
1701
|
+
if studio_url:
|
|
1702
|
+
tqdm.write(f"\nDetailed results available at:\n{studio_url}")
|
|
1703
|
+
|
|
1704
|
+
# Print the output directory path
|
|
1705
|
+
if self.scan_output_dir:
|
|
1706
|
+
tqdm.write(f"\n📂 All scan files saved to: {self.scan_output_dir}")
|
|
1707
|
+
|
|
1708
|
+
tqdm.write(f"✅ Scan completed successfully!")
|
|
1709
|
+
self.logger.info("Scan completed successfully")
|
|
1710
|
+
|
|
1711
|
+
# Close file handlers
|
|
1712
|
+
for handler in self.logger.handlers:
|
|
1713
|
+
if isinstance(handler, logging.FileHandler):
|
|
1714
|
+
handler.close()
|
|
1715
|
+
self.logger.removeHandler(handler)
|
|
1716
|
+
|
|
1717
|
+
return red_team_result
|