azure-ai-evaluation 1.0.0b2__py3-none-any.whl → 1.13.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (299) hide show
  1. azure/ai/evaluation/__init__.py +100 -5
  2. azure/ai/evaluation/{_evaluators/_chat → _aoai}/__init__.py +3 -2
  3. azure/ai/evaluation/_aoai/aoai_grader.py +140 -0
  4. azure/ai/evaluation/_aoai/label_grader.py +68 -0
  5. azure/ai/evaluation/_aoai/python_grader.py +86 -0
  6. azure/ai/evaluation/_aoai/score_model_grader.py +94 -0
  7. azure/ai/evaluation/_aoai/string_check_grader.py +66 -0
  8. azure/ai/evaluation/_aoai/text_similarity_grader.py +80 -0
  9. azure/ai/evaluation/_azure/__init__.py +3 -0
  10. azure/ai/evaluation/_azure/_clients.py +204 -0
  11. azure/ai/evaluation/_azure/_envs.py +207 -0
  12. azure/ai/evaluation/_azure/_models.py +227 -0
  13. azure/ai/evaluation/_azure/_token_manager.py +129 -0
  14. azure/ai/evaluation/_common/__init__.py +9 -1
  15. azure/ai/evaluation/{simulator/_helpers → _common}/_experimental.py +24 -9
  16. azure/ai/evaluation/_common/constants.py +131 -2
  17. azure/ai/evaluation/_common/evaluation_onedp_client.py +169 -0
  18. azure/ai/evaluation/_common/math.py +89 -0
  19. azure/ai/evaluation/_common/onedp/__init__.py +32 -0
  20. azure/ai/evaluation/_common/onedp/_client.py +166 -0
  21. azure/ai/evaluation/_common/onedp/_configuration.py +72 -0
  22. azure/ai/evaluation/_common/onedp/_model_base.py +1232 -0
  23. azure/ai/evaluation/_common/onedp/_patch.py +21 -0
  24. azure/ai/evaluation/_common/onedp/_serialization.py +2032 -0
  25. azure/ai/evaluation/_common/onedp/_types.py +21 -0
  26. azure/ai/evaluation/_common/onedp/_utils/__init__.py +6 -0
  27. azure/ai/evaluation/_common/onedp/_utils/model_base.py +1232 -0
  28. azure/ai/evaluation/_common/onedp/_utils/serialization.py +2032 -0
  29. azure/ai/evaluation/_common/onedp/_validation.py +66 -0
  30. azure/ai/evaluation/_common/onedp/_vendor.py +50 -0
  31. azure/ai/evaluation/_common/onedp/_version.py +9 -0
  32. azure/ai/evaluation/_common/onedp/aio/__init__.py +29 -0
  33. azure/ai/evaluation/_common/onedp/aio/_client.py +168 -0
  34. azure/ai/evaluation/_common/onedp/aio/_configuration.py +72 -0
  35. azure/ai/evaluation/_common/onedp/aio/_patch.py +21 -0
  36. azure/ai/evaluation/_common/onedp/aio/operations/__init__.py +49 -0
  37. azure/ai/evaluation/_common/onedp/aio/operations/_operations.py +7143 -0
  38. azure/ai/evaluation/_common/onedp/aio/operations/_patch.py +21 -0
  39. azure/ai/evaluation/_common/onedp/models/__init__.py +358 -0
  40. azure/ai/evaluation/_common/onedp/models/_enums.py +447 -0
  41. azure/ai/evaluation/_common/onedp/models/_models.py +5963 -0
  42. azure/ai/evaluation/_common/onedp/models/_patch.py +21 -0
  43. azure/ai/evaluation/_common/onedp/operations/__init__.py +49 -0
  44. azure/ai/evaluation/_common/onedp/operations/_operations.py +8951 -0
  45. azure/ai/evaluation/_common/onedp/operations/_patch.py +21 -0
  46. azure/ai/evaluation/_common/onedp/py.typed +1 -0
  47. azure/ai/evaluation/_common/onedp/servicepatterns/__init__.py +1 -0
  48. azure/ai/evaluation/_common/onedp/servicepatterns/aio/__init__.py +1 -0
  49. azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/__init__.py +25 -0
  50. azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/_operations.py +34 -0
  51. azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/_patch.py +20 -0
  52. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/__init__.py +1 -0
  53. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/__init__.py +1 -0
  54. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/__init__.py +22 -0
  55. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/_operations.py +29 -0
  56. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/_patch.py +20 -0
  57. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/__init__.py +22 -0
  58. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/_operations.py +29 -0
  59. azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/_patch.py +20 -0
  60. azure/ai/evaluation/_common/onedp/servicepatterns/operations/__init__.py +25 -0
  61. azure/ai/evaluation/_common/onedp/servicepatterns/operations/_operations.py +34 -0
  62. azure/ai/evaluation/_common/onedp/servicepatterns/operations/_patch.py +20 -0
  63. azure/ai/evaluation/_common/rai_service.py +831 -142
  64. azure/ai/evaluation/_common/raiclient/__init__.py +34 -0
  65. azure/ai/evaluation/_common/raiclient/_client.py +128 -0
  66. azure/ai/evaluation/_common/raiclient/_configuration.py +87 -0
  67. azure/ai/evaluation/_common/raiclient/_model_base.py +1235 -0
  68. azure/ai/evaluation/_common/raiclient/_patch.py +20 -0
  69. azure/ai/evaluation/_common/raiclient/_serialization.py +2050 -0
  70. azure/ai/evaluation/_common/raiclient/_version.py +9 -0
  71. azure/ai/evaluation/_common/raiclient/aio/__init__.py +29 -0
  72. azure/ai/evaluation/_common/raiclient/aio/_client.py +130 -0
  73. azure/ai/evaluation/_common/raiclient/aio/_configuration.py +87 -0
  74. azure/ai/evaluation/_common/raiclient/aio/_patch.py +20 -0
  75. azure/ai/evaluation/_common/raiclient/aio/operations/__init__.py +25 -0
  76. azure/ai/evaluation/_common/raiclient/aio/operations/_operations.py +981 -0
  77. azure/ai/evaluation/_common/raiclient/aio/operations/_patch.py +20 -0
  78. azure/ai/evaluation/_common/raiclient/models/__init__.py +60 -0
  79. azure/ai/evaluation/_common/raiclient/models/_enums.py +18 -0
  80. azure/ai/evaluation/_common/raiclient/models/_models.py +651 -0
  81. azure/ai/evaluation/_common/raiclient/models/_patch.py +20 -0
  82. azure/ai/evaluation/_common/raiclient/operations/__init__.py +25 -0
  83. azure/ai/evaluation/_common/raiclient/operations/_operations.py +1238 -0
  84. azure/ai/evaluation/_common/raiclient/operations/_patch.py +20 -0
  85. azure/ai/evaluation/_common/raiclient/py.typed +1 -0
  86. azure/ai/evaluation/_common/utils.py +870 -34
  87. azure/ai/evaluation/_constants.py +167 -6
  88. azure/ai/evaluation/_converters/__init__.py +3 -0
  89. azure/ai/evaluation/_converters/_ai_services.py +899 -0
  90. azure/ai/evaluation/_converters/_models.py +467 -0
  91. azure/ai/evaluation/_converters/_sk_services.py +495 -0
  92. azure/ai/evaluation/_eval_mapping.py +83 -0
  93. azure/ai/evaluation/_evaluate/_batch_run/__init__.py +17 -0
  94. azure/ai/evaluation/_evaluate/_batch_run/_run_submitter_client.py +176 -0
  95. azure/ai/evaluation/_evaluate/_batch_run/batch_clients.py +82 -0
  96. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/code_client.py +47 -25
  97. azure/ai/evaluation/_evaluate/{_batch_run_client/batch_run_context.py → _batch_run/eval_run_context.py} +42 -13
  98. azure/ai/evaluation/_evaluate/_batch_run/proxy_client.py +124 -0
  99. azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +62 -0
  100. azure/ai/evaluation/_evaluate/_eval_run.py +102 -59
  101. azure/ai/evaluation/_evaluate/_evaluate.py +2134 -311
  102. azure/ai/evaluation/_evaluate/_evaluate_aoai.py +992 -0
  103. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +14 -99
  104. azure/ai/evaluation/_evaluate/_utils.py +289 -40
  105. azure/ai/evaluation/_evaluator_definition.py +76 -0
  106. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +93 -42
  107. azure/ai/evaluation/_evaluators/_code_vulnerability/__init__.py +5 -0
  108. azure/ai/evaluation/_evaluators/_code_vulnerability/_code_vulnerability.py +119 -0
  109. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +117 -91
  110. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +76 -39
  111. azure/ai/evaluation/_evaluators/_common/__init__.py +15 -0
  112. azure/ai/evaluation/_evaluators/_common/_base_eval.py +742 -0
  113. azure/ai/evaluation/_evaluators/_common/_base_multi_eval.py +63 -0
  114. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +345 -0
  115. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +198 -0
  116. azure/ai/evaluation/_evaluators/_common/_conversation_aggregators.py +49 -0
  117. azure/ai/evaluation/_evaluators/_content_safety/__init__.py +0 -4
  118. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +144 -86
  119. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +138 -57
  120. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +123 -55
  121. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +133 -54
  122. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +134 -54
  123. azure/ai/evaluation/_evaluators/_document_retrieval/__init__.py +7 -0
  124. azure/ai/evaluation/_evaluators/_document_retrieval/_document_retrieval.py +442 -0
  125. azure/ai/evaluation/_evaluators/_eci/_eci.py +49 -56
  126. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +102 -60
  127. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +115 -92
  128. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +66 -41
  129. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +90 -37
  130. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +318 -82
  131. azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +114 -0
  132. azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +104 -0
  133. azure/ai/evaluation/{_evaluate/_batch_run_client → _evaluators/_intent_resolution}/__init__.py +3 -4
  134. azure/ai/evaluation/_evaluators/_intent_resolution/_intent_resolution.py +196 -0
  135. azure/ai/evaluation/_evaluators/_intent_resolution/intent_resolution.prompty +275 -0
  136. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +107 -61
  137. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +104 -77
  138. azure/ai/evaluation/_evaluators/_qa/_qa.py +115 -63
  139. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +182 -98
  140. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +178 -49
  141. azure/ai/evaluation/_evaluators/_response_completeness/__init__.py +7 -0
  142. azure/ai/evaluation/_evaluators/_response_completeness/_response_completeness.py +202 -0
  143. azure/ai/evaluation/_evaluators/_response_completeness/response_completeness.prompty +84 -0
  144. azure/ai/evaluation/_evaluators/{_chat/retrieval → _retrieval}/__init__.py +2 -2
  145. azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +148 -0
  146. azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +93 -0
  147. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +189 -50
  148. azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
  149. azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +179 -0
  150. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +102 -91
  151. azure/ai/evaluation/_evaluators/_similarity/similarity.prompty +0 -5
  152. azure/ai/evaluation/_evaluators/_task_adherence/__init__.py +7 -0
  153. azure/ai/evaluation/_evaluators/_task_adherence/_task_adherence.py +226 -0
  154. azure/ai/evaluation/_evaluators/_task_adherence/task_adherence.prompty +101 -0
  155. azure/ai/evaluation/_evaluators/_task_completion/__init__.py +7 -0
  156. azure/ai/evaluation/_evaluators/_task_completion/_task_completion.py +177 -0
  157. azure/ai/evaluation/_evaluators/_task_completion/task_completion.prompty +220 -0
  158. azure/ai/evaluation/_evaluators/_task_navigation_efficiency/__init__.py +7 -0
  159. azure/ai/evaluation/_evaluators/_task_navigation_efficiency/_task_navigation_efficiency.py +384 -0
  160. azure/ai/evaluation/_evaluators/_tool_call_accuracy/__init__.py +9 -0
  161. azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py +298 -0
  162. azure/ai/evaluation/_evaluators/_tool_call_accuracy/tool_call_accuracy.prompty +166 -0
  163. azure/ai/evaluation/_evaluators/_tool_input_accuracy/__init__.py +9 -0
  164. azure/ai/evaluation/_evaluators/_tool_input_accuracy/_tool_input_accuracy.py +263 -0
  165. azure/ai/evaluation/_evaluators/_tool_input_accuracy/tool_input_accuracy.prompty +76 -0
  166. azure/ai/evaluation/_evaluators/_tool_output_utilization/__init__.py +7 -0
  167. azure/ai/evaluation/_evaluators/_tool_output_utilization/_tool_output_utilization.py +225 -0
  168. azure/ai/evaluation/_evaluators/_tool_output_utilization/tool_output_utilization.prompty +221 -0
  169. azure/ai/evaluation/_evaluators/_tool_selection/__init__.py +9 -0
  170. azure/ai/evaluation/_evaluators/_tool_selection/_tool_selection.py +266 -0
  171. azure/ai/evaluation/_evaluators/_tool_selection/tool_selection.prompty +104 -0
  172. azure/ai/evaluation/_evaluators/_tool_success/__init__.py +7 -0
  173. azure/ai/evaluation/_evaluators/_tool_success/_tool_success.py +301 -0
  174. azure/ai/evaluation/_evaluators/_tool_success/tool_success.prompty +321 -0
  175. azure/ai/evaluation/_evaluators/_ungrounded_attributes/__init__.py +5 -0
  176. azure/ai/evaluation/_evaluators/_ungrounded_attributes/_ungrounded_attributes.py +102 -0
  177. azure/ai/evaluation/_evaluators/_xpia/xpia.py +109 -107
  178. azure/ai/evaluation/_exceptions.py +51 -7
  179. azure/ai/evaluation/_http_utils.py +210 -137
  180. azure/ai/evaluation/_legacy/__init__.py +3 -0
  181. azure/ai/evaluation/_legacy/_adapters/__init__.py +7 -0
  182. azure/ai/evaluation/_legacy/_adapters/_check.py +17 -0
  183. azure/ai/evaluation/_legacy/_adapters/_configuration.py +45 -0
  184. azure/ai/evaluation/_legacy/_adapters/_constants.py +10 -0
  185. azure/ai/evaluation/_legacy/_adapters/_errors.py +29 -0
  186. azure/ai/evaluation/_legacy/_adapters/_flows.py +28 -0
  187. azure/ai/evaluation/_legacy/_adapters/_service.py +16 -0
  188. azure/ai/evaluation/_legacy/_adapters/client.py +51 -0
  189. azure/ai/evaluation/_legacy/_adapters/entities.py +26 -0
  190. azure/ai/evaluation/_legacy/_adapters/tracing.py +28 -0
  191. azure/ai/evaluation/_legacy/_adapters/types.py +15 -0
  192. azure/ai/evaluation/_legacy/_adapters/utils.py +31 -0
  193. azure/ai/evaluation/_legacy/_batch_engine/__init__.py +9 -0
  194. azure/ai/evaluation/_legacy/_batch_engine/_config.py +48 -0
  195. azure/ai/evaluation/_legacy/_batch_engine/_engine.py +477 -0
  196. azure/ai/evaluation/_legacy/_batch_engine/_exceptions.py +88 -0
  197. azure/ai/evaluation/_legacy/_batch_engine/_openai_injector.py +132 -0
  198. azure/ai/evaluation/_legacy/_batch_engine/_result.py +107 -0
  199. azure/ai/evaluation/_legacy/_batch_engine/_run.py +127 -0
  200. azure/ai/evaluation/_legacy/_batch_engine/_run_storage.py +128 -0
  201. azure/ai/evaluation/_legacy/_batch_engine/_run_submitter.py +262 -0
  202. azure/ai/evaluation/_legacy/_batch_engine/_status.py +25 -0
  203. azure/ai/evaluation/_legacy/_batch_engine/_trace.py +97 -0
  204. azure/ai/evaluation/_legacy/_batch_engine/_utils.py +97 -0
  205. azure/ai/evaluation/_legacy/_batch_engine/_utils_deprecated.py +131 -0
  206. azure/ai/evaluation/_legacy/_common/__init__.py +3 -0
  207. azure/ai/evaluation/_legacy/_common/_async_token_provider.py +117 -0
  208. azure/ai/evaluation/_legacy/_common/_logging.py +292 -0
  209. azure/ai/evaluation/_legacy/_common/_thread_pool_executor_with_context.py +17 -0
  210. azure/ai/evaluation/_legacy/prompty/__init__.py +36 -0
  211. azure/ai/evaluation/_legacy/prompty/_connection.py +119 -0
  212. azure/ai/evaluation/_legacy/prompty/_exceptions.py +139 -0
  213. azure/ai/evaluation/_legacy/prompty/_prompty.py +430 -0
  214. azure/ai/evaluation/_legacy/prompty/_utils.py +663 -0
  215. azure/ai/evaluation/_legacy/prompty/_yaml_utils.py +99 -0
  216. azure/ai/evaluation/_model_configurations.py +130 -8
  217. azure/ai/evaluation/_safety_evaluation/__init__.py +3 -0
  218. azure/ai/evaluation/_safety_evaluation/_generated_rai_client.py +0 -0
  219. azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +917 -0
  220. azure/ai/evaluation/_user_agent.py +32 -1
  221. azure/ai/evaluation/_vendor/__init__.py +3 -0
  222. azure/ai/evaluation/_vendor/rouge_score/__init__.py +14 -0
  223. azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +324 -0
  224. azure/ai/evaluation/_vendor/rouge_score/scoring.py +59 -0
  225. azure/ai/evaluation/_vendor/rouge_score/tokenize.py +59 -0
  226. azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +53 -0
  227. azure/ai/evaluation/_version.py +2 -1
  228. azure/ai/evaluation/red_team/__init__.py +22 -0
  229. azure/ai/evaluation/red_team/_agent/__init__.py +3 -0
  230. azure/ai/evaluation/red_team/_agent/_agent_functions.py +261 -0
  231. azure/ai/evaluation/red_team/_agent/_agent_tools.py +461 -0
  232. azure/ai/evaluation/red_team/_agent/_agent_utils.py +89 -0
  233. azure/ai/evaluation/red_team/_agent/_semantic_kernel_plugin.py +228 -0
  234. azure/ai/evaluation/red_team/_attack_objective_generator.py +268 -0
  235. azure/ai/evaluation/red_team/_attack_strategy.py +49 -0
  236. azure/ai/evaluation/red_team/_callback_chat_target.py +115 -0
  237. azure/ai/evaluation/red_team/_default_converter.py +21 -0
  238. azure/ai/evaluation/red_team/_evaluation_processor.py +505 -0
  239. azure/ai/evaluation/red_team/_mlflow_integration.py +430 -0
  240. azure/ai/evaluation/red_team/_orchestrator_manager.py +803 -0
  241. azure/ai/evaluation/red_team/_red_team.py +1717 -0
  242. azure/ai/evaluation/red_team/_red_team_result.py +661 -0
  243. azure/ai/evaluation/red_team/_result_processor.py +1708 -0
  244. azure/ai/evaluation/red_team/_utils/__init__.py +37 -0
  245. azure/ai/evaluation/red_team/_utils/_rai_service_eval_chat_target.py +128 -0
  246. azure/ai/evaluation/red_team/_utils/_rai_service_target.py +601 -0
  247. azure/ai/evaluation/red_team/_utils/_rai_service_true_false_scorer.py +114 -0
  248. azure/ai/evaluation/red_team/_utils/constants.py +72 -0
  249. azure/ai/evaluation/red_team/_utils/exception_utils.py +345 -0
  250. azure/ai/evaluation/red_team/_utils/file_utils.py +266 -0
  251. azure/ai/evaluation/red_team/_utils/formatting_utils.py +365 -0
  252. azure/ai/evaluation/red_team/_utils/logging_utils.py +139 -0
  253. azure/ai/evaluation/red_team/_utils/metric_mapping.py +73 -0
  254. azure/ai/evaluation/red_team/_utils/objective_utils.py +46 -0
  255. azure/ai/evaluation/red_team/_utils/progress_utils.py +252 -0
  256. azure/ai/evaluation/red_team/_utils/retry_utils.py +218 -0
  257. azure/ai/evaluation/red_team/_utils/strategy_utils.py +218 -0
  258. azure/ai/evaluation/simulator/__init__.py +2 -1
  259. azure/ai/evaluation/simulator/_adversarial_scenario.py +26 -1
  260. azure/ai/evaluation/simulator/_adversarial_simulator.py +270 -144
  261. azure/ai/evaluation/simulator/_constants.py +12 -1
  262. azure/ai/evaluation/simulator/_conversation/__init__.py +151 -23
  263. azure/ai/evaluation/simulator/_conversation/_conversation.py +10 -6
  264. azure/ai/evaluation/simulator/_conversation/constants.py +1 -1
  265. azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
  266. azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
  267. azure/ai/evaluation/simulator/_direct_attack_simulator.py +54 -75
  268. azure/ai/evaluation/simulator/_helpers/__init__.py +1 -2
  269. azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py +1 -0
  270. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +26 -5
  271. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +145 -104
  272. azure/ai/evaluation/simulator/_model_tools/__init__.py +2 -1
  273. azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +225 -0
  274. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +80 -30
  275. azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +117 -45
  276. azure/ai/evaluation/simulator/_model_tools/_rai_client.py +109 -7
  277. azure/ai/evaluation/simulator/_model_tools/_template_handler.py +97 -33
  278. azure/ai/evaluation/simulator/_model_tools/models.py +30 -27
  279. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +6 -10
  280. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +6 -5
  281. azure/ai/evaluation/simulator/_simulator.py +302 -208
  282. azure/ai/evaluation/simulator/_utils.py +31 -13
  283. azure_ai_evaluation-1.13.3.dist-info/METADATA +939 -0
  284. azure_ai_evaluation-1.13.3.dist-info/RECORD +305 -0
  285. {azure_ai_evaluation-1.0.0b2.dist-info → azure_ai_evaluation-1.13.3.dist-info}/WHEEL +1 -1
  286. azure_ai_evaluation-1.13.3.dist-info/licenses/NOTICE.txt +70 -0
  287. azure/ai/evaluation/_evaluate/_batch_run_client/proxy_client.py +0 -71
  288. azure/ai/evaluation/_evaluators/_chat/_chat.py +0 -357
  289. azure/ai/evaluation/_evaluators/_chat/retrieval/_retrieval.py +0 -157
  290. azure/ai/evaluation/_evaluators/_chat/retrieval/retrieval.prompty +0 -48
  291. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_base.py +0 -65
  292. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +0 -301
  293. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -54
  294. azure/ai/evaluation/_evaluators/_protected_materials/__init__.py +0 -5
  295. azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py +0 -104
  296. azure/ai/evaluation/simulator/_tracing.py +0 -89
  297. azure_ai_evaluation-1.0.0b2.dist-info/METADATA +0 -449
  298. azure_ai_evaluation-1.0.0b2.dist-info/RECORD +0 -99
  299. {azure_ai_evaluation-1.0.0b2.dist-info → azure_ai_evaluation-1.13.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1717 @@
1
+ # ---------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # ---------------------------------------------------------
4
+ # Third-party imports
5
+ import asyncio
6
+ import itertools
7
+ import logging
8
+ import math
9
+ import os
10
+ from pathlib import Path
11
+ import random
12
+ import time
13
+ import uuid
14
+ from datetime import datetime
15
+ from typing import Callable, Dict, List, Optional, Union, cast, Any
16
+ from tqdm import tqdm
17
+
18
+ # Azure AI Evaluation imports
19
+ from azure.ai.evaluation._constants import TokenScope
20
+ from azure.ai.evaluation._common._experimental import experimental
21
+
22
+ from azure.ai.evaluation._evaluate._evaluate import (
23
+ emit_eval_result_events_to_app_insights,
24
+ ) # TODO: uncomment when app insights checked in
25
+ from azure.ai.evaluation._model_configurations import EvaluationResult
26
+ from azure.ai.evaluation.simulator._model_tools import ManagedIdentityAPITokenManager
27
+ from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient
28
+ from azure.ai.evaluation._user_agent import UserAgentSingleton
29
+ from azure.ai.evaluation._model_configurations import (
30
+ AzureOpenAIModelConfiguration,
31
+ OpenAIModelConfiguration,
32
+ )
33
+ from azure.ai.evaluation._exceptions import (
34
+ ErrorBlame,
35
+ ErrorCategory,
36
+ ErrorTarget,
37
+ EvaluationException,
38
+ )
39
+ from azure.ai.evaluation._common.utils import (
40
+ validate_azure_ai_project,
41
+ is_onedp_project,
42
+ )
43
+ from azure.ai.evaluation._evaluate._utils import _write_output
44
+
45
+ # Azure Core imports
46
+ from azure.core.credentials import TokenCredential
47
+
48
+ # Red Teaming imports
49
+ from ._red_team_result import RedTeamResult
50
+ from ._attack_strategy import AttackStrategy
51
+ from ._attack_objective_generator import (
52
+ RiskCategory,
53
+ SupportedLanguages,
54
+ _AttackObjectiveGenerator,
55
+ )
56
+
57
+ # PyRIT imports
58
+ from pyrit.common import initialize_pyrit, DUCK_DB
59
+ from pyrit.prompt_target import PromptChatTarget
60
+
61
+ # Local imports - constants and utilities
62
+ from ._utils.constants import TASK_STATUS, MAX_SAMPLING_ITERATIONS_MULTIPLIER, RISK_TO_NUM_SUBTYPE_MAP
63
+ from ._utils.logging_utils import (
64
+ setup_logger,
65
+ log_section_header,
66
+ log_subsection_header,
67
+ )
68
+ from ._utils.formatting_utils import (
69
+ get_strategy_name,
70
+ get_flattened_attack_strategies,
71
+ write_pyrit_outputs_to_file,
72
+ format_scorecard,
73
+ format_content_by_modality,
74
+ )
75
+ from ._utils.strategy_utils import get_chat_target, get_converter_for_strategy
76
+ from ._utils.retry_utils import create_standard_retry_manager
77
+ from ._utils.file_utils import create_file_manager
78
+ from ._utils.metric_mapping import get_attack_objective_from_risk_category
79
+ from ._utils.objective_utils import extract_risk_subtype, get_objective_id
80
+
81
+ from ._orchestrator_manager import OrchestratorManager
82
+ from ._evaluation_processor import EvaluationProcessor
83
+ from ._mlflow_integration import MLflowIntegration
84
+ from ._result_processor import ResultProcessor
85
+
86
+
87
+ @experimental
88
+ class RedTeam:
89
+ """
90
+ This class uses various attack strategies to test the robustness of AI models against adversarial inputs.
91
+ It logs the results of these evaluations and provides detailed scorecards summarizing the attack success rates.
92
+
93
+ :param azure_ai_project: The Azure AI project, which can either be a string representing the project endpoint
94
+ or an instance of AzureAIProject. It contains subscription id, resource group, and project name.
95
+ :type azure_ai_project: Union[str, ~azure.ai.evaluation.AzureAIProject]
96
+ :param credential: The credential to authenticate with Azure services
97
+ :type credential: TokenCredential
98
+ :param risk_categories: List of risk categories to generate attack objectives for (optional if custom_attack_seed_prompts is provided)
99
+ :type risk_categories: Optional[List[RiskCategory]]
100
+ :param num_objectives: Number of objectives to generate per risk category
101
+ :type num_objectives: int
102
+ :param application_scenario: Description of the application scenario for context
103
+ :type application_scenario: Optional[str]
104
+ :param custom_attack_seed_prompts: Path to a JSON file containing custom attack seed prompts (can be absolute or relative path)
105
+ :type custom_attack_seed_prompts: Optional[str]
106
+ :param language: Language to use for attack objectives generation. Defaults to English.
107
+ :type language: SupportedLanguages
108
+ :param output_dir: Directory to save output files (optional)
109
+ :type output_dir: Optional[str]
110
+ :param attack_success_thresholds: Threshold configuration for determining attack success.
111
+ Should be a dictionary mapping risk categories (RiskCategory enum values) to threshold values,
112
+ or None to use default binary evaluation (evaluation results determine success).
113
+ When using thresholds, scores >= threshold are considered successful attacks.
114
+ :type attack_success_thresholds: Optional[Dict[RiskCategory, int]]
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ azure_ai_project: Union[dict, str],
120
+ credential,
121
+ *,
122
+ risk_categories: Optional[List[RiskCategory]] = None,
123
+ num_objectives: int = 10,
124
+ application_scenario: Optional[str] = None,
125
+ custom_attack_seed_prompts: Optional[str] = None,
126
+ language: SupportedLanguages = SupportedLanguages.English,
127
+ output_dir=".",
128
+ attack_success_thresholds: Optional[Dict[RiskCategory, int]] = None,
129
+ ):
130
+ """Initialize a new Red Team agent for AI model evaluation.
131
+
132
+ Creates a Red Team agent instance configured with the specified parameters.
133
+ This initializes the token management, attack objective generation, and logging
134
+ needed for running red team evaluations against AI models.
135
+
136
+ :param azure_ai_project: The Azure AI project, which can either be a string representing the project endpoint
137
+ or an instance of AzureAIProject. It contains subscription id, resource group, and project name.
138
+ :type azure_ai_project: Union[str, ~azure.ai.evaluation.AzureAIProject]
139
+ :param credential: Authentication credential for Azure services
140
+ :type credential: TokenCredential
141
+ :param risk_categories: List of risk categories to test (required unless custom prompts provided)
142
+ :type risk_categories: Optional[List[RiskCategory]]
143
+ :param num_objectives: Number of attack objectives to generate per risk category
144
+ :type num_objectives: int
145
+ :param application_scenario: Description of the application scenario
146
+ :type application_scenario: Optional[str]
147
+ :param custom_attack_seed_prompts: Path to a JSON file with custom attack prompts
148
+ :type custom_attack_seed_prompts: Optional[str]
149
+ :param language: Language to use for attack objectives generation. Defaults to English.
150
+ :type language: SupportedLanguages
151
+ :param output_dir: Directory to save evaluation outputs and logs. Defaults to current working directory.
152
+ :type output_dir: str
153
+ :param attack_success_thresholds: Threshold configuration for determining attack success.
154
+ Should be a dictionary mapping risk categories (RiskCategory enum values) to threshold values,
155
+ or None to use default binary evaluation (evaluation results determine success).
156
+ When using thresholds, scores >= threshold are considered successful attacks.
157
+ :type attack_success_thresholds: Optional[Dict[RiskCategory, int]]
158
+ """
159
+
160
+ self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
161
+ self.credential = credential
162
+ self.output_dir = output_dir
163
+ self.language = language
164
+ self._one_dp_project = is_onedp_project(azure_ai_project)
165
+
166
+ # Configure attack success thresholds
167
+ self.attack_success_thresholds = self._configure_attack_success_thresholds(attack_success_thresholds)
168
+
169
+ # Initialize basic logger without file handler (will be properly set up during scan)
170
+ self.logger = logging.getLogger("RedTeamLogger")
171
+ self.logger.setLevel(logging.DEBUG)
172
+
173
+ # Only add console handler for now - file handler will be added during scan setup
174
+ if not any(isinstance(h, logging.StreamHandler) for h in self.logger.handlers):
175
+ console_handler = logging.StreamHandler()
176
+ console_handler.setLevel(logging.WARNING)
177
+ console_formatter = logging.Formatter("%(levelname)s - %(message)s")
178
+ console_handler.setFormatter(console_formatter)
179
+ self.logger.addHandler(console_handler)
180
+
181
+ if not self._one_dp_project:
182
+ self.token_manager = ManagedIdentityAPITokenManager(
183
+ token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
184
+ logger=logging.getLogger("RedTeamLogger"),
185
+ credential=cast(TokenCredential, credential),
186
+ )
187
+ else:
188
+ self.token_manager = ManagedIdentityAPITokenManager(
189
+ token_scope=TokenScope.COGNITIVE_SERVICES_MANAGEMENT,
190
+ logger=logging.getLogger("RedTeamLogger"),
191
+ credential=cast(TokenCredential, credential),
192
+ )
193
+
194
+ # Initialize task tracking
195
+ self.task_statuses = {}
196
+ self.total_tasks = 0
197
+ self.completed_tasks = 0
198
+ self.failed_tasks = 0
199
+ self.start_time = None
200
+ self.scan_id = None
201
+ self.scan_session_id = None
202
+ self.scan_output_dir = None
203
+
204
+ # Initialize RAI client
205
+ self.generated_rai_client = GeneratedRAIClient(
206
+ azure_ai_project=self.azure_ai_project, token_manager=self.token_manager.credential
207
+ )
208
+
209
+ # Initialize a cache for attack objectives by risk category and strategy
210
+ self.attack_objectives = {}
211
+
212
+ # Keep track of data and eval result file names
213
+ self.red_team_info = {}
214
+
215
+ # keep track of prompt content to context mapping for evaluation
216
+ self.prompt_to_context = {}
217
+
218
+ # keep track of prompt content to risk_sub_type mapping for evaluation
219
+ self.prompt_to_risk_subtype = {}
220
+
221
+ # Initialize PyRIT
222
+ initialize_pyrit(memory_db_type=DUCK_DB)
223
+
224
+ # Initialize attack objective generator
225
+ self.attack_objective_generator = _AttackObjectiveGenerator(
226
+ risk_categories=risk_categories,
227
+ num_objectives=num_objectives,
228
+ application_scenario=application_scenario,
229
+ custom_attack_seed_prompts=custom_attack_seed_prompts,
230
+ )
231
+
232
+ # Initialize component managers (will be set up during scan)
233
+ self.orchestrator_manager = None
234
+ self.evaluation_processor = None
235
+ self.mlflow_integration = None
236
+ self.result_processor = None
237
+
238
+ # Initialize utility managers
239
+ self.retry_manager = create_standard_retry_manager(logger=self.logger)
240
+ self.file_manager = create_file_manager(base_output_dir=self.output_dir, logger=self.logger)
241
+
242
+ self.logger.debug("RedTeam initialized successfully")
243
+
244
+ def _configure_attack_success_thresholds(
245
+ self, attack_success_thresholds: Optional[Dict[RiskCategory, int]]
246
+ ) -> Dict[str, int]:
247
+ """Configure attack success thresholds for different risk categories."""
248
+ if attack_success_thresholds is None:
249
+ return {}
250
+
251
+ def validate_threshold(value: int, key: str = "threshold") -> None:
252
+ if not isinstance(value, int) or not (0 <= value <= 7):
253
+ raise ValueError(f"Attack success {key} must be an integer between 0 and 7, got: {value}")
254
+
255
+ configured_thresholds = {}
256
+
257
+ if not isinstance(attack_success_thresholds, dict):
258
+ raise ValueError(
259
+ f"attack_success_thresholds must be a dictionary mapping RiskCategory instances to thresholds, or None. Got: {type(attack_success_thresholds)}"
260
+ )
261
+
262
+ # Per-category thresholds
263
+ for key, value in attack_success_thresholds.items():
264
+ validate_threshold(value, f"threshold for {key}")
265
+
266
+ # Normalize the key to string format
267
+ if hasattr(key, "value"):
268
+ category_key = key.value
269
+ else:
270
+ raise ValueError(f"attack_success_thresholds keys must be RiskCategory instance, got: {type(key)}")
271
+
272
+ configured_thresholds[category_key] = value
273
+
274
+ return configured_thresholds
275
+
276
+ def _setup_component_managers(self):
277
+ """Initialize component managers with shared configuration."""
278
+ retry_config = self.retry_manager.get_retry_config()
279
+
280
+ # Initialize orchestrator manager
281
+ self.orchestrator_manager = OrchestratorManager(
282
+ logger=self.logger,
283
+ generated_rai_client=self.generated_rai_client,
284
+ credential=self.credential,
285
+ azure_ai_project=self.azure_ai_project,
286
+ one_dp_project=self._one_dp_project,
287
+ retry_config=retry_config,
288
+ scan_output_dir=self.scan_output_dir,
289
+ red_team=self,
290
+ )
291
+
292
+ # Initialize evaluation processor
293
+ self.evaluation_processor = EvaluationProcessor(
294
+ logger=self.logger,
295
+ azure_ai_project=self.azure_ai_project,
296
+ credential=self.credential,
297
+ attack_success_thresholds=self.attack_success_thresholds,
298
+ retry_config=retry_config,
299
+ scan_session_id=self.scan_session_id,
300
+ scan_output_dir=self.scan_output_dir,
301
+ taxonomy_risk_categories=getattr(self, "taxonomy_risk_categories", None),
302
+ )
303
+
304
+ # Initialize MLflow integration
305
+ self.mlflow_integration = MLflowIntegration(
306
+ logger=self.logger,
307
+ azure_ai_project=self.azure_ai_project,
308
+ generated_rai_client=self.generated_rai_client,
309
+ one_dp_project=self._one_dp_project,
310
+ scan_output_dir=self.scan_output_dir,
311
+ )
312
+
313
+ # Initialize result processor
314
+ self.result_processor = ResultProcessor(
315
+ logger=self.logger,
316
+ attack_success_thresholds=self.attack_success_thresholds,
317
+ application_scenario=getattr(self, "application_scenario", ""),
318
+ risk_categories=getattr(self, "risk_categories", []),
319
+ ai_studio_url=getattr(self.mlflow_integration, "ai_studio_url", None),
320
+ mlflow_integration=self.mlflow_integration,
321
+ )
322
+
323
+ async def _get_attack_objectives(
324
+ self,
325
+ risk_category: Optional[RiskCategory] = None,
326
+ application_scenario: Optional[str] = None,
327
+ strategy: Optional[str] = None,
328
+ is_agent_target: Optional[bool] = None,
329
+ client_id: Optional[str] = None,
330
+ ) -> List[str]:
331
+ """Get attack objectives from the RAI client for a specific risk category or from a custom dataset.
332
+
333
+ Retrieves attack objectives based on the provided risk category and strategy. These objectives
334
+ can come from either the RAI service or from custom attack seed prompts if provided. The function
335
+ handles different strategies, including special handling for jailbreak strategy which requires
336
+ applying prefixes to messages. It also maintains a cache of objectives to ensure consistency
337
+ across different strategies for the same risk category.
338
+
339
+ :param risk_category: The specific risk category to get objectives for
340
+ :type risk_category: Optional[RiskCategory]
341
+ :param application_scenario: Optional description of the application scenario for context
342
+ :type application_scenario: Optional[str]
343
+ :param strategy: Optional attack strategy to get specific objectives for
344
+ :type strategy: Optional[str]
345
+ :param is_agent_target: Optional boolean indicating if target is an agent (True) or model (False)
346
+ :type is_agent_target: Optional[bool]
347
+ :return: A list of attack objective prompts
348
+ :rtype: List[str]
349
+ """
350
+ attack_objective_generator = self.attack_objective_generator
351
+
352
+ # Convert risk category to lowercase for consistent caching
353
+ risk_cat_value = get_attack_objective_from_risk_category(risk_category).lower()
354
+ num_objectives = attack_objective_generator.num_objectives
355
+
356
+ # Calculate num_objectives_with_subtypes based on max subtypes across all risk categories
357
+ # Use attack_objective_generator.risk_categories as self.risk_categories may not be set yet
358
+ risk_categories = getattr(self, "risk_categories", None) or attack_objective_generator.risk_categories
359
+ max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in risk_categories), default=0)
360
+ num_objectives_with_subtypes = max(num_objectives, max_num_subtypes)
361
+
362
+ self.logger.debug(
363
+ f"Calculated num_objectives_with_subtypes for {risk_cat_value}: "
364
+ f"max(num_objectives={num_objectives}, max_subtypes={max_num_subtypes}) = {num_objectives_with_subtypes}"
365
+ )
366
+
367
+ log_subsection_header(
368
+ self.logger,
369
+ f"Getting attack objectives for {risk_cat_value}, strategy: {strategy}, num_objectives: {num_objectives}, num_objectives_with_subtypes: {num_objectives_with_subtypes}",
370
+ )
371
+
372
+ # Check if we already have baseline objectives for this risk category
373
+ baseline_key = ((risk_cat_value,), "baseline")
374
+ baseline_objectives_exist = baseline_key in self.attack_objectives
375
+ current_key = ((risk_cat_value,), strategy)
376
+
377
+ # Check if custom attack seed prompts are provided in the generator
378
+ if attack_objective_generator.custom_attack_seed_prompts and attack_objective_generator.validated_prompts:
379
+ # Check if this specific risk category has custom objectives
380
+ custom_objectives = attack_objective_generator.valid_prompts_by_category.get(risk_cat_value, [])
381
+
382
+ if custom_objectives:
383
+ # Use custom objectives for this risk category
384
+ return await self._get_custom_attack_objectives(
385
+ risk_cat_value, num_objectives, num_objectives_with_subtypes, strategy, current_key, is_agent_target
386
+ )
387
+ else:
388
+ # No custom objectives for this risk category, but risk_categories was specified
389
+ # Fetch from service if this risk category is in the requested list
390
+ if (
391
+ self.attack_objective_generator.risk_categories
392
+ and risk_category in self.attack_objective_generator.risk_categories
393
+ ):
394
+ self.logger.info(
395
+ f"No custom objectives found for risk category {risk_cat_value}, fetching from service"
396
+ )
397
+ return await self._get_rai_attack_objectives(
398
+ risk_category,
399
+ risk_cat_value,
400
+ application_scenario,
401
+ strategy,
402
+ baseline_objectives_exist,
403
+ baseline_key,
404
+ current_key,
405
+ num_objectives,
406
+ num_objectives_with_subtypes,
407
+ is_agent_target,
408
+ client_id,
409
+ )
410
+ else:
411
+ # Risk category not in requested list, return empty
412
+ self.logger.warning(
413
+ f"No custom objectives found for risk category {risk_cat_value} and it's not in the requested risk categories"
414
+ )
415
+ return []
416
+ else:
417
+ return await self._get_rai_attack_objectives(
418
+ risk_category,
419
+ risk_cat_value,
420
+ application_scenario,
421
+ strategy,
422
+ baseline_objectives_exist,
423
+ baseline_key,
424
+ current_key,
425
+ num_objectives,
426
+ num_objectives_with_subtypes,
427
+ is_agent_target,
428
+ client_id,
429
+ )
430
+
431
+ async def _get_custom_attack_objectives(
432
+ self,
433
+ risk_cat_value: str,
434
+ num_objectives: int,
435
+ num_objectives_with_subtypes: int,
436
+ strategy: str,
437
+ current_key: tuple,
438
+ is_agent_target: Optional[bool] = None,
439
+ ) -> List[str]:
440
+ """Get attack objectives from custom seed prompts."""
441
+ attack_objective_generator = self.attack_objective_generator
442
+
443
+ self.logger.info(
444
+ f"Using custom attack seed prompts from {attack_objective_generator.custom_attack_seed_prompts}"
445
+ )
446
+
447
+ # Get the prompts for this risk category
448
+ custom_objectives = attack_objective_generator.valid_prompts_by_category.get(risk_cat_value, [])
449
+
450
+ if not custom_objectives:
451
+ self.logger.warning(f"No custom objectives found for risk category {risk_cat_value}")
452
+ return []
453
+
454
+ self.logger.info(f"Found {len(custom_objectives)} custom objectives for {risk_cat_value}")
455
+
456
+ # Deduplicate objectives by ID to avoid selecting the same logical objective multiple times
457
+ seen_ids = set()
458
+ deduplicated_objectives = []
459
+ for obj in custom_objectives:
460
+ obj_id = get_objective_id(obj)
461
+ if obj_id not in seen_ids:
462
+ seen_ids.add(obj_id)
463
+ deduplicated_objectives.append(obj)
464
+
465
+ if len(deduplicated_objectives) < len(custom_objectives):
466
+ self.logger.debug(
467
+ f"Deduplicated {len(custom_objectives)} objectives to {len(deduplicated_objectives)} unique objectives by ID"
468
+ )
469
+
470
+ # Group objectives by risk_subtype if present
471
+ objectives_by_subtype = {}
472
+ objectives_without_subtype = []
473
+
474
+ for obj in deduplicated_objectives:
475
+ risk_subtype = extract_risk_subtype(obj)
476
+
477
+ if risk_subtype:
478
+ if risk_subtype not in objectives_by_subtype:
479
+ objectives_by_subtype[risk_subtype] = []
480
+ objectives_by_subtype[risk_subtype].append(obj)
481
+ else:
482
+ objectives_without_subtype.append(obj)
483
+
484
+ # Determine sampling strategy based on risk_subtype presence
485
+ # Use num_objectives_with_subtypes for initial sampling to ensure coverage
486
+ if objectives_by_subtype:
487
+ # We have risk subtypes - sample evenly across them
488
+ num_subtypes = len(objectives_by_subtype)
489
+ objectives_per_subtype = max(1, num_objectives_with_subtypes // num_subtypes)
490
+
491
+ self.logger.info(
492
+ f"Found {num_subtypes} risk subtypes in custom objectives. "
493
+ f"Sampling {objectives_per_subtype} objectives per subtype to reach ~{num_objectives_with_subtypes} total."
494
+ )
495
+
496
+ selected_cat_objectives = []
497
+ for subtype, subtype_objectives in objectives_by_subtype.items():
498
+ num_to_sample = min(objectives_per_subtype, len(subtype_objectives))
499
+ sampled = random.sample(subtype_objectives, num_to_sample)
500
+ selected_cat_objectives.extend(sampled)
501
+ self.logger.debug(
502
+ f"Sampled {num_to_sample} objectives from risk_subtype '{subtype}' "
503
+ f"({len(subtype_objectives)} available)"
504
+ )
505
+
506
+ # If we need more objectives to reach num_objectives_with_subtypes, sample from objectives without subtype
507
+ if len(selected_cat_objectives) < num_objectives_with_subtypes and objectives_without_subtype:
508
+ remaining = num_objectives_with_subtypes - len(selected_cat_objectives)
509
+ num_to_sample = min(remaining, len(objectives_without_subtype))
510
+ selected_cat_objectives.extend(random.sample(objectives_without_subtype, num_to_sample))
511
+ self.logger.debug(f"Added {num_to_sample} objectives without risk_subtype to reach target count")
512
+
513
+ # If we still need more, round-robin through subtypes again
514
+ if len(selected_cat_objectives) < num_objectives_with_subtypes:
515
+ remaining = num_objectives_with_subtypes - len(selected_cat_objectives)
516
+ subtype_list = list(objectives_by_subtype.keys())
517
+ # Track selected objective IDs in a set for O(1) membership checks
518
+ # Use the objective's 'id' field if available, generate UUID-based ID otherwise
519
+ selected_ids = {get_objective_id(obj) for obj in selected_cat_objectives}
520
+ idx = 0
521
+ while remaining > 0 and subtype_list:
522
+ subtype = subtype_list[idx % len(subtype_list)]
523
+ available = [
524
+ obj for obj in objectives_by_subtype[subtype] if get_objective_id(obj) not in selected_ids
525
+ ]
526
+ if available:
527
+ selected_obj = random.choice(available)
528
+ selected_cat_objectives.append(selected_obj)
529
+ selected_ids.add(get_objective_id(selected_obj))
530
+ remaining -= 1
531
+ idx += 1
532
+ # Prevent infinite loop if we run out of unique objectives
533
+ if idx > len(subtype_list) * MAX_SAMPLING_ITERATIONS_MULTIPLIER:
534
+ break
535
+
536
+ self.logger.info(f"Sampled {len(selected_cat_objectives)} objectives across {num_subtypes} risk subtypes")
537
+ else:
538
+ # No risk subtypes - use num_objectives_with_subtypes for sampling
539
+ if len(custom_objectives) > num_objectives_with_subtypes:
540
+ selected_cat_objectives = random.sample(custom_objectives, num_objectives_with_subtypes)
541
+ self.logger.info(
542
+ f"Sampled {num_objectives_with_subtypes} objectives from {len(custom_objectives)} available for {risk_cat_value}"
543
+ )
544
+ else:
545
+ selected_cat_objectives = custom_objectives
546
+ self.logger.info(f"Using all {len(custom_objectives)} available objectives for {risk_cat_value}")
547
+ target_type_str = "agent" if is_agent_target else "model" if is_agent_target is not None else None
548
+ # Handle jailbreak strategy - need to apply jailbreak prefixes to messages
549
+ if strategy == "jailbreak":
550
+ selected_cat_objectives = await self._apply_jailbreak_prefixes(selected_cat_objectives)
551
+ elif strategy == "indirect_jailbreak":
552
+ selected_cat_objectives = await self._apply_xpia_prompts(selected_cat_objectives, target_type_str)
553
+
554
+ # Extract content from selected objectives
555
+ selected_prompts = []
556
+ for obj in selected_cat_objectives:
557
+ # Extract risk-subtype from target_harms if present
558
+ risk_subtype = extract_risk_subtype(obj)
559
+
560
+ if "messages" in obj and len(obj["messages"]) > 0:
561
+ message = obj["messages"][0]
562
+ if isinstance(message, dict) and "content" in message:
563
+ content = message["content"]
564
+ context = message.get("context", "")
565
+ selected_prompts.append(content)
566
+ # Store mapping of content to context for later evaluation
567
+ self.prompt_to_context[content] = context
568
+ # Store risk_subtype mapping if it exists
569
+ if risk_subtype:
570
+ self.prompt_to_risk_subtype[content] = risk_subtype
571
+
572
+ # Store in cache and return
573
+ self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives)
574
+ return selected_prompts
575
+
576
+ async def _get_rai_attack_objectives(
577
+ self,
578
+ risk_category: RiskCategory,
579
+ risk_cat_value: str,
580
+ application_scenario: str,
581
+ strategy: str,
582
+ baseline_objectives_exist: bool,
583
+ baseline_key: tuple,
584
+ current_key: tuple,
585
+ num_objectives: int,
586
+ num_objectives_with_subtypes: int,
587
+ is_agent_target: Optional[bool] = None,
588
+ client_id: Optional[str] = None,
589
+ ) -> List[str]:
590
+ """Get attack objectives from the RAI service."""
591
+ content_harm_risk = None
592
+ other_risk = ""
593
+ if risk_cat_value in ["hate_unfairness", "violence", "self_harm", "sexual"]:
594
+ content_harm_risk = risk_cat_value
595
+ else:
596
+ other_risk = risk_cat_value
597
+
598
+ try:
599
+ self.logger.debug(
600
+ f"API call: get_attack_objectives({risk_cat_value}, app: {application_scenario}, strategy: {strategy})"
601
+ )
602
+
603
+ # Get objectives from RAI service
604
+ target_type_str = "agent" if is_agent_target else "model" if is_agent_target is not None else None
605
+
606
+ objectives_response = await self.generated_rai_client.get_attack_objectives(
607
+ risk_type=content_harm_risk,
608
+ risk_category=other_risk,
609
+ application_scenario=application_scenario or "",
610
+ strategy=None,
611
+ language=self.language.value,
612
+ scan_session_id=self.scan_session_id,
613
+ target=target_type_str,
614
+ client_id=client_id,
615
+ )
616
+
617
+ if isinstance(objectives_response, list):
618
+ self.logger.debug(f"API returned {len(objectives_response)} objectives")
619
+ # Handle jailbreak strategy
620
+ if strategy == "jailbreak":
621
+ objectives_response = await self._apply_jailbreak_prefixes(objectives_response)
622
+ elif strategy == "indirect_jailbreak":
623
+ objectives_response = await self._apply_xpia_prompts(objectives_response, target_type_str)
624
+
625
+ except Exception as e:
626
+ self.logger.warning(f"Error calling get_attack_objectives: {str(e)}")
627
+ objectives_response = {}
628
+
629
+ # Check if the response is valid
630
+ if not objectives_response or (
631
+ isinstance(objectives_response, dict) and not objectives_response.get("objectives")
632
+ ):
633
+ # If we got no agent objectives, fallback to model objectives
634
+ if is_agent_target:
635
+ self.logger.warning(
636
+ f"No agent-type attack objectives found for {risk_cat_value}. "
637
+ "Falling back to model-type objectives."
638
+ )
639
+ try:
640
+ # Retry with model target type
641
+ objectives_response = await self.generated_rai_client.get_attack_objectives(
642
+ risk_type=content_harm_risk,
643
+ risk_category=other_risk,
644
+ application_scenario=application_scenario or "",
645
+ strategy=None,
646
+ language=self.language.value,
647
+ scan_session_id=self.scan_session_id,
648
+ target="model",
649
+ client_id=client_id,
650
+ )
651
+
652
+ if isinstance(objectives_response, list):
653
+ self.logger.debug(f"Fallback API returned {len(objectives_response)} model-type objectives")
654
+
655
+ # Apply strategy-specific transformations to fallback objectives
656
+ # Still try agent-type attack techniques (jailbreak/XPIA) even with model-type baseline objectives
657
+ if strategy == "jailbreak":
658
+ objectives_response = await self._apply_jailbreak_prefixes(objectives_response)
659
+ elif strategy == "indirect_jailbreak":
660
+ objectives_response = await self._apply_xpia_prompts(objectives_response, target_type_str)
661
+
662
+ # Check if fallback response is also empty
663
+ if not objectives_response or (
664
+ isinstance(objectives_response, dict) and not objectives_response.get("objectives")
665
+ ):
666
+ self.logger.warning("Fallback to model-type objectives also returned empty list")
667
+ return []
668
+
669
+ except Exception as fallback_error:
670
+ self.logger.error(f"Error calling get_attack_objectives with model fallback: {str(fallback_error)}")
671
+ self.logger.warning("Fallback API call failed, returning empty objectives list")
672
+ return []
673
+ else:
674
+ self.logger.warning("Empty or invalid response, returning empty list")
675
+ return []
676
+
677
+ # Filter and select objectives using num_objectives_with_subtypes
678
+ selected_cat_objectives = self._filter_and_select_objectives(
679
+ objectives_response, strategy, baseline_objectives_exist, baseline_key, num_objectives_with_subtypes
680
+ )
681
+
682
+ # Extract content and cache
683
+ selected_prompts = self._extract_objective_content(selected_cat_objectives)
684
+ self._cache_attack_objectives(current_key, risk_cat_value, strategy, selected_prompts, selected_cat_objectives)
685
+
686
+ return selected_prompts
687
+
688
+ async def _apply_xpia_prompts(self, objectives_list: List, target_type_str: str) -> List:
689
+ """Apply XPIA prompt formatting to objectives for indirect jailbreak strategy.
690
+
691
+ XPIA prompts are wrapper structures that contain:
692
+ - content: benign user query to trigger tool use
693
+ - context: attack vehicle with {attack_text} placeholder
694
+ - context_type: modality for formatting (email, document, html, code)
695
+ - tool_name: name for mock tool
696
+
697
+ We inject the baseline attack objectives into these XPIA wrapper prompts.
698
+ """
699
+ self.logger.debug(f"Applying XPIA prompts to objectives for indirect jailbreak (target_type={target_type_str})")
700
+
701
+ try:
702
+ # Fetch XPIA wrapper prompts from RAI service
703
+ @self.retry_manager.create_retry_decorator(context="xpia_prompts")
704
+ async def get_xpia_prompts_with_retry():
705
+ return await self.generated_rai_client.get_attack_objectives(
706
+ risk_type=None,
707
+ risk_category="xpia",
708
+ application_scenario="",
709
+ strategy=None,
710
+ language=self.language.value,
711
+ scan_session_id=self.scan_session_id,
712
+ target=target_type_str,
713
+ )
714
+
715
+ xpia_prompts = await get_xpia_prompts_with_retry()
716
+
717
+ # If no agent XPIA prompts and we're trying agent, fallback to model
718
+ if (not xpia_prompts or len(xpia_prompts) == 0) and target_type_str == "agent":
719
+ self.logger.debug("No agent-type XPIA prompts available, falling back to model-type XPIA prompts")
720
+ try:
721
+ xpia_prompts = await self.generated_rai_client.get_attack_objectives(
722
+ risk_type=None,
723
+ risk_category="xpia",
724
+ application_scenario="",
725
+ strategy=None,
726
+ language=self.language.value,
727
+ scan_session_id=self.scan_session_id,
728
+ target="model",
729
+ )
730
+ if xpia_prompts and len(xpia_prompts) > 0:
731
+ self.logger.debug(f"Fetched {len(xpia_prompts)} model-type XPIA wrapper prompts as fallback")
732
+ except Exception as fallback_error:
733
+ self.logger.error(f"Error fetching model-type XPIA prompts as fallback: {str(fallback_error)}")
734
+
735
+ if not xpia_prompts or len(xpia_prompts) == 0:
736
+ self.logger.warning("No XPIA prompts available (even after fallback), returning objectives unchanged")
737
+ return objectives_list
738
+
739
+ self.logger.debug(f"Fetched {len(xpia_prompts)} XPIA wrapper prompts")
740
+
741
+ # Apply XPIA wrapping to each baseline objective
742
+ for objective in objectives_list:
743
+ if "messages" in objective and len(objective["messages"]) > 0:
744
+ message = objective["messages"][0]
745
+ if isinstance(message, dict) and "content" in message:
746
+ # Get the baseline attack content to inject
747
+ baseline_attack_content = message["content"]
748
+ # Preserve the original baseline context if it exists
749
+ baseline_context = message.get("context", "")
750
+
751
+ # Normalize baseline_context to a list of context dicts
752
+ baseline_contexts = []
753
+ if baseline_context:
754
+ # Extract baseline context from RAI service format
755
+ context_dict = {"content": baseline_context}
756
+ if message.get("tool_name"):
757
+ context_dict["tool_name"] = message["tool_name"]
758
+ if message.get("context_type"):
759
+ context_dict["context_type"] = message["context_type"]
760
+ baseline_contexts = [context_dict]
761
+
762
+ # Check if baseline contexts have agent fields (context_type, tool_name)
763
+ baseline_contexts_with_agent_fields = []
764
+ baseline_contexts_without_agent_fields = []
765
+
766
+ for ctx in baseline_contexts:
767
+ if isinstance(ctx, dict):
768
+ if "context_type" in ctx or "tool_name" in ctx:
769
+ # This baseline context has agent fields - preserve it separately
770
+ baseline_contexts_with_agent_fields.append(ctx)
771
+ self.logger.debug(
772
+ f"Found baseline context with agent fields: tool_name={ctx.get('tool_name')}, context_type={ctx.get('context_type')}"
773
+ )
774
+ else:
775
+ # This baseline context has no agent fields - can be embedded
776
+ baseline_contexts_without_agent_fields.append(ctx)
777
+ else:
778
+ baseline_contexts_without_agent_fields.append({"content": str(ctx)})
779
+
780
+ # For baseline contexts without agent fields, embed them in the attack content
781
+ if baseline_contexts_without_agent_fields:
782
+ context_texts = [
783
+ ctx.get("content", "")
784
+ for ctx in baseline_contexts_without_agent_fields
785
+ if ctx.get("content")
786
+ ]
787
+ if context_texts:
788
+ combined_context = "\n\n".join(context_texts)
789
+ baseline_attack_content = f"{baseline_attack_content}\n\nContext:\n{combined_context}"
790
+ self.logger.debug(
791
+ f"Embedded {len(context_texts)} baseline context(s) without agent fields into attack content"
792
+ )
793
+
794
+ # Randomly select an XPIA wrapper prompt
795
+ xpia_prompt = random.choice(xpia_prompts)
796
+ xpia_message = xpia_prompt.get("messages", [{}])[0]
797
+
798
+ # Extract XPIA wrapper structure
799
+ user_query = xpia_message.get("content", "")
800
+ attack_vehicle_context = xpia_message.get("context", "")
801
+ context_type = xpia_message.get("context_type") or "text"
802
+ tool_name = xpia_message.get("tool_name", "")
803
+
804
+ # Inject baseline attack (now with appended context) into the {attack_text} placeholder
805
+ if "{attack_text}" in attack_vehicle_context:
806
+ injected_context = attack_vehicle_context.replace("{attack_text}", baseline_attack_content)
807
+ else:
808
+ # No placeholder found, append to end
809
+ injected_context = f"{attack_vehicle_context}\n\n{baseline_attack_content}"
810
+
811
+ # Apply modality-based formatting
812
+ formatted_context = format_content_by_modality(injected_context, context_type)
813
+
814
+ # Update the message with benign user query
815
+ message["content"] = user_query
816
+
817
+ # Build the contexts list: XPIA context + any baseline contexts with agent fields
818
+ contexts = [
819
+ {"content": formatted_context, "context_type": context_type, "tool_name": tool_name}
820
+ ]
821
+
822
+ # Add baseline contexts with agent fields as separate context entries
823
+ if baseline_contexts_with_agent_fields:
824
+ contexts.extend(baseline_contexts_with_agent_fields)
825
+ self.logger.debug(
826
+ f"Preserved {len(baseline_contexts_with_agent_fields)} baseline context(s) with agent fields"
827
+ )
828
+
829
+ message["context"] = contexts
830
+ message["context_type"] = (
831
+ context_type # Keep at message level for backward compat (XPIA primary)
832
+ )
833
+ message["tool_name"] = tool_name
834
+
835
+ self.logger.debug(
836
+ f"Wrapped baseline attack in XPIA: total contexts={len(contexts)}, xpia_tool={tool_name}, xpia_type={context_type}"
837
+ )
838
+
839
+ except Exception as e:
840
+ self.logger.error(f"Error applying XPIA prompts: {str(e)}")
841
+ self.logger.warning("XPIA prompt application failed, returning original objectives")
842
+
843
+ return objectives_list
844
+
845
+ async def _apply_jailbreak_prefixes(self, objectives_list: List) -> List:
846
+ """Apply jailbreak prefixes to objectives."""
847
+ self.logger.debug("Applying jailbreak prefixes to objectives")
848
+ try:
849
+ # Use centralized retry decorator
850
+ @self.retry_manager.create_retry_decorator(context="jailbreak_prefixes")
851
+ async def get_jailbreak_prefixes_with_retry():
852
+ return await self.generated_rai_client.get_jailbreak_prefixes()
853
+
854
+ jailbreak_prefixes = await get_jailbreak_prefixes_with_retry()
855
+ for objective in objectives_list:
856
+ if "messages" in objective and len(objective["messages"]) > 0:
857
+ message = objective["messages"][0]
858
+ if isinstance(message, dict) and "content" in message:
859
+ message["content"] = f"{random.choice(jailbreak_prefixes)} {message['content']}"
860
+ except Exception as e:
861
+ self.logger.error(f"Error applying jailbreak prefixes: {str(e)}")
862
+
863
+ return objectives_list
864
+
865
+ def _filter_and_select_objectives(
866
+ self,
867
+ objectives_response: List,
868
+ strategy: str,
869
+ baseline_objectives_exist: bool,
870
+ baseline_key: tuple,
871
+ num_objectives: int,
872
+ ) -> List:
873
+ """Filter and select objectives based on strategy and baseline requirements."""
874
+ # For non-baseline strategies, filter by baseline IDs if they exist
875
+ if strategy != "baseline" and baseline_objectives_exist:
876
+ self.logger.debug(f"Found existing baseline objectives, will filter {strategy} by baseline IDs")
877
+ baseline_selected_objectives = self.attack_objectives[baseline_key].get("selected_objectives", [])
878
+ baseline_objective_ids = [obj.get("id") for obj in baseline_selected_objectives if "id" in obj]
879
+
880
+ if baseline_objective_ids:
881
+ self.logger.debug(f"Filtering by {len(baseline_objective_ids)} baseline objective IDs for {strategy}")
882
+ # Filter by baseline IDs
883
+ filtered_objectives = [obj for obj in objectives_response if obj.get("id") in baseline_objective_ids]
884
+ self.logger.debug(f"Found {len(filtered_objectives)} matching objectives with baseline IDs")
885
+
886
+ # For strategies like indirect_jailbreak, the RAI service may return multiple
887
+ # objectives per baseline ID (e.g., multiple XPIA variations for one baseline objective).
888
+ # We should select num_objectives total, ensuring each baseline objective gets an XPIA attack.
889
+ # Group by baseline ID and select one objective per baseline ID up to num_objectives.
890
+ selected_by_id = {}
891
+ for obj in filtered_objectives:
892
+ obj_id = obj.get("id")
893
+ if obj_id not in selected_by_id:
894
+ selected_by_id[obj_id] = []
895
+ selected_by_id[obj_id].append(obj)
896
+
897
+ # Select objectives to match num_objectives
898
+ selected_cat_objectives = []
899
+ baseline_ids = list(selected_by_id.keys())
900
+
901
+ # If we have enough baseline IDs to cover num_objectives, select one per baseline ID
902
+ if len(baseline_ids) >= num_objectives:
903
+ # Select from the first num_objectives baseline IDs
904
+ for i in range(num_objectives):
905
+ obj_id = baseline_ids[i]
906
+ selected_cat_objectives.append(random.choice(selected_by_id[obj_id]))
907
+ else:
908
+ # If we have fewer baseline IDs than num_objectives, select all and cycle through
909
+ for i in range(num_objectives):
910
+ obj_id = baseline_ids[i % len(baseline_ids)]
911
+ # For repeated IDs, try to select different variations if available
912
+ available_variations = selected_by_id[obj_id].copy()
913
+ # Remove already selected variations for this baseline ID
914
+ already_selected = [obj for obj in selected_cat_objectives if obj.get("id") == obj_id]
915
+ for selected_obj in already_selected:
916
+ if selected_obj in available_variations:
917
+ available_variations.remove(selected_obj)
918
+
919
+ if available_variations:
920
+ selected_cat_objectives.append(random.choice(available_variations))
921
+ else:
922
+ # If no more variations, reuse one (shouldn't happen with proper XPIA generation)
923
+ selected_cat_objectives.append(random.choice(selected_by_id[obj_id]))
924
+
925
+ self.logger.debug(
926
+ f"Selected {len(selected_cat_objectives)} objectives from {len(baseline_ids)} baseline IDs and {len(filtered_objectives)} total variations for {strategy} strategy"
927
+ )
928
+ else:
929
+ self.logger.warning("No baseline objective IDs found, using random selection")
930
+ selected_cat_objectives = random.sample(
931
+ objectives_response, min(num_objectives, len(objectives_response))
932
+ )
933
+ else:
934
+ # This is the baseline strategy or we don't have baseline objectives yet
935
+ self.logger.debug(f"Using random selection for {strategy} strategy")
936
+ selected_cat_objectives = random.sample(objectives_response, min(num_objectives, len(objectives_response)))
937
+ selection_msg = (
938
+ f"Selected {len(selected_cat_objectives)} objectives using num_objectives={num_objectives} "
939
+ f"(available: {len(objectives_response)})"
940
+ )
941
+ self.logger.info(selection_msg)
942
+ tqdm.write(f"[INFO] {selection_msg}")
943
+
944
+ if len(selected_cat_objectives) < num_objectives:
945
+ self.logger.warning(
946
+ f"Only found {len(selected_cat_objectives)} objectives, fewer than requested {num_objectives}"
947
+ )
948
+
949
+ return selected_cat_objectives
950
+
951
+ def _extract_objective_content(self, selected_objectives: List) -> List[str]:
952
+ """Extract content from selected objectives and build prompt-to-context mapping."""
953
+ selected_prompts = []
954
+ for obj in selected_objectives:
955
+ risk_subtype = extract_risk_subtype(obj)
956
+ if "messages" in obj and len(obj["messages"]) > 0:
957
+ message = obj["messages"][0]
958
+ if isinstance(message, dict) and "content" in message:
959
+ content = message["content"]
960
+ context_raw = message.get("context", "")
961
+ # TODO is first if necessary?
962
+ # Normalize context to always be a list of dicts with 'content' key
963
+ if isinstance(context_raw, list):
964
+ # Already a list - ensure each item is a dict with 'content' key
965
+ contexts = []
966
+ for ctx in context_raw:
967
+ if isinstance(ctx, dict) and "content" in ctx:
968
+ # Preserve all keys including context_type, tool_name if present
969
+ contexts.append(ctx)
970
+ elif isinstance(ctx, str):
971
+ contexts.append({"content": ctx})
972
+ elif context_raw:
973
+ # Single string value - wrap in dict
974
+ contexts = [{"content": context_raw}]
975
+ if message.get("tool_name"):
976
+ contexts[0]["tool_name"] = message["tool_name"]
977
+ if message.get("context_type"):
978
+ contexts[0]["context_type"] = message["context_type"]
979
+ else:
980
+ contexts = []
981
+
982
+ # Check if any context has agent-specific fields
983
+ has_agent_fields = any(
984
+ isinstance(ctx, dict)
985
+ and ("context_type" in ctx and "tool_name" in ctx and ctx["tool_name"] is not None)
986
+ for ctx in contexts
987
+ )
988
+
989
+ # For contexts without agent fields, append them to the content
990
+ # This applies to baseline and any other attack objectives with plain context
991
+ if contexts and not has_agent_fields:
992
+ # Extract all context content and append to the attack content
993
+ context_texts = []
994
+ for ctx in contexts:
995
+ if isinstance(ctx, dict):
996
+ ctx_content = ctx.get("content", "")
997
+ if ctx_content:
998
+ context_texts.append(ctx_content)
999
+
1000
+ if context_texts:
1001
+ # Append context to content
1002
+ combined_context = "\n\n".join(context_texts)
1003
+ content = f"{content}\n\nContext:\n{combined_context}"
1004
+ self.logger.debug(
1005
+ f"Appended {len(context_texts)} context source(s) to attack content (total context length={len(combined_context)})"
1006
+ )
1007
+
1008
+ selected_prompts.append(content)
1009
+
1010
+ # Store risk_subtype mapping if it exists
1011
+ if risk_subtype:
1012
+ self.prompt_to_risk_subtype[content] = risk_subtype
1013
+
1014
+ # Always store contexts if they exist (whether or not they have agent fields)
1015
+ if contexts:
1016
+ context_dict = {"contexts": contexts}
1017
+ if has_agent_fields:
1018
+ self.logger.debug(f"Stored context with agent fields: {len(contexts)} context source(s)")
1019
+ else:
1020
+ self.logger.debug(
1021
+ f"Stored context without agent fields: {len(contexts)} context source(s) (also embedded in content)"
1022
+ )
1023
+ self.prompt_to_context[content] = context_dict
1024
+ else:
1025
+ self.logger.debug(f"No context to store")
1026
+ return selected_prompts
1027
+
1028
+ def _cache_attack_objectives(
1029
+ self,
1030
+ current_key: tuple,
1031
+ risk_cat_value: str,
1032
+ strategy: str,
1033
+ selected_prompts: List[str],
1034
+ selected_objectives: List,
1035
+ ) -> None:
1036
+ """Cache attack objectives for reuse."""
1037
+ objectives_by_category = {risk_cat_value: []}
1038
+
1039
+ # Process list format and organize by category for caching
1040
+ for obj in selected_objectives:
1041
+ obj_id = obj.get("id", f"obj-{uuid.uuid4()}")
1042
+ content = ""
1043
+ context = ""
1044
+ risk_subtype = extract_risk_subtype(obj)
1045
+
1046
+ if "messages" in obj and len(obj["messages"]) > 0:
1047
+
1048
+ message = obj["messages"][0]
1049
+ content = message.get("content", "")
1050
+ context = message.get("context", "")
1051
+ if content:
1052
+ obj_data = {"id": obj_id, "content": content, "context": context}
1053
+ # Add risk_subtype to obj_data if it exists
1054
+ if risk_subtype:
1055
+ obj_data["risk_subtype"] = risk_subtype
1056
+ objectives_by_category[risk_cat_value].append(obj_data)
1057
+
1058
+ self.attack_objectives[current_key] = {
1059
+ "objectives_by_category": objectives_by_category,
1060
+ "strategy": strategy,
1061
+ "risk_category": risk_cat_value,
1062
+ "selected_prompts": selected_prompts,
1063
+ "selected_objectives": selected_objectives,
1064
+ }
1065
+ self.logger.info(f"Selected {len(selected_prompts)} objectives for {risk_cat_value}")
1066
+
1067
+ async def _process_attack(
1068
+ self,
1069
+ strategy: Union[AttackStrategy, List[AttackStrategy]],
1070
+ risk_category: RiskCategory,
1071
+ all_prompts: List[str],
1072
+ progress_bar: tqdm,
1073
+ progress_bar_lock: asyncio.Lock,
1074
+ scan_name: Optional[str] = None,
1075
+ skip_upload: bool = False,
1076
+ output_path: Optional[Union[str, os.PathLike]] = None,
1077
+ timeout: int = 120,
1078
+ _skip_evals: bool = False,
1079
+ ) -> Optional[EvaluationResult]:
1080
+ """Process a red team scan with the given orchestrator, converter, and prompts.
1081
+
1082
+ Executes a red team attack process using the specified strategy and risk category against the
1083
+ target model or function. This includes creating an orchestrator, applying prompts through the
1084
+ appropriate converter, saving results to files, and optionally evaluating the results.
1085
+ The function handles progress tracking, logging, and error handling throughout the process.
1086
+
1087
+ :param strategy: The attack strategy to use
1088
+ :type strategy: Union[AttackStrategy, List[AttackStrategy]]
1089
+ :param risk_category: The risk category to evaluate
1090
+ :type risk_category: RiskCategory
1091
+ :param all_prompts: List of prompts to use for the scan
1092
+ :type all_prompts: List[str]
1093
+ :param progress_bar: Progress bar to update
1094
+ :type progress_bar: tqdm
1095
+ :param progress_bar_lock: Lock for the progress bar
1096
+ :type progress_bar_lock: asyncio.Lock
1097
+ :param scan_name: Optional name for the evaluation
1098
+ :type scan_name: Optional[str]
1099
+ :param skip_upload: Whether to return only data without evaluation
1100
+ :type skip_upload: bool
1101
+ :param output_path: Optional path for output
1102
+ :type output_path: Optional[Union[str, os.PathLike]]
1103
+ :param timeout: The timeout in seconds for API calls
1104
+ :type timeout: int
1105
+ :param _skip_evals: Whether to skip the actual evaluation process
1106
+ :type _skip_evals: bool
1107
+ :return: Evaluation result if available
1108
+ :rtype: Optional[EvaluationResult]
1109
+ """
1110
+ strategy_name = get_strategy_name(strategy)
1111
+ task_key = f"{strategy_name}_{risk_category.value}_attack"
1112
+ self.task_statuses[task_key] = TASK_STATUS["RUNNING"]
1113
+
1114
+ try:
1115
+ start_time = time.time()
1116
+ tqdm.write(f"▶️ Starting task: {strategy_name} strategy for {risk_category.value} risk category")
1117
+
1118
+ # Get converter and orchestrator function
1119
+ converter = get_converter_for_strategy(
1120
+ strategy, self.generated_rai_client, self._one_dp_project, self.logger
1121
+ )
1122
+ call_orchestrator = self.orchestrator_manager.get_orchestrator_for_attack_strategy(strategy)
1123
+
1124
+ try:
1125
+ self.logger.debug(f"Calling orchestrator for {strategy_name} strategy")
1126
+ orchestrator = await call_orchestrator(
1127
+ chat_target=self.chat_target,
1128
+ all_prompts=all_prompts,
1129
+ converter=converter,
1130
+ strategy_name=strategy_name,
1131
+ risk_category=risk_category,
1132
+ risk_category_name=risk_category.value,
1133
+ timeout=timeout,
1134
+ red_team_info=self.red_team_info,
1135
+ task_statuses=self.task_statuses,
1136
+ prompt_to_context=self.prompt_to_context,
1137
+ )
1138
+ except Exception as e:
1139
+ self.logger.error(f"Error calling orchestrator for {strategy_name} strategy: {str(e)}")
1140
+ self.task_statuses[task_key] = TASK_STATUS["FAILED"]
1141
+ self.failed_tasks += 1
1142
+ async with progress_bar_lock:
1143
+ progress_bar.update(1)
1144
+ return None
1145
+
1146
+ # Write PyRIT outputs to file
1147
+ data_path = write_pyrit_outputs_to_file(
1148
+ output_path=self.red_team_info[strategy_name][risk_category.value]["data_file"],
1149
+ logger=self.logger,
1150
+ prompt_to_context=self.prompt_to_context,
1151
+ )
1152
+ orchestrator.dispose_db_engine()
1153
+
1154
+ # Store data file in our tracking dictionary
1155
+ self.red_team_info[strategy_name][risk_category.value]["data_file"] = data_path
1156
+ self.logger.debug(
1157
+ f"Updated red_team_info with data file: {strategy_name} -> {risk_category.value} -> {data_path}"
1158
+ )
1159
+
1160
+ # Perform evaluation
1161
+ try:
1162
+ await self.evaluation_processor.evaluate(
1163
+ scan_name=scan_name,
1164
+ risk_category=risk_category,
1165
+ strategy=strategy,
1166
+ _skip_evals=_skip_evals,
1167
+ data_path=data_path,
1168
+ output_path=None,
1169
+ red_team_info=self.red_team_info,
1170
+ )
1171
+ except Exception as e:
1172
+ self.logger.error(
1173
+ self.logger,
1174
+ f"Error during evaluation for {strategy_name}/{risk_category.value}",
1175
+ e,
1176
+ )
1177
+ tqdm.write(f"⚠️ Evaluation error for {strategy_name}/{risk_category.value}: {str(e)}")
1178
+ self.red_team_info[strategy_name][risk_category.value]["status"] = TASK_STATUS["FAILED"]
1179
+
1180
+ # Update progress
1181
+ async with progress_bar_lock:
1182
+ self.completed_tasks += 1
1183
+ progress_bar.update(1)
1184
+ completion_pct = (self.completed_tasks / self.total_tasks) * 100
1185
+ elapsed_time = time.time() - start_time
1186
+
1187
+ if self.start_time:
1188
+ total_elapsed = time.time() - self.start_time
1189
+ avg_time_per_task = total_elapsed / self.completed_tasks if self.completed_tasks > 0 else 0
1190
+ remaining_tasks = self.total_tasks - self.completed_tasks
1191
+ est_remaining_time = avg_time_per_task * remaining_tasks if avg_time_per_task > 0 else 0
1192
+
1193
+ tqdm.write(
1194
+ f"✅ Completed task {self.completed_tasks}/{self.total_tasks} ({completion_pct:.1f}%) - {strategy_name}/{risk_category.value} in {elapsed_time:.1f}s"
1195
+ )
1196
+ tqdm.write(f" Est. remaining: {est_remaining_time/60:.1f} minutes")
1197
+ else:
1198
+ tqdm.write(
1199
+ f"✅ Completed task {self.completed_tasks}/{self.total_tasks} ({completion_pct:.1f}%) - {strategy_name}/{risk_category.value} in {elapsed_time:.1f}s"
1200
+ )
1201
+
1202
+ self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
1203
+
1204
+ except Exception as e:
1205
+ self.logger.error(
1206
+ f"Unexpected error processing {strategy_name} strategy for {risk_category.value}: {str(e)}"
1207
+ )
1208
+ self.task_statuses[task_key] = TASK_STATUS["FAILED"]
1209
+ self.failed_tasks += 1
1210
+ async with progress_bar_lock:
1211
+ progress_bar.update(1)
1212
+
1213
+ return None
1214
+
1215
+ async def scan(
1216
+ self,
1217
+ target: Union[
1218
+ Callable,
1219
+ AzureOpenAIModelConfiguration,
1220
+ OpenAIModelConfiguration,
1221
+ PromptChatTarget,
1222
+ ],
1223
+ *,
1224
+ scan_name: Optional[str] = None,
1225
+ attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]] = [],
1226
+ skip_upload: bool = False,
1227
+ output_path: Optional[Union[str, os.PathLike]] = None,
1228
+ application_scenario: Optional[str] = None,
1229
+ parallel_execution: bool = True,
1230
+ max_parallel_tasks: int = 5,
1231
+ timeout: int = 3600,
1232
+ skip_evals: bool = False,
1233
+ **kwargs: Any,
1234
+ ) -> RedTeamResult:
1235
+ """Run a red team scan against the target using the specified strategies.
1236
+
1237
+ :param target: The target model or function to scan
1238
+ :type target: Union[Callable, AzureOpenAIModelConfiguration, OpenAIModelConfiguration, PromptChatTarget]
1239
+ :param scan_name: Optional name for the evaluation
1240
+ :type scan_name: Optional[str]
1241
+ :param attack_strategies: List of attack strategies to use
1242
+ :type attack_strategies: List[Union[AttackStrategy, List[AttackStrategy]]]
1243
+ :param skip_upload: Flag to determine if the scan results should be uploaded
1244
+ :type skip_upload: bool
1245
+ :param output_path: Optional path for output
1246
+ :type output_path: Optional[Union[str, os.PathLike]]
1247
+ :param application_scenario: Optional description of the application scenario
1248
+ :type application_scenario: Optional[str]
1249
+ :param parallel_execution: Whether to execute orchestrator tasks in parallel
1250
+ :type parallel_execution: bool
1251
+ :param max_parallel_tasks: Maximum number of parallel orchestrator tasks to run (default: 5)
1252
+ :type max_parallel_tasks: int
1253
+ :param timeout: The timeout in seconds for API calls (default: 120)
1254
+ :type timeout: int
1255
+ :param skip_evals: Whether to skip the evaluation process
1256
+ :type skip_evals: bool
1257
+ :return: The output from the red team scan
1258
+ :rtype: RedTeamResult
1259
+ """
1260
+ user_agent: Optional[str] = kwargs.get("user_agent", "(type=redteam; subtype=RedTeam)")
1261
+ run_id_override = kwargs.get("run_id") or kwargs.get("runId")
1262
+ eval_id_override = kwargs.get("eval_id") or kwargs.get("evalId")
1263
+ created_at_override = kwargs.get("created_at") or kwargs.get("createdAt")
1264
+ taxonomy_risk_categories = kwargs.get("taxonomy_risk_categories") # key is risk category value is taxonomy
1265
+ _app_insights_configuration = kwargs.get("_app_insights_configuration")
1266
+ self._app_insights_configuration = _app_insights_configuration
1267
+ self.taxonomy_risk_categories = taxonomy_risk_categories or {}
1268
+ is_agent_target: Optional[bool] = kwargs.get("is_agent_target", False)
1269
+ client_id: Optional[str] = kwargs.get("client_id")
1270
+
1271
+ with UserAgentSingleton().add_useragent_product(user_agent):
1272
+ # Initialize scan
1273
+ self._initialize_scan(scan_name, application_scenario)
1274
+
1275
+ # Setup logging and directories FIRST
1276
+ self._setup_scan_environment()
1277
+
1278
+ # Setup component managers AFTER scan environment is set up
1279
+ self._setup_component_managers()
1280
+
1281
+ # Update result processor with AI studio URL
1282
+ self.result_processor.ai_studio_url = getattr(self.mlflow_integration, "ai_studio_url", None)
1283
+
1284
+ # Update component managers with the new logger
1285
+ self.orchestrator_manager.logger = self.logger
1286
+ self.evaluation_processor.logger = self.logger
1287
+ self.mlflow_integration.logger = self.logger
1288
+ self.result_processor.logger = self.logger
1289
+
1290
+ self.mlflow_integration.set_run_identity_overrides(
1291
+ run_id=run_id_override,
1292
+ eval_id=eval_id_override,
1293
+ created_at=created_at_override,
1294
+ )
1295
+
1296
+ # Validate attack objective generator
1297
+ if not self.attack_objective_generator:
1298
+ raise EvaluationException(
1299
+ message="Attack objective generator is required for red team agent.",
1300
+ internal_message="Attack objective generator is not provided.",
1301
+ target=ErrorTarget.RED_TEAM,
1302
+ category=ErrorCategory.MISSING_FIELD,
1303
+ blame=ErrorBlame.USER_ERROR,
1304
+ )
1305
+
1306
+ # Set default risk categories if not specified
1307
+ if not self.attack_objective_generator.risk_categories:
1308
+ self.logger.info("No risk categories specified, using all available categories")
1309
+ self.attack_objective_generator.risk_categories = [
1310
+ RiskCategory.HateUnfairness,
1311
+ RiskCategory.Sexual,
1312
+ RiskCategory.Violence,
1313
+ RiskCategory.SelfHarm,
1314
+ ]
1315
+
1316
+ self.risk_categories = self.attack_objective_generator.risk_categories
1317
+ self.result_processor.risk_categories = self.risk_categories
1318
+
1319
+ # Validate risk categories for target type
1320
+ if not is_agent_target:
1321
+ # Check if any agent-only risk categories are used with model targets
1322
+ for risk_cat in self.risk_categories:
1323
+ if risk_cat == RiskCategory.SensitiveDataLeakage:
1324
+ raise EvaluationException(
1325
+ message=f"Risk category '{risk_cat.value}' is only available for agent targets",
1326
+ internal_message=f"Risk category {risk_cat.value} requires agent target",
1327
+ target=ErrorTarget.RED_TEAM,
1328
+ category=ErrorCategory.INVALID_VALUE,
1329
+ blame=ErrorBlame.USER_ERROR,
1330
+ )
1331
+
1332
+ # Show risk categories to user
1333
+ tqdm.write(f"📊 Risk categories: {[rc.value for rc in self.risk_categories]}")
1334
+ self.logger.info(f"Risk categories to process: {[rc.value for rc in self.risk_categories]}")
1335
+
1336
+ # Setup attack strategies
1337
+ if AttackStrategy.Baseline not in attack_strategies:
1338
+ attack_strategies.insert(0, AttackStrategy.Baseline)
1339
+
1340
+ # Start MLFlow run if not skipping upload
1341
+ if skip_upload:
1342
+ eval_run = {}
1343
+ else:
1344
+ eval_run = self.mlflow_integration.start_redteam_mlflow_run(self.azure_ai_project, scan_name)
1345
+ tqdm.write(f"🔗 Track your red team scan in AI Foundry: {self.mlflow_integration.ai_studio_url}")
1346
+
1347
+ # Update result processor with the AI studio URL now that it's available
1348
+ self.result_processor.ai_studio_url = self.mlflow_integration.ai_studio_url
1349
+
1350
+ # Process strategies and execute scan
1351
+ flattened_attack_strategies = get_flattened_attack_strategies(attack_strategies)
1352
+ self._validate_strategies(flattened_attack_strategies)
1353
+
1354
+ # Calculate total tasks and initialize tracking
1355
+ self.total_tasks = len(self.risk_categories) * len(flattened_attack_strategies)
1356
+ tqdm.write(f"📋 Planning {self.total_tasks} total tasks")
1357
+ self._initialize_tracking_dict(flattened_attack_strategies)
1358
+
1359
+ # Fetch attack objectives
1360
+ all_objectives = await self._fetch_all_objectives(
1361
+ flattened_attack_strategies, application_scenario, is_agent_target, client_id
1362
+ )
1363
+
1364
+ chat_target = get_chat_target(target)
1365
+ self.chat_target = chat_target
1366
+
1367
+ # Execute attacks
1368
+ await self._execute_attacks(
1369
+ flattened_attack_strategies,
1370
+ all_objectives,
1371
+ scan_name,
1372
+ skip_upload,
1373
+ output_path,
1374
+ timeout,
1375
+ skip_evals,
1376
+ parallel_execution,
1377
+ max_parallel_tasks,
1378
+ )
1379
+
1380
+ # Process and return results
1381
+ return await self._finalize_results(skip_upload, skip_evals, eval_run, output_path, scan_name)
1382
+
1383
+ def _initialize_scan(self, scan_name: Optional[str], application_scenario: Optional[str]):
1384
+ """Initialize scan-specific variables."""
1385
+ self.start_time = time.time()
1386
+ self.task_statuses = {}
1387
+ self.completed_tasks = 0
1388
+ self.failed_tasks = 0
1389
+
1390
+ # Generate unique scan ID and session ID
1391
+ self.scan_id = (
1392
+ f"scan_{scan_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
1393
+ if scan_name
1394
+ else f"scan_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
1395
+ )
1396
+ self.scan_id = self.scan_id.replace(" ", "_")
1397
+ self.scan_session_id = str(uuid.uuid4())
1398
+ self.application_scenario = application_scenario or ""
1399
+
1400
+ def _setup_scan_environment(self):
1401
+ """Setup scan output directory and logging."""
1402
+ # Use file manager to create scan output directory
1403
+ self.scan_output_dir = self.file_manager.get_scan_output_path(self.scan_id)
1404
+
1405
+ # Re-initialize logger with the scan output directory
1406
+ self.logger = setup_logger(output_dir=self.scan_output_dir)
1407
+
1408
+ # Setup logging filters
1409
+ self._setup_logging_filters()
1410
+
1411
+ log_section_header(self.logger, "Starting red team scan")
1412
+ tqdm.write(f"🚀 STARTING RED TEAM SCAN")
1413
+ tqdm.write(f"📂 Output directory: {self.scan_output_dir}")
1414
+
1415
+ def _setup_logging_filters(self):
1416
+ """Setup logging filters to suppress unwanted logs."""
1417
+
1418
+ class LogFilter(logging.Filter):
1419
+ def filter(self, record):
1420
+ # Filter out promptflow logs and evaluation warnings about artifacts
1421
+ if record.name.startswith("promptflow"):
1422
+ return False
1423
+ if "The path to the artifact is either not a directory or does not exist" in record.getMessage():
1424
+ return False
1425
+ if "RedTeamResult object at" in record.getMessage():
1426
+ return False
1427
+ if "timeout won't take effect" in record.getMessage():
1428
+ return False
1429
+ if "Submitting run" in record.getMessage():
1430
+ return False
1431
+ return True
1432
+
1433
+ # Apply filter to root logger
1434
+ root_logger = logging.getLogger()
1435
+ log_filter = LogFilter()
1436
+
1437
+ for handler in root_logger.handlers:
1438
+ for filter in handler.filters:
1439
+ handler.removeFilter(filter)
1440
+ handler.addFilter(log_filter)
1441
+
1442
+ def _validate_strategies(self, flattened_attack_strategies: List):
1443
+ """Validate attack strategies."""
1444
+ if len(flattened_attack_strategies) > 2 and (
1445
+ AttackStrategy.MultiTurn in flattened_attack_strategies
1446
+ or AttackStrategy.Crescendo in flattened_attack_strategies
1447
+ ):
1448
+ self.logger.warning(
1449
+ "MultiTurn and Crescendo strategies are not compatible with multiple attack strategies."
1450
+ )
1451
+ raise ValueError("MultiTurn and Crescendo strategies are not compatible with multiple attack strategies.")
1452
+
1453
+ def _initialize_tracking_dict(self, flattened_attack_strategies: List):
1454
+ """Initialize the red_team_info tracking dictionary."""
1455
+ self.red_team_info = {}
1456
+ for strategy in flattened_attack_strategies:
1457
+ strategy_name = get_strategy_name(strategy)
1458
+ self.red_team_info[strategy_name] = {}
1459
+ for risk_category in self.risk_categories:
1460
+ self.red_team_info[strategy_name][risk_category.value] = {
1461
+ "data_file": "",
1462
+ "evaluation_result_file": "",
1463
+ "evaluation_result": None,
1464
+ "status": TASK_STATUS["PENDING"],
1465
+ }
1466
+
1467
+ async def _fetch_all_objectives(
1468
+ self,
1469
+ flattened_attack_strategies: List,
1470
+ application_scenario: str,
1471
+ is_agent_target: bool,
1472
+ client_id: Optional[str] = None,
1473
+ ) -> Dict:
1474
+ """Fetch all attack objectives for all strategies and risk categories."""
1475
+ log_section_header(self.logger, "Fetching attack objectives")
1476
+ all_objectives = {}
1477
+
1478
+ # Calculate and log num_objectives_with_subtypes once globally
1479
+ num_objectives = self.attack_objective_generator.num_objectives
1480
+ max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in self.risk_categories), default=0)
1481
+ num_objectives_with_subtypes = max(num_objectives, max_num_subtypes)
1482
+
1483
+ if num_objectives_with_subtypes != num_objectives:
1484
+ warning_msg = (
1485
+ f"Using {num_objectives_with_subtypes} objectives per risk category instead of requested {num_objectives} "
1486
+ f"to ensure adequate coverage of {max_num_subtypes} subtypes"
1487
+ )
1488
+ self.logger.warning(warning_msg)
1489
+ tqdm.write(f"[WARNING] {warning_msg}")
1490
+
1491
+ # First fetch baseline objectives for all risk categories
1492
+ self.logger.info("Fetching baseline objectives for all risk categories")
1493
+ for risk_category in self.risk_categories:
1494
+ baseline_objectives = await self._get_attack_objectives(
1495
+ risk_category=risk_category,
1496
+ application_scenario=application_scenario,
1497
+ strategy="baseline",
1498
+ is_agent_target=is_agent_target,
1499
+ client_id=client_id,
1500
+ )
1501
+ if "baseline" not in all_objectives:
1502
+ all_objectives["baseline"] = {}
1503
+ all_objectives["baseline"][risk_category.value] = baseline_objectives
1504
+ status_msg = f"📝 Fetched baseline objectives for {risk_category.value}: {len(baseline_objectives)}/{num_objectives_with_subtypes} objectives"
1505
+ if len(baseline_objectives) < num_objectives_with_subtypes:
1506
+ status_msg += f" (⚠️ fewer than expected)"
1507
+ tqdm.write(status_msg)
1508
+
1509
+ # Then fetch objectives for other strategies
1510
+ strategy_count = len(flattened_attack_strategies)
1511
+ for i, strategy in enumerate(flattened_attack_strategies):
1512
+ strategy_name = get_strategy_name(strategy)
1513
+ if strategy_name == "baseline":
1514
+ continue
1515
+
1516
+ tqdm.write(f"🔄 Fetching objectives for strategy {i+1}/{strategy_count}: {strategy_name}")
1517
+ all_objectives[strategy_name] = {}
1518
+
1519
+ for risk_category in self.risk_categories:
1520
+ objectives = await self._get_attack_objectives(
1521
+ risk_category=risk_category,
1522
+ application_scenario=application_scenario,
1523
+ strategy=strategy_name,
1524
+ is_agent_target=is_agent_target,
1525
+ client_id=client_id,
1526
+ )
1527
+ all_objectives[strategy_name][risk_category.value] = objectives
1528
+
1529
+ return all_objectives
1530
+
1531
+ async def _execute_attacks(
1532
+ self,
1533
+ flattened_attack_strategies: List,
1534
+ all_objectives: Dict,
1535
+ scan_name: str,
1536
+ skip_upload: bool,
1537
+ output_path: str,
1538
+ timeout: int,
1539
+ skip_evals: bool,
1540
+ parallel_execution: bool,
1541
+ max_parallel_tasks: int,
1542
+ ):
1543
+ """Execute all attack combinations."""
1544
+ log_section_header(self.logger, "Starting orchestrator processing")
1545
+
1546
+ # Create progress bar
1547
+ progress_bar = tqdm(
1548
+ total=self.total_tasks,
1549
+ desc="Scanning: ",
1550
+ ncols=100,
1551
+ unit="scan",
1552
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]",
1553
+ )
1554
+ progress_bar.set_postfix({"current": "initializing"})
1555
+ progress_bar_lock = asyncio.Lock()
1556
+
1557
+ # Create all tasks for parallel processing
1558
+ orchestrator_tasks = []
1559
+ combinations = list(itertools.product(flattened_attack_strategies, self.risk_categories))
1560
+
1561
+ for combo_idx, (strategy, risk_category) in enumerate(combinations):
1562
+ strategy_name = get_strategy_name(strategy)
1563
+ objectives = all_objectives[strategy_name][risk_category.value]
1564
+
1565
+ if not objectives:
1566
+ self.logger.warning(f"No objectives found for {strategy_name}+{risk_category.value}, skipping")
1567
+ tqdm.write(f"⚠️ No objectives found for {strategy_name}/{risk_category.value}, skipping")
1568
+ self.red_team_info[strategy_name][risk_category.value]["status"] = TASK_STATUS["COMPLETED"]
1569
+ async with progress_bar_lock:
1570
+ progress_bar.update(1)
1571
+ continue
1572
+
1573
+ orchestrator_tasks.append(
1574
+ self._process_attack(
1575
+ all_prompts=objectives,
1576
+ strategy=strategy,
1577
+ progress_bar=progress_bar,
1578
+ progress_bar_lock=progress_bar_lock,
1579
+ scan_name=scan_name,
1580
+ skip_upload=skip_upload,
1581
+ output_path=output_path,
1582
+ risk_category=risk_category,
1583
+ timeout=timeout,
1584
+ _skip_evals=skip_evals,
1585
+ )
1586
+ )
1587
+
1588
+ # Process tasks
1589
+ await self._process_orchestrator_tasks(orchestrator_tasks, parallel_execution, max_parallel_tasks, timeout)
1590
+ progress_bar.close()
1591
+
1592
+ async def _process_orchestrator_tasks(
1593
+ self, orchestrator_tasks: List, parallel_execution: bool, max_parallel_tasks: int, timeout: int
1594
+ ):
1595
+ """Process orchestrator tasks either in parallel or sequentially."""
1596
+ if parallel_execution and orchestrator_tasks:
1597
+ tqdm.write(f"⚙️ Processing {len(orchestrator_tasks)} tasks in parallel (max {max_parallel_tasks} at a time)")
1598
+
1599
+ # Process tasks in batches
1600
+ for i in range(0, len(orchestrator_tasks), max_parallel_tasks):
1601
+ end_idx = min(i + max_parallel_tasks, len(orchestrator_tasks))
1602
+ batch = orchestrator_tasks[i:end_idx]
1603
+
1604
+ try:
1605
+ await asyncio.wait_for(asyncio.gather(*batch), timeout=timeout * 2)
1606
+ except asyncio.TimeoutError:
1607
+ self.logger.warning(f"Batch {i//max_parallel_tasks+1} timed out")
1608
+ tqdm.write(f"⚠️ Batch {i//max_parallel_tasks+1} timed out, continuing with next batch")
1609
+ continue
1610
+ except Exception as e:
1611
+ self.logger.error(f"Error processing batch {i//max_parallel_tasks+1}: {str(e)}")
1612
+ continue
1613
+ else:
1614
+ # Sequential execution
1615
+ tqdm.write("⚙️ Processing tasks sequentially")
1616
+ for i, task in enumerate(orchestrator_tasks):
1617
+ try:
1618
+ await asyncio.wait_for(task, timeout=timeout)
1619
+ except asyncio.TimeoutError:
1620
+ self.logger.warning(f"Task {i+1} timed out")
1621
+ tqdm.write(f"⚠️ Task {i+1} timed out, continuing with next task")
1622
+ continue
1623
+ except Exception as e:
1624
+ self.logger.error(f"Error processing task {i+1}: {str(e)}")
1625
+ continue
1626
+
1627
+ async def _finalize_results(
1628
+ self, skip_upload: bool, skip_evals: bool, eval_run, output_path: str, scan_name: str
1629
+ ) -> RedTeamResult:
1630
+ """Process and finalize scan results."""
1631
+ log_section_header(self.logger, "Processing results")
1632
+
1633
+ # Convert results to RedTeamResult (now builds AOAI summary internally)
1634
+ red_team_result = self.result_processor.to_red_team_result(
1635
+ red_team_info=self.red_team_info,
1636
+ eval_run=eval_run,
1637
+ scan_name=scan_name,
1638
+ )
1639
+
1640
+ # Extract AOAI summary for passing to MLflow logging
1641
+ aoai_summary = red_team_result.scan_result.get("AOAI_Compatible_Summary")
1642
+ if self._app_insights_configuration:
1643
+ # Get redacted results from the result processor for App Insights logging
1644
+ redacted_results = self.result_processor.get_app_insights_redacted_results(
1645
+ aoai_summary["output_items"]["data"]
1646
+ )
1647
+ emit_eval_result_events_to_app_insights(self._app_insights_configuration, redacted_results)
1648
+ # Log results to MLFlow if not skipping upload
1649
+ if not skip_upload:
1650
+ self.logger.info("Logging results to AI Foundry")
1651
+ await self.mlflow_integration.log_redteam_results_to_mlflow(
1652
+ redteam_result=red_team_result,
1653
+ eval_run=eval_run,
1654
+ red_team_info=self.red_team_info,
1655
+ _skip_evals=skip_evals,
1656
+ aoai_summary=aoai_summary,
1657
+ )
1658
+ # Write output to specified path
1659
+ if output_path and red_team_result.scan_result:
1660
+ abs_output_path = output_path if os.path.isabs(output_path) else os.path.abspath(output_path)
1661
+ self.logger.info(f"Writing output to {abs_output_path}")
1662
+
1663
+ # Ensure output_path is treated as a directory
1664
+ # If it exists as a file, remove it first
1665
+ if os.path.exists(abs_output_path) and not os.path.isdir(abs_output_path):
1666
+ os.remove(abs_output_path)
1667
+ os.makedirs(abs_output_path, exist_ok=True)
1668
+
1669
+ # Create a copy of scan_result without AOAI properties for eval_result.json
1670
+ scan_result_without_aoai = {
1671
+ key: value
1672
+ for key, value in red_team_result.scan_result.items()
1673
+ if not key.startswith("AOAI_Compatible")
1674
+ }
1675
+
1676
+ # Write scan result without AOAI properties to eval_result.json
1677
+ _write_output(abs_output_path, scan_result_without_aoai)
1678
+
1679
+ # Write the AOAI summary to results.json
1680
+ if aoai_summary:
1681
+ _write_output(os.path.join(abs_output_path, "results.json"), aoai_summary)
1682
+ else:
1683
+ self.logger.warning("AOAI summary not available for output_path write")
1684
+
1685
+ # Also save a copy to the scan output directory if available
1686
+ if self.scan_output_dir:
1687
+ final_output = os.path.join(self.scan_output_dir, "final_results.json")
1688
+ _write_output(final_output, red_team_result.scan_result)
1689
+ elif red_team_result.scan_result and self.scan_output_dir:
1690
+ # If no output_path was specified but we have scan_output_dir, save there
1691
+ final_output = os.path.join(self.scan_output_dir, "final_results.json")
1692
+ _write_output(final_output, red_team_result.scan_result)
1693
+
1694
+ # Display final scorecard and results
1695
+ if red_team_result.scan_result:
1696
+ scorecard = format_scorecard(red_team_result.scan_result)
1697
+ tqdm.write(scorecard)
1698
+
1699
+ # Print URL for detailed results
1700
+ studio_url = red_team_result.scan_result.get("studio_url", "")
1701
+ if studio_url:
1702
+ tqdm.write(f"\nDetailed results available at:\n{studio_url}")
1703
+
1704
+ # Print the output directory path
1705
+ if self.scan_output_dir:
1706
+ tqdm.write(f"\n📂 All scan files saved to: {self.scan_output_dir}")
1707
+
1708
+ tqdm.write(f"✅ Scan completed successfully!")
1709
+ self.logger.info("Scan completed successfully")
1710
+
1711
+ # Close file handlers
1712
+ for handler in self.logger.handlers:
1713
+ if isinstance(handler, logging.FileHandler):
1714
+ handler.close()
1715
+ self.logger.removeHandler(handler)
1716
+
1717
+ return red_team_result