azure-ai-evaluation 1.0.0b3__py3-none-any.whl → 1.0.0b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (93) hide show
  1. azure/ai/evaluation/__init__.py +23 -1
  2. azure/ai/evaluation/{simulator/_helpers → _common}/_experimental.py +20 -9
  3. azure/ai/evaluation/_common/constants.py +9 -2
  4. azure/ai/evaluation/_common/math.py +29 -0
  5. azure/ai/evaluation/_common/rai_service.py +222 -93
  6. azure/ai/evaluation/_common/utils.py +328 -19
  7. azure/ai/evaluation/_constants.py +16 -8
  8. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/__init__.py +3 -2
  9. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/code_client.py +33 -17
  10. azure/ai/evaluation/_evaluate/{_batch_run_client/batch_run_context.py → _batch_run/eval_run_context.py} +14 -7
  11. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/proxy_client.py +22 -4
  12. azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +35 -0
  13. azure/ai/evaluation/_evaluate/_eval_run.py +47 -14
  14. azure/ai/evaluation/_evaluate/_evaluate.py +370 -188
  15. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +15 -16
  16. azure/ai/evaluation/_evaluate/_utils.py +77 -25
  17. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +1 -1
  18. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +16 -10
  19. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +76 -34
  20. azure/ai/evaluation/_evaluators/_common/_base_eval.py +76 -46
  21. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +26 -19
  22. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +62 -25
  23. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +68 -36
  24. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +67 -46
  25. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +33 -4
  26. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +33 -4
  27. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +33 -4
  28. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +33 -4
  29. azure/ai/evaluation/_evaluators/_eci/_eci.py +7 -5
  30. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +14 -6
  31. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +22 -21
  32. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +66 -36
  33. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +1 -1
  34. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +51 -16
  35. azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +113 -0
  36. azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +99 -0
  37. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +3 -7
  38. azure/ai/evaluation/_evaluators/_multimodal/__init__.py +20 -0
  39. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +130 -0
  40. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +57 -0
  41. azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +96 -0
  42. azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +120 -0
  43. azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +96 -0
  44. azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +96 -0
  45. azure/ai/evaluation/_evaluators/_multimodal/_violence.py +96 -0
  46. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +46 -13
  47. azure/ai/evaluation/_evaluators/_qa/_qa.py +11 -6
  48. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +23 -20
  49. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +78 -42
  50. azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +126 -80
  51. azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +74 -24
  52. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +2 -2
  53. azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
  54. azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +150 -0
  55. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +32 -15
  56. azure/ai/evaluation/_evaluators/_xpia/xpia.py +36 -10
  57. azure/ai/evaluation/_exceptions.py +26 -6
  58. azure/ai/evaluation/_http_utils.py +203 -132
  59. azure/ai/evaluation/_model_configurations.py +23 -6
  60. azure/ai/evaluation/_vendor/__init__.py +3 -0
  61. azure/ai/evaluation/_vendor/rouge_score/__init__.py +14 -0
  62. azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +328 -0
  63. azure/ai/evaluation/_vendor/rouge_score/scoring.py +63 -0
  64. azure/ai/evaluation/_vendor/rouge_score/tokenize.py +63 -0
  65. azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +53 -0
  66. azure/ai/evaluation/_version.py +1 -1
  67. azure/ai/evaluation/simulator/__init__.py +2 -1
  68. azure/ai/evaluation/simulator/_adversarial_scenario.py +5 -0
  69. azure/ai/evaluation/simulator/_adversarial_simulator.py +88 -60
  70. azure/ai/evaluation/simulator/_conversation/__init__.py +13 -12
  71. azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -4
  72. azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
  73. azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
  74. azure/ai/evaluation/simulator/_direct_attack_simulator.py +24 -66
  75. azure/ai/evaluation/simulator/_helpers/__init__.py +1 -2
  76. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +26 -5
  77. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +98 -95
  78. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +67 -21
  79. azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +28 -11
  80. azure/ai/evaluation/simulator/_model_tools/_template_handler.py +68 -24
  81. azure/ai/evaluation/simulator/_model_tools/models.py +10 -10
  82. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +4 -9
  83. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +6 -5
  84. azure/ai/evaluation/simulator/_simulator.py +222 -169
  85. azure/ai/evaluation/simulator/_tracing.py +4 -4
  86. azure/ai/evaluation/simulator/_utils.py +6 -6
  87. {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/METADATA +237 -52
  88. azure_ai_evaluation-1.0.0b5.dist-info/NOTICE.txt +70 -0
  89. azure_ai_evaluation-1.0.0b5.dist-info/RECORD +120 -0
  90. {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/WHEEL +1 -1
  91. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -49
  92. azure_ai_evaluation-1.0.0b3.dist-info/RECORD +0 -98
  93. {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,17 @@
1
1
  # ---------------------------------------------------------
2
2
  # Copyright (c) Microsoft Corporation. All rights reserved.
3
3
  # ---------------------------------------------------------
4
+ # pylint: disable=C0301,C0114,R0913,R0903
4
5
  # noqa: E501
5
- import functools
6
6
  import logging
7
7
  from random import randint
8
- from typing import Callable, Optional
9
-
10
- from promptflow._sdk._telemetry import ActivityType, monitor_operation
8
+ from typing import Callable, Optional, cast
11
9
 
10
+ from azure.ai.evaluation._common._experimental import experimental
11
+ from azure.ai.evaluation._common.utils import validate_azure_ai_project
12
12
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
13
- from azure.ai.evaluation._model_configurations import AzureAIProject
14
13
  from azure.ai.evaluation.simulator import AdversarialScenario
15
- from azure.identity import DefaultAzureCredential
14
+ from azure.core.credentials import TokenCredential
16
15
 
17
16
  from ._adversarial_simulator import AdversarialSimulator
18
17
  from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
@@ -20,35 +19,7 @@ from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenMan
20
19
  logger = logging.getLogger(__name__)
21
20
 
22
21
 
23
- def monitor_adversarial_scenario(func) -> Callable:
24
- """Decorator to monitor adversarial scenario.
25
-
26
- :param func: The function to be decorated.
27
- :type func: Callable
28
- :return: The decorated function.
29
- :rtype: Callable
30
- """
31
-
32
- @functools.wraps(func)
33
- def wrapper(*args, **kwargs):
34
- scenario = str(kwargs.get("scenario", None))
35
- max_conversation_turns = kwargs.get("max_conversation_turns", None)
36
- max_simulation_results = kwargs.get("max_simulation_results", None)
37
- decorated_func = monitor_operation(
38
- activity_name="jailbreak.adversarial.simulator.call",
39
- activity_type=ActivityType.PUBLICAPI,
40
- custom_dimensions={
41
- "scenario": scenario,
42
- "max_conversation_turns": max_conversation_turns,
43
- "max_simulation_results": max_simulation_results,
44
- },
45
- )(func)
46
-
47
- return decorated_func(*args, **kwargs)
48
-
49
- return wrapper
50
-
51
-
22
+ @experimental
52
23
  class DirectAttackSimulator:
53
24
  """
54
25
  Initialize a UPIA (user prompt injected attack) jailbreak adversarial simulator with a project scope.
@@ -61,42 +32,28 @@ class DirectAttackSimulator:
61
32
  :type credential: ~azure.core.credentials.TokenCredential
62
33
  """
63
34
 
64
- def __init__(self, *, azure_ai_project: AzureAIProject, credential=None):
35
+ def __init__(self, *, azure_ai_project: dict, credential):
65
36
  """Constructor."""
66
- # check if azure_ai_project has the keys: subscription_id, resource_group_name, project_name, credential
67
- if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
68
- msg = "azure_ai_project must contain keys: subscription_id, resource_group_name and project_name"
69
- raise EvaluationException(
70
- message=msg,
71
- internal_message=msg,
72
- target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
73
- category=ErrorCategory.MISSING_FIELD,
74
- blame=ErrorBlame.USER_ERROR,
75
- )
76
- # check the value of the keys in azure_ai_project is not none
77
- if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
78
- msg = "subscription_id, resource_group_name and project_name keys cannot be None"
37
+
38
+ try:
39
+ self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
40
+ except EvaluationException as e:
79
41
  raise EvaluationException(
80
- message=msg,
81
- internal_message=msg,
42
+ message=e.message,
43
+ internal_message=e.internal_message,
82
44
  target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
83
- category=ErrorCategory.MISSING_FIELD,
84
- blame=ErrorBlame.USER_ERROR,
85
- )
86
- if "credential" not in azure_ai_project and not credential:
87
- credential = DefaultAzureCredential()
88
- elif "credential" in azure_ai_project:
89
- credential = azure_ai_project["credential"]
90
- self.credential = credential
91
- self.azure_ai_project = azure_ai_project
45
+ category=e.category,
46
+ blame=e.blame,
47
+ ) from e
48
+ self.credential = cast(TokenCredential, credential)
92
49
  self.token_manager = ManagedIdentityAPITokenManager(
93
50
  token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
94
51
  logger=logging.getLogger("AdversarialSimulator"),
95
- credential=credential,
52
+ credential=self.credential,
96
53
  )
97
- self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
54
+ self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager)
98
55
  self.adversarial_template_handler = AdversarialTemplateHandler(
99
- azure_ai_project=azure_ai_project, rai_client=self.rai_client
56
+ azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
100
57
  )
101
58
 
102
59
  def _ensure_service_dependencies(self):
@@ -110,7 +67,6 @@ class DirectAttackSimulator:
110
67
  blame=ErrorBlame.USER_ERROR,
111
68
  )
112
69
 
113
- # @monitor_adversarial_scenario
114
70
  async def __call__(
115
71
  self,
116
72
  *,
@@ -222,7 +178,9 @@ class DirectAttackSimulator:
222
178
  if not randomization_seed:
223
179
  randomization_seed = randint(0, 1000000)
224
180
 
225
- regular_sim = AdversarialSimulator(azure_ai_project=self.azure_ai_project, credential=self.credential)
181
+ regular_sim = AdversarialSimulator(
182
+ azure_ai_project=cast(dict, self.azure_ai_project), credential=self.credential
183
+ )
226
184
  regular_sim_results = await regular_sim(
227
185
  scenario=scenario,
228
186
  target=target,
@@ -235,7 +193,7 @@ class DirectAttackSimulator:
235
193
  randomize_order=True,
236
194
  randomization_seed=randomization_seed,
237
195
  )
238
- jb_sim = AdversarialSimulator(azure_ai_project=self.azure_ai_project, credential=self.credential)
196
+ jb_sim = AdversarialSimulator(azure_ai_project=cast(dict, self.azure_ai_project), credential=self.credential)
239
197
  jb_sim_results = await jb_sim(
240
198
  scenario=scenario,
241
199
  target=target,
@@ -1,5 +1,4 @@
1
- from ._experimental import experimental
2
1
  from ._language_suffix_mapping import SUPPORTED_LANGUAGES_MAPPING
3
2
  from ._simulator_data_classes import ConversationHistory, Turn
4
3
 
5
- __all__ = ["ConversationHistory", "Turn", "SUPPORTED_LANGUAGES_MAPPING", "experimental"]
4
+ __all__ = ["ConversationHistory", "Turn", "SUPPORTED_LANGUAGES_MAPPING"]
@@ -18,7 +18,7 @@ class Turn:
18
18
 
19
19
  role: Union[str, ConversationRole]
20
20
  content: str
21
- context: str = None
21
+ context: Optional[str] = None
22
22
 
23
23
  def to_dict(self) -> Dict[str, Optional[str]]:
24
24
  """
@@ -30,7 +30,19 @@ class Turn:
30
30
  return {
31
31
  "role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
32
32
  "content": self.content,
33
- "context": self.context,
33
+ "context": str(self.context),
34
+ }
35
+
36
+ def to_context_free_dict(self) -> Dict[str, Optional[str]]:
37
+ """
38
+ Convert the conversation turn to a dictionary without context.
39
+
40
+ :returns: A dictionary representation of the conversation turn without context.
41
+ :rtype: Dict[str, Optional[str]]
42
+ """
43
+ return {
44
+ "role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
45
+ "content": self.content,
34
46
  }
35
47
 
36
48
  def __repr__(self):
@@ -42,13 +54,13 @@ class ConversationHistory:
42
54
  Conversation history class to keep track of the conversation turns in a conversation.
43
55
  """
44
56
 
45
- def __init__(self):
57
+ def __init__(self) -> None:
46
58
  """
47
59
  Initializes the conversation history with an empty list of turns.
48
60
  """
49
61
  self.history: List[Turn] = []
50
62
 
51
- def add_to_history(self, turn: Turn):
63
+ def add_to_history(self, turn: Turn) -> None:
52
64
  """
53
65
  Adds a turn to the conversation history.
54
66
 
@@ -57,7 +69,7 @@ class ConversationHistory:
57
69
  """
58
70
  self.history.append(turn)
59
71
 
60
- def to_list(self) -> List[Dict[str, str]]:
72
+ def to_list(self) -> List[Dict[str, Optional[str]]]:
61
73
  """
62
74
  Converts the conversation history to a list of dictionaries.
63
75
 
@@ -66,6 +78,15 @@ class ConversationHistory:
66
78
  """
67
79
  return [turn.to_dict() for turn in self.history]
68
80
 
81
+ def to_context_free_list(self) -> List[Dict[str, Optional[str]]]:
82
+ """
83
+ Converts the conversation history to a list of dictionaries without context.
84
+
85
+ :returns: A list of dictionaries representing the conversation turns without context.
86
+ :rtype: List[Dict[str, str]]
87
+ """
88
+ return [turn.to_context_free_dict() for turn in self.history]
89
+
69
90
  def __len__(self) -> int:
70
91
  return len(self.history)
71
92
 
@@ -1,54 +1,29 @@
1
1
  # ---------------------------------------------------------
2
2
  # Copyright (c) Microsoft Corporation. All rights reserved.
3
3
  # ---------------------------------------------------------
4
+ # pylint: disable=C0301,C0114,R0913,R0903
4
5
  # noqa: E501
5
- import functools
6
+ import asyncio
6
7
  import logging
7
- from typing import Callable
8
+ from typing import Callable, cast
8
9
 
9
- from promptflow._sdk._telemetry import ActivityType, monitor_operation
10
+ from tqdm import tqdm
10
11
 
12
+ from azure.ai.evaluation._common.utils import validate_azure_ai_project
13
+ from azure.ai.evaluation._common._experimental import experimental
11
14
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
12
- from azure.ai.evaluation._model_configurations import AzureAIProject
13
- from azure.ai.evaluation.simulator import AdversarialScenario
14
- from azure.identity import DefaultAzureCredential
15
+ from azure.ai.evaluation.simulator import AdversarialScenarioJailbreak, SupportedLanguages
16
+ from azure.core.credentials import TokenCredential
17
+
18
+ from ._adversarial_simulator import AdversarialSimulator, JsonLineList
15
19
 
16
- from ._adversarial_simulator import AdversarialSimulator
17
20
  from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
18
21
 
19
22
  logger = logging.getLogger(__name__)
20
23
 
21
24
 
22
- def monitor_adversarial_scenario(func) -> Callable:
23
- """Decorator to monitor adversarial scenario.
24
-
25
- :param func: The function to be decorated.
26
- :type func: Callable
27
- :return: The decorated function.
28
- :rtype: Callable
29
- """
30
-
31
- @functools.wraps(func)
32
- def wrapper(*args, **kwargs):
33
- scenario = str(kwargs.get("scenario", None))
34
- max_conversation_turns = kwargs.get("max_conversation_turns", None)
35
- max_simulation_results = kwargs.get("max_simulation_results", None)
36
- decorated_func = monitor_operation(
37
- activity_name="xpia.adversarial.simulator.call",
38
- activity_type=ActivityType.PUBLICAPI,
39
- custom_dimensions={
40
- "scenario": scenario,
41
- "max_conversation_turns": max_conversation_turns,
42
- "max_simulation_results": max_simulation_results,
43
- },
44
- )(func)
45
-
46
- return decorated_func(*args, **kwargs)
47
-
48
- return wrapper
49
-
50
-
51
- class IndirectAttackSimulator:
25
+ @experimental
26
+ class IndirectAttackSimulator(AdversarialSimulator):
52
27
  """
53
28
  Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
54
29
 
@@ -59,42 +34,31 @@ class IndirectAttackSimulator:
59
34
  :type credential: ~azure.core.credentials.TokenCredential
60
35
  """
61
36
 
62
- def __init__(self, *, azure_ai_project: AzureAIProject, credential=None):
37
+ def __init__(self, *, azure_ai_project: dict, credential):
63
38
  """Constructor."""
64
- # check if azure_ai_project has the keys: subscription_id, resource_group_name, project_name, credential
65
- if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
66
- msg = "azure_ai_project must contain keys: subscription_id, resource_group_name and project_name"
39
+
40
+ try:
41
+ self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
42
+ except EvaluationException as e:
67
43
  raise EvaluationException(
68
- message=msg,
69
- internal_message=msg,
44
+ message=e.message,
45
+ internal_message=e.internal_message,
70
46
  target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
71
- category=ErrorCategory.MISSING_FIELD,
72
- blame=ErrorBlame.USER_ERROR,
73
- )
74
- if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
75
- msg = "subscription_id, resource_group_name and project_name keys cannot be None"
76
- raise EvaluationException(
77
- message=msg,
78
- internal_message=msg,
79
- target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
80
- category=ErrorCategory.MISSING_FIELD,
81
- blame=ErrorBlame.USER_ERROR,
82
- )
83
- if "credential" not in azure_ai_project and not credential:
84
- credential = DefaultAzureCredential()
85
- elif "credential" in azure_ai_project:
86
- credential = azure_ai_project["credential"]
87
- self.credential = credential
88
- self.azure_ai_project = azure_ai_project
47
+ category=e.category,
48
+ blame=e.blame,
49
+ ) from e
50
+
51
+ self.credential = cast(TokenCredential, credential)
89
52
  self.token_manager = ManagedIdentityAPITokenManager(
90
53
  token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
91
54
  logger=logging.getLogger("AdversarialSimulator"),
92
- credential=credential,
55
+ credential=self.credential,
93
56
  )
94
- self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
57
+ self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager)
95
58
  self.adversarial_template_handler = AdversarialTemplateHandler(
96
- azure_ai_project=azure_ai_project, rai_client=self.rai_client
59
+ azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
97
60
  )
61
+ super().__init__(azure_ai_project=azure_ai_project, credential=credential)
98
62
 
99
63
  def _ensure_service_dependencies(self):
100
64
  if self.rai_client is None:
@@ -107,33 +71,25 @@ class IndirectAttackSimulator:
107
71
  blame=ErrorBlame.USER_ERROR,
108
72
  )
109
73
 
110
- # @monitor_adversarial_scenario
111
74
  async def __call__(
112
75
  self,
113
76
  *,
114
- scenario: AdversarialScenario,
115
77
  target: Callable,
116
- max_conversation_turns: int = 1,
117
78
  max_simulation_results: int = 3,
118
79
  api_call_retry_limit: int = 3,
119
80
  api_call_retry_sleep_sec: int = 1,
120
81
  api_call_delay_sec: int = 0,
121
82
  concurrent_async_task: int = 3,
83
+ **kwargs,
122
84
  ):
123
85
  """
124
86
  Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
125
87
  This simulator converses with your AI system using prompts injected into the context to interrupt normal
126
88
  expected functionality by eliciting manipulated content, intrusion and attempting to gather information outside
127
89
  the scope of your AI system.
128
-
129
- :keyword scenario: Enum value specifying the adversarial scenario used for generating inputs.
130
- :paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario
131
90
  :keyword target: The target function to simulate adversarial inputs against.
132
91
  This function should be asynchronous and accept a dictionary representing the adversarial input.
133
92
  :paramtype target: Callable
134
- :keyword max_conversation_turns: The maximum number of conversation turns to simulate.
135
- Defaults to 1.
136
- :paramtype max_conversation_turns: int
137
93
  :keyword max_simulation_results: The maximum number of simulation results to return.
138
94
  Defaults to 3.
139
95
  :paramtype max_simulation_results: int
@@ -170,11 +126,11 @@ class IndirectAttackSimulator:
170
126
  'template_parameters': {},
171
127
  'messages': [
172
128
  {
173
- 'content': '<jailbreak prompt> <adversarial query>',
129
+ 'content': '<adversarial query>',
174
130
  'role': 'user'
175
131
  },
176
132
  {
177
- 'content': "<response from endpoint>",
133
+ 'content': "<response from your callback>",
178
134
  'role': 'assistant',
179
135
  'context': None
180
136
  }
@@ -183,25 +139,72 @@ class IndirectAttackSimulator:
183
139
  }]
184
140
  }
185
141
  """
186
- if scenario not in AdversarialScenario.__members__.values():
187
- msg = f"Invalid scenario: {scenario}. Supported scenarios: {AdversarialScenario.__members__.values()}"
188
- raise EvaluationException(
189
- message=msg,
190
- internal_message=msg,
191
- target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
192
- category=ErrorCategory.INVALID_VALUE,
193
- blame=ErrorBlame.USER_ERROR,
142
+ # values that cannot be changed:
143
+ scenario = AdversarialScenarioJailbreak.ADVERSARIAL_INDIRECT_JAILBREAK
144
+ max_conversation_turns = 2
145
+ language = SupportedLanguages.English
146
+ self._ensure_service_dependencies()
147
+ templates = await self.adversarial_template_handler._get_content_harm_template_collections(scenario.value)
148
+ concurrent_async_task = min(concurrent_async_task, 1000)
149
+ semaphore = asyncio.Semaphore(concurrent_async_task)
150
+ sim_results = []
151
+ tasks = []
152
+ total_tasks = sum(len(t.template_parameters) for t in templates)
153
+ if max_simulation_results > total_tasks:
154
+ logger.warning(
155
+ "Cannot provide %s results due to maximum number of adversarial simulations that can be generated: %s."
156
+ "\n %s simulations will be generated.",
157
+ max_simulation_results,
158
+ total_tasks,
159
+ total_tasks,
194
160
  )
195
- jb_sim = AdversarialSimulator(azure_ai_project=self.azure_ai_project, credential=self.credential)
196
- jb_sim_results = await jb_sim(
197
- scenario=scenario,
198
- target=target,
199
- max_conversation_turns=max_conversation_turns,
200
- max_simulation_results=max_simulation_results,
201
- api_call_retry_limit=api_call_retry_limit,
202
- api_call_retry_sleep_sec=api_call_retry_sleep_sec,
203
- api_call_delay_sec=api_call_delay_sec,
204
- concurrent_async_task=concurrent_async_task,
205
- _jailbreak_type="xpia",
161
+ total_tasks = min(total_tasks, max_simulation_results)
162
+ progress_bar = tqdm(
163
+ total=total_tasks,
164
+ desc="generating jailbreak simulations",
165
+ ncols=100,
166
+ unit="simulations",
206
167
  )
207
- return jb_sim_results
168
+ for template in templates:
169
+ for parameter in template.template_parameters:
170
+ tasks.append(
171
+ asyncio.create_task(
172
+ self._simulate_async(
173
+ target=target,
174
+ template=template,
175
+ parameters=parameter,
176
+ max_conversation_turns=max_conversation_turns,
177
+ api_call_retry_limit=api_call_retry_limit,
178
+ api_call_retry_sleep_sec=api_call_retry_sleep_sec,
179
+ api_call_delay_sec=api_call_delay_sec,
180
+ language=language,
181
+ semaphore=semaphore,
182
+ )
183
+ )
184
+ )
185
+ if len(tasks) >= max_simulation_results:
186
+ break
187
+ if len(tasks) >= max_simulation_results:
188
+ break
189
+ for task in asyncio.as_completed(tasks):
190
+ completed_task = await task # type: ignore
191
+ template_parameters = completed_task.get("template_parameters", {}) # type: ignore
192
+ xpia_attack_type = template_parameters.get("xpia_attack_type", "") # type: ignore
193
+ action = template_parameters.get("action", "") # type: ignore
194
+ document_type = template_parameters.get("document_type", "") # type: ignore
195
+ sim_results.append(
196
+ {
197
+ "messages": completed_task["messages"], # type: ignore
198
+ "$schema": "http://azureml/sdk-2-0/ChatConversation.json",
199
+ "template_parameters": {
200
+ "metadata": {
201
+ "xpia_attack_type": xpia_attack_type,
202
+ "action": action,
203
+ "document_type": document_type,
204
+ },
205
+ },
206
+ }
207
+ )
208
+ progress_bar.update(1)
209
+ progress_bar.close()
210
+ return JsonLineList(sim_results)
@@ -3,13 +3,15 @@
3
3
  # ---------------------------------------------------------
4
4
 
5
5
  import asyncio
6
+ import inspect
6
7
  import logging
7
8
  import os
8
9
  import time
9
10
  from abc import ABC, abstractmethod
10
11
  from enum import Enum
11
- from typing import Dict, Optional, Union
12
+ from typing import Optional, Union
12
13
 
14
+ from azure.core.credentials import AccessToken, TokenCredential
13
15
  from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
14
16
 
15
17
  AZURE_TOKEN_REFRESH_INTERVAL = 600 # seconds
@@ -29,24 +31,24 @@ class APITokenManager(ABC):
29
31
  :param auth_header: Authorization header prefix. Defaults to "Bearer"
30
32
  :type auth_header: str
31
33
  :param credential: Azure credential object
32
- :type credential: Optional[Union[azure.identity.DefaultAzureCredential, azure.identity.ManagedIdentityCredential]
34
+ :type credential: Optional[TokenCredential]
33
35
  """
34
36
 
35
37
  def __init__(
36
38
  self,
37
39
  logger: logging.Logger,
38
40
  auth_header: str = "Bearer",
39
- credential: Optional[Union[DefaultAzureCredential, ManagedIdentityCredential]] = None,
41
+ credential: Optional[TokenCredential] = None,
40
42
  ) -> None:
41
43
  self.logger = logger
42
44
  self.auth_header = auth_header
43
- self._lock = None
45
+ self._lock: Optional[asyncio.Lock] = None
44
46
  if credential is not None:
45
47
  self.credential = credential
46
48
  else:
47
49
  self.credential = self.get_aad_credential()
48
- self.token = None
49
- self.last_refresh_time = None
50
+ self.token: Optional[str] = None
51
+ self.last_refresh_time: Optional[float] = None
50
52
 
51
53
  @property
52
54
  def lock(self) -> asyncio.Lock:
@@ -73,20 +75,26 @@ class APITokenManager(ABC):
73
75
  identity_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID", None)
74
76
  if identity_client_id is not None:
75
77
  self.logger.info(f"Using DEFAULT_IDENTITY_CLIENT_ID: {identity_client_id}")
76
- credential = ManagedIdentityCredential(client_id=identity_client_id)
77
- else:
78
- self.logger.info("Environment variable DEFAULT_IDENTITY_CLIENT_ID is not set, using DefaultAzureCredential")
79
- credential = DefaultAzureCredential()
80
- return credential
78
+ return ManagedIdentityCredential(client_id=identity_client_id)
79
+
80
+ self.logger.info("Environment variable DEFAULT_IDENTITY_CLIENT_ID is not set, using DefaultAzureCredential")
81
+ return DefaultAzureCredential()
82
+
83
+ @abstractmethod
84
+ def get_token(self) -> str:
85
+ """Async method to get the API token. Subclasses should implement this method.
86
+
87
+ :return: API token
88
+ :rtype: str
89
+ """
81
90
 
82
91
  @abstractmethod
83
- async def get_token(self) -> str:
92
+ async def get_token_async(self) -> str:
84
93
  """Async method to get the API token. Subclasses should implement this method.
85
94
 
86
95
  :return: API token
87
96
  :rtype: str
88
97
  """
89
- pass # pylint: disable=unnecessary-pass
90
98
 
91
99
 
92
100
  class ManagedIdentityAPITokenManager(APITokenManager):
@@ -100,12 +108,18 @@ class ManagedIdentityAPITokenManager(APITokenManager):
100
108
  :paramtype kwargs: Dict
101
109
  """
102
110
 
103
- def __init__(self, token_scope: TokenScope, logger: logging.Logger, **kwargs: Dict):
104
- super().__init__(logger, **kwargs)
111
+ def __init__(
112
+ self,
113
+ token_scope: TokenScope,
114
+ logger: logging.Logger,
115
+ *,
116
+ auth_header: str = "Bearer",
117
+ credential: Optional[TokenCredential] = None,
118
+ ):
119
+ super().__init__(logger, auth_header=auth_header, credential=credential)
105
120
  self.token_scope = token_scope
106
121
 
107
- # Bug 3353724: This get_token is sync method, but it is defined as async method in the base class
108
- def get_token(self) -> str: # pylint: disable=invalid-overridden-method
122
+ def get_token(self) -> str:
109
123
  """Get the API token. If the token is not available or has expired, refresh the token.
110
124
 
111
125
  :return: API token
@@ -122,6 +136,31 @@ class ManagedIdentityAPITokenManager(APITokenManager):
122
136
 
123
137
  return self.token
124
138
 
139
+ async def get_token_async(self) -> str:
140
+ """Get the API token synchronously. If the token is not available or has expired, refresh it.
141
+
142
+ :return: API token
143
+ :rtype: str
144
+ """
145
+ if (
146
+ self.token is None
147
+ or self.last_refresh_time is None
148
+ or time.time() - self.last_refresh_time > AZURE_TOKEN_REFRESH_INTERVAL
149
+ ):
150
+ self.last_refresh_time = time.time()
151
+ get_token_method = self.credential.get_token(self.token_scope.value)
152
+ if inspect.isawaitable(get_token_method):
153
+ # If it's awaitable, await it
154
+ token_response: AccessToken = await get_token_method
155
+ else:
156
+ # Otherwise, call it synchronously
157
+ token_response = get_token_method
158
+
159
+ self.token = token_response.token
160
+ self.logger.info("Refreshed Azure endpoint token.")
161
+
162
+ return self.token
163
+
125
164
 
126
165
  class PlainTokenManager(APITokenManager):
127
166
  """Plain API Token Manager
@@ -134,11 +173,18 @@ class PlainTokenManager(APITokenManager):
134
173
  :paramtype kwargs: Dict
135
174
  """
136
175
 
137
- def __init__(self, openapi_key: str, logger: logging.Logger, **kwargs: Dict):
138
- super().__init__(logger, **kwargs)
139
- self.token = openapi_key
176
+ def __init__(
177
+ self,
178
+ openapi_key: str,
179
+ logger: logging.Logger,
180
+ *,
181
+ auth_header: str = "Bearer",
182
+ credential: Optional[TokenCredential] = None,
183
+ ) -> None:
184
+ super().__init__(logger, auth_header=auth_header, credential=credential)
185
+ self.token: str = openapi_key
140
186
 
141
- async def get_token(self) -> str:
187
+ def get_token(self) -> str:
142
188
  """Get the API token
143
189
 
144
190
  :return: API token