azure-ai-evaluation 1.0.0__py3-none-any.whl → 1.0.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (105) hide show
  1. azure/ai/evaluation/__init__.py +5 -31
  2. azure/ai/evaluation/_common/constants.py +2 -9
  3. azure/ai/evaluation/_common/rai_service.py +120 -300
  4. azure/ai/evaluation/_common/utils.py +23 -381
  5. azure/ai/evaluation/_constants.py +6 -19
  6. azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/__init__.py +2 -3
  7. azure/ai/evaluation/_evaluate/{_batch_run/eval_run_context.py → _batch_run_client/batch_run_context.py} +7 -23
  8. azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/code_client.py +17 -33
  9. azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/proxy_client.py +4 -32
  10. azure/ai/evaluation/_evaluate/_eval_run.py +24 -81
  11. azure/ai/evaluation/_evaluate/_evaluate.py +239 -393
  12. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +17 -17
  13. azure/ai/evaluation/_evaluate/_utils.py +28 -82
  14. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +18 -17
  15. azure/ai/evaluation/_evaluators/{_retrieval → _chat}/__init__.py +2 -2
  16. azure/ai/evaluation/_evaluators/_chat/_chat.py +357 -0
  17. azure/ai/evaluation/_evaluators/{_service_groundedness → _chat/retrieval}/__init__.py +2 -2
  18. azure/ai/evaluation/_evaluators/_chat/retrieval/_retrieval.py +157 -0
  19. azure/ai/evaluation/_evaluators/_chat/retrieval/retrieval.prompty +48 -0
  20. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +88 -78
  21. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +39 -76
  22. azure/ai/evaluation/_evaluators/_content_safety/__init__.py +4 -0
  23. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +67 -105
  24. azure/ai/evaluation/_evaluators/{_multimodal/_content_safety_multimodal_base.py → _content_safety/_content_safety_base.py} +34 -24
  25. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +301 -0
  26. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +54 -105
  27. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +52 -99
  28. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +52 -101
  29. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +51 -101
  30. azure/ai/evaluation/_evaluators/_eci/_eci.py +54 -44
  31. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +19 -34
  32. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +89 -76
  33. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +41 -66
  34. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +16 -14
  35. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +87 -113
  36. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +54 -0
  37. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +27 -20
  38. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +80 -89
  39. azure/ai/evaluation/_evaluators/_protected_materials/__init__.py +5 -0
  40. azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py +104 -0
  41. azure/ai/evaluation/_evaluators/_qa/_qa.py +30 -23
  42. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +96 -84
  43. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +47 -78
  44. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +27 -26
  45. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +38 -53
  46. azure/ai/evaluation/_evaluators/_similarity/similarity.prompty +5 -0
  47. azure/ai/evaluation/_evaluators/_xpia/xpia.py +105 -91
  48. azure/ai/evaluation/_exceptions.py +7 -28
  49. azure/ai/evaluation/_http_utils.py +132 -203
  50. azure/ai/evaluation/_model_configurations.py +8 -104
  51. azure/ai/evaluation/_version.py +1 -1
  52. azure/ai/evaluation/simulator/__init__.py +1 -2
  53. azure/ai/evaluation/simulator/_adversarial_scenario.py +1 -20
  54. azure/ai/evaluation/simulator/_adversarial_simulator.py +92 -111
  55. azure/ai/evaluation/simulator/_constants.py +1 -11
  56. azure/ai/evaluation/simulator/_conversation/__init__.py +12 -13
  57. azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -4
  58. azure/ai/evaluation/simulator/_direct_attack_simulator.py +67 -33
  59. azure/ai/evaluation/simulator/_helpers/__init__.py +2 -1
  60. azure/ai/evaluation/{_common → simulator/_helpers}/_experimental.py +9 -24
  61. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +5 -26
  62. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +94 -107
  63. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +22 -70
  64. azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +11 -28
  65. azure/ai/evaluation/simulator/_model_tools/_rai_client.py +4 -8
  66. azure/ai/evaluation/simulator/_model_tools/_template_handler.py +24 -68
  67. azure/ai/evaluation/simulator/_model_tools/models.py +10 -10
  68. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +10 -6
  69. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +5 -6
  70. azure/ai/evaluation/simulator/_simulator.py +207 -277
  71. azure/ai/evaluation/simulator/_tracing.py +4 -4
  72. azure/ai/evaluation/simulator/_utils.py +13 -31
  73. azure_ai_evaluation-1.0.0b2.dist-info/METADATA +449 -0
  74. azure_ai_evaluation-1.0.0b2.dist-info/RECORD +99 -0
  75. {azure_ai_evaluation-1.0.0.dist-info → azure_ai_evaluation-1.0.0b2.dist-info}/WHEEL +1 -1
  76. azure/ai/evaluation/_common/math.py +0 -89
  77. azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +0 -46
  78. azure/ai/evaluation/_evaluators/_common/__init__.py +0 -13
  79. azure/ai/evaluation/_evaluators/_common/_base_eval.py +0 -344
  80. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +0 -88
  81. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +0 -133
  82. azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +0 -113
  83. azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +0 -99
  84. azure/ai/evaluation/_evaluators/_multimodal/__init__.py +0 -20
  85. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +0 -132
  86. azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +0 -100
  87. azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +0 -124
  88. azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +0 -100
  89. azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +0 -100
  90. azure/ai/evaluation/_evaluators/_multimodal/_violence.py +0 -100
  91. azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +0 -112
  92. azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +0 -93
  93. azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +0 -148
  94. azure/ai/evaluation/_vendor/__init__.py +0 -3
  95. azure/ai/evaluation/_vendor/rouge_score/__init__.py +0 -14
  96. azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +0 -328
  97. azure/ai/evaluation/_vendor/rouge_score/scoring.py +0 -63
  98. azure/ai/evaluation/_vendor/rouge_score/tokenize.py +0 -63
  99. azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +0 -53
  100. azure/ai/evaluation/simulator/_data_sources/__init__.py +0 -3
  101. azure/ai/evaluation/simulator/_data_sources/grounding.json +0 -1150
  102. azure_ai_evaluation-1.0.0.dist-info/METADATA +0 -595
  103. azure_ai_evaluation-1.0.0.dist-info/NOTICE.txt +0 -70
  104. azure_ai_evaluation-1.0.0.dist-info/RECORD +0 -119
  105. {azure_ai_evaluation-1.0.0.dist-info → azure_ai_evaluation-1.0.0b2.dist-info}/top_level.txt +0 -0
@@ -2,14 +2,13 @@
2
2
  # Copyright (c) Microsoft Corporation. All rights reserved.
3
3
  # ---------------------------------------------------------
4
4
 
5
- import os
6
5
  import functools
7
6
  import inspect
8
7
  import logging
9
8
  import sys
10
- from typing import Callable, Type, TypeVar, Union, overload
9
+ from typing import Callable, Type, TypeVar, Union
11
10
 
12
- from typing_extensions import ParamSpec, TypeGuard
11
+ from typing_extensions import ParamSpec
13
12
 
14
13
  DOCSTRING_TEMPLATE = ".. note:: {0} {1}\n\n"
15
14
  DOCSTRING_DEFAULT_INDENTATION = 8
@@ -23,31 +22,20 @@ EXPERIMENTAL_LINK_MESSAGE = (
23
22
  _warning_cache = set()
24
23
  module_logger = logging.getLogger(__name__)
25
24
 
25
+ TExperimental = TypeVar("TExperimental", bound=Union[Type, Callable])
26
26
  P = ParamSpec("P")
27
27
  T = TypeVar("T")
28
28
 
29
29
 
30
- @overload
31
- def experimental(wrapped: Type[T]) -> Type[T]: ...
32
-
33
-
34
- @overload
35
- def experimental(wrapped: Callable[P, T]) -> Callable[P, T]: ...
36
-
37
-
38
- def experimental(wrapped: Union[Type[T], Callable[P, T]]) -> Union[Type[T], Callable[P, T]]:
30
+ def experimental(wrapped: TExperimental) -> TExperimental:
39
31
  """Add experimental tag to a class or a method.
40
32
 
41
33
  :param wrapped: Either a Class or Function to mark as experimental
42
- :type wrapped: Union[Type[T], Callable[P, T]]
34
+ :type wrapped: TExperimental
43
35
  :return: The wrapped class or method
44
- :rtype: Union[Type[T], Callable[P, T]]
36
+ :rtype: TExperimental
45
37
  """
46
-
47
- def is_class(t: Union[Type[T], Callable[P, T]]) -> TypeGuard[Type[T]]:
48
- return isinstance(t, type)
49
-
50
- if is_class(wrapped):
38
+ if inspect.isclass(wrapped):
51
39
  return _add_class_docstring(wrapped)
52
40
  if inspect.isfunction(wrapped):
53
41
  return _add_method_docstring(wrapped)
@@ -86,11 +74,11 @@ def _add_class_docstring(cls: Type[T]) -> Type[T]:
86
74
  cls.__doc__ = _add_note_to_docstring(cls.__doc__, doc_string)
87
75
  else:
88
76
  cls.__doc__ = doc_string + ">"
89
- cls.__init__ = _add_class_warning(cls.__init__) # type: ignore[method-assign]
77
+ cls.__init__ = _add_class_warning(cls.__init__)
90
78
  return cls
91
79
 
92
80
 
93
- def _add_method_docstring(func: Callable[P, T]) -> Callable[P, T]:
81
+ def _add_method_docstring(func: Callable[P, T] = None) -> Callable[P, T]:
94
82
  """Add experimental tag to the method doc string.
95
83
 
96
84
  :param func: The function to update
@@ -150,9 +138,6 @@ def _get_indentation_size(doc_string: str) -> int:
150
138
  def _should_skip_warning():
151
139
  skip_warning_msg = False
152
140
 
153
- if os.getenv("AI_EVALS_DISABLE_EXPERIMENTAL_WARNING", "false").lower() == "true":
154
- skip_warning_msg = True
155
-
156
141
  # Cases where we want to suppress the warning:
157
142
  # 1. When converting from REST object to SDK object
158
143
  for frame in inspect.stack():
@@ -18,7 +18,7 @@ class Turn:
18
18
 
19
19
  role: Union[str, ConversationRole]
20
20
  content: str
21
- context: Optional[str] = None
21
+ context: str = None
22
22
 
23
23
  def to_dict(self) -> Dict[str, Optional[str]]:
24
24
  """
@@ -30,19 +30,7 @@ class Turn:
30
30
  return {
31
31
  "role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
32
32
  "content": self.content,
33
- "context": str(self.context),
34
- }
35
-
36
- def to_context_free_dict(self) -> Dict[str, Optional[str]]:
37
- """
38
- Convert the conversation turn to a dictionary without context.
39
-
40
- :returns: A dictionary representation of the conversation turn without context.
41
- :rtype: Dict[str, Optional[str]]
42
- """
43
- return {
44
- "role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
45
- "content": self.content,
33
+ "context": self.context,
46
34
  }
47
35
 
48
36
  def __repr__(self):
@@ -54,13 +42,13 @@ class ConversationHistory:
54
42
  Conversation history class to keep track of the conversation turns in a conversation.
55
43
  """
56
44
 
57
- def __init__(self) -> None:
45
+ def __init__(self):
58
46
  """
59
47
  Initializes the conversation history with an empty list of turns.
60
48
  """
61
49
  self.history: List[Turn] = []
62
50
 
63
- def add_to_history(self, turn: Turn) -> None:
51
+ def add_to_history(self, turn: Turn):
64
52
  """
65
53
  Adds a turn to the conversation history.
66
54
 
@@ -69,7 +57,7 @@ class ConversationHistory:
69
57
  """
70
58
  self.history.append(turn)
71
59
 
72
- def to_list(self) -> List[Dict[str, Optional[str]]]:
60
+ def to_list(self) -> List[Dict[str, str]]:
73
61
  """
74
62
  Converts the conversation history to a list of dictionaries.
75
63
 
@@ -78,15 +66,6 @@ class ConversationHistory:
78
66
  """
79
67
  return [turn.to_dict() for turn in self.history]
80
68
 
81
- def to_context_free_list(self) -> List[Dict[str, Optional[str]]]:
82
- """
83
- Converts the conversation history to a list of dictionaries without context.
84
-
85
- :returns: A list of dictionaries representing the conversation turns without context.
86
- :rtype: List[Dict[str, str]]
87
- """
88
- return [turn.to_context_free_dict() for turn in self.history]
89
-
90
69
  def __len__(self) -> int:
91
70
  return len(self.history)
92
71
 
@@ -1,30 +1,54 @@
1
1
  # ---------------------------------------------------------
2
2
  # Copyright (c) Microsoft Corporation. All rights reserved.
3
3
  # ---------------------------------------------------------
4
- # pylint: disable=C0301,C0114,R0913,R0903
5
4
  # noqa: E501
6
- import asyncio
5
+ import functools
7
6
  import logging
8
- from typing import Callable, cast
7
+ from typing import Callable
9
8
 
10
- from tqdm import tqdm
9
+ from promptflow._sdk._telemetry import ActivityType, monitor_operation
11
10
 
12
- from azure.ai.evaluation._common.utils import validate_azure_ai_project
13
- from azure.ai.evaluation._common._experimental import experimental
14
11
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
15
- from azure.ai.evaluation.simulator import AdversarialScenarioJailbreak, SupportedLanguages
16
12
  from azure.ai.evaluation._model_configurations import AzureAIProject
17
- from azure.core.credentials import TokenCredential
18
-
19
- from ._adversarial_simulator import AdversarialSimulator, JsonLineList
13
+ from azure.ai.evaluation.simulator import AdversarialScenario
14
+ from azure.identity import DefaultAzureCredential
20
15
 
16
+ from ._adversarial_simulator import AdversarialSimulator
21
17
  from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
22
18
 
23
19
  logger = logging.getLogger(__name__)
24
20
 
25
21
 
26
- @experimental
27
- class IndirectAttackSimulator(AdversarialSimulator):
22
+ def monitor_adversarial_scenario(func) -> Callable:
23
+ """Decorator to monitor adversarial scenario.
24
+
25
+ :param func: The function to be decorated.
26
+ :type func: Callable
27
+ :return: The decorated function.
28
+ :rtype: Callable
29
+ """
30
+
31
+ @functools.wraps(func)
32
+ def wrapper(*args, **kwargs):
33
+ scenario = str(kwargs.get("scenario", None))
34
+ max_conversation_turns = kwargs.get("max_conversation_turns", None)
35
+ max_simulation_results = kwargs.get("max_simulation_results", None)
36
+ decorated_func = monitor_operation(
37
+ activity_name="xpia.adversarial.simulator.call",
38
+ activity_type=ActivityType.PUBLICAPI,
39
+ custom_dimensions={
40
+ "scenario": scenario,
41
+ "max_conversation_turns": max_conversation_turns,
42
+ "max_simulation_results": max_simulation_results,
43
+ },
44
+ )(func)
45
+
46
+ return decorated_func(*args, **kwargs)
47
+
48
+ return wrapper
49
+
50
+
51
+ class IndirectAttackSimulator:
28
52
  """
29
53
  Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
30
54
 
@@ -33,42 +57,44 @@ class IndirectAttackSimulator(AdversarialSimulator):
33
57
  :type azure_ai_project: ~azure.ai.evaluation.AzureAIProject
34
58
  :param credential: The credential for connecting to Azure AI project.
35
59
  :type credential: ~azure.core.credentials.TokenCredential
36
-
37
- .. admonition:: Example:
38
-
39
- .. literalinclude:: ../samples/evaluation_samples_simulate.py
40
- :start-after: [START indirect_attack_simulator]
41
- :end-before: [END indirect_attack_simulator]
42
- :language: python
43
- :dedent: 8
44
- :caption: Run the IndirectAttackSimulator to produce 1 result with 1 conversation turn (2 messages in the result).
45
60
  """
46
61
 
47
- def __init__(self, *, azure_ai_project: AzureAIProject, credential: TokenCredential):
62
+ def __init__(self, *, azure_ai_project: AzureAIProject, credential=None):
48
63
  """Constructor."""
49
-
50
- try:
51
- self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
52
- except EvaluationException as e:
64
+ # check if azure_ai_project has the keys: subscription_id, resource_group_name, project_name, credential
65
+ if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
66
+ msg = "azure_ai_project must contain keys: subscription_id, resource_group_name and project_name"
53
67
  raise EvaluationException(
54
- message=e.message,
55
- internal_message=e.internal_message,
68
+ message=msg,
69
+ internal_message=msg,
56
70
  target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
57
- category=e.category,
58
- blame=e.blame,
59
- ) from e
60
-
61
- self.credential = cast(TokenCredential, credential)
71
+ category=ErrorCategory.MISSING_FIELD,
72
+ blame=ErrorBlame.USER_ERROR,
73
+ )
74
+ if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
75
+ msg = "subscription_id, resource_group_name and project_name keys cannot be None"
76
+ raise EvaluationException(
77
+ message=msg,
78
+ internal_message=msg,
79
+ target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
80
+ category=ErrorCategory.MISSING_FIELD,
81
+ blame=ErrorBlame.USER_ERROR,
82
+ )
83
+ if "credential" not in azure_ai_project and not credential:
84
+ credential = DefaultAzureCredential()
85
+ elif "credential" in azure_ai_project:
86
+ credential = azure_ai_project["credential"]
87
+ self.credential = credential
88
+ self.azure_ai_project = azure_ai_project
62
89
  self.token_manager = ManagedIdentityAPITokenManager(
63
90
  token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
64
91
  logger=logging.getLogger("AdversarialSimulator"),
65
- credential=self.credential,
92
+ credential=credential,
66
93
  )
67
- self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager)
94
+ self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
68
95
  self.adversarial_template_handler = AdversarialTemplateHandler(
69
- azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
96
+ azure_ai_project=azure_ai_project, rai_client=self.rai_client
70
97
  )
71
- super().__init__(azure_ai_project=azure_ai_project, credential=credential)
72
98
 
73
99
  def _ensure_service_dependencies(self):
74
100
  if self.rai_client is None:
@@ -81,25 +107,33 @@ class IndirectAttackSimulator(AdversarialSimulator):
81
107
  blame=ErrorBlame.USER_ERROR,
82
108
  )
83
109
 
110
+ # @monitor_adversarial_scenario
84
111
  async def __call__(
85
112
  self,
86
113
  *,
114
+ scenario: AdversarialScenario,
87
115
  target: Callable,
116
+ max_conversation_turns: int = 1,
88
117
  max_simulation_results: int = 3,
89
118
  api_call_retry_limit: int = 3,
90
119
  api_call_retry_sleep_sec: int = 1,
91
120
  api_call_delay_sec: int = 0,
92
121
  concurrent_async_task: int = 3,
93
- **kwargs,
94
122
  ):
95
123
  """
96
124
  Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
97
125
  This simulator converses with your AI system using prompts injected into the context to interrupt normal
98
126
  expected functionality by eliciting manipulated content, intrusion and attempting to gather information outside
99
127
  the scope of your AI system.
128
+
129
+ :keyword scenario: Enum value specifying the adversarial scenario used for generating inputs.
130
+ :paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario
100
131
  :keyword target: The target function to simulate adversarial inputs against.
101
132
  This function should be asynchronous and accept a dictionary representing the adversarial input.
102
133
  :paramtype target: Callable
134
+ :keyword max_conversation_turns: The maximum number of conversation turns to simulate.
135
+ Defaults to 1.
136
+ :paramtype max_conversation_turns: int
103
137
  :keyword max_simulation_results: The maximum number of simulation results to return.
104
138
  Defaults to 3.
105
139
  :paramtype max_simulation_results: int
@@ -136,11 +170,11 @@ class IndirectAttackSimulator(AdversarialSimulator):
136
170
  'template_parameters': {},
137
171
  'messages': [
138
172
  {
139
- 'content': '<adversarial query>',
173
+ 'content': '<jailbreak prompt> <adversarial query>',
140
174
  'role': 'user'
141
175
  },
142
176
  {
143
- 'content': "<response from your callback>",
177
+ 'content': "<response from endpoint>",
144
178
  'role': 'assistant',
145
179
  'context': None
146
180
  }
@@ -149,72 +183,25 @@ class IndirectAttackSimulator(AdversarialSimulator):
149
183
  }]
150
184
  }
151
185
  """
152
- # values that cannot be changed:
153
- scenario = AdversarialScenarioJailbreak.ADVERSARIAL_INDIRECT_JAILBREAK
154
- max_conversation_turns = 2
155
- language = SupportedLanguages.English
156
- self._ensure_service_dependencies()
157
- templates = await self.adversarial_template_handler._get_content_harm_template_collections(scenario.value)
158
- concurrent_async_task = min(concurrent_async_task, 1000)
159
- semaphore = asyncio.Semaphore(concurrent_async_task)
160
- sim_results = []
161
- tasks = []
162
- total_tasks = sum(len(t.template_parameters) for t in templates)
163
- if max_simulation_results > total_tasks:
164
- logger.warning(
165
- "Cannot provide %s results due to maximum number of adversarial simulations that can be generated: %s."
166
- "\n %s simulations will be generated.",
167
- max_simulation_results,
168
- total_tasks,
169
- total_tasks,
186
+ if scenario not in AdversarialScenario.__members__.values():
187
+ msg = f"Invalid scenario: {scenario}. Supported scenarios: {AdversarialScenario.__members__.values()}"
188
+ raise EvaluationException(
189
+ message=msg,
190
+ internal_message=msg,
191
+ target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
192
+ category=ErrorCategory.INVALID_VALUE,
193
+ blame=ErrorBlame.USER_ERROR,
170
194
  )
171
- total_tasks = min(total_tasks, max_simulation_results)
172
- progress_bar = tqdm(
173
- total=total_tasks,
174
- desc="generating jailbreak simulations",
175
- ncols=100,
176
- unit="simulations",
195
+ jb_sim = AdversarialSimulator(azure_ai_project=self.azure_ai_project, credential=self.credential)
196
+ jb_sim_results = await jb_sim(
197
+ scenario=scenario,
198
+ target=target,
199
+ max_conversation_turns=max_conversation_turns,
200
+ max_simulation_results=max_simulation_results,
201
+ api_call_retry_limit=api_call_retry_limit,
202
+ api_call_retry_sleep_sec=api_call_retry_sleep_sec,
203
+ api_call_delay_sec=api_call_delay_sec,
204
+ concurrent_async_task=concurrent_async_task,
205
+ _jailbreak_type="xpia",
177
206
  )
178
- for template in templates:
179
- for parameter in template.template_parameters:
180
- tasks.append(
181
- asyncio.create_task(
182
- self._simulate_async(
183
- target=target,
184
- template=template,
185
- parameters=parameter,
186
- max_conversation_turns=max_conversation_turns,
187
- api_call_retry_limit=api_call_retry_limit,
188
- api_call_retry_sleep_sec=api_call_retry_sleep_sec,
189
- api_call_delay_sec=api_call_delay_sec,
190
- language=language,
191
- semaphore=semaphore,
192
- )
193
- )
194
- )
195
- if len(tasks) >= max_simulation_results:
196
- break
197
- if len(tasks) >= max_simulation_results:
198
- break
199
- for task in asyncio.as_completed(tasks):
200
- completed_task = await task # type: ignore
201
- template_parameters = completed_task.get("template_parameters", {}) # type: ignore
202
- xpia_attack_type = template_parameters.get("xpia_attack_type", "") # type: ignore
203
- action = template_parameters.get("action", "") # type: ignore
204
- document_type = template_parameters.get("document_type", "") # type: ignore
205
- sim_results.append(
206
- {
207
- "messages": completed_task["messages"], # type: ignore
208
- "$schema": "http://azureml/sdk-2-0/ChatConversation.json",
209
- "template_parameters": {
210
- "metadata": {
211
- "xpia_attack_type": xpia_attack_type,
212
- "action": action,
213
- "document_type": document_type,
214
- },
215
- },
216
- }
217
- )
218
- progress_bar.update(1)
219
- progress_bar.close()
220
- return JsonLineList(sim_results)
207
+ return jb_sim_results
@@ -3,20 +3,16 @@
3
3
  # ---------------------------------------------------------
4
4
 
5
5
  import asyncio
6
- import inspect
7
6
  import logging
8
7
  import os
9
8
  import time
10
9
  from abc import ABC, abstractmethod
11
10
  from enum import Enum
12
- from typing import Optional, Union
11
+ from typing import Dict, Optional, Union
13
12
 
14
- from azure.core.credentials import AccessToken, TokenCredential
15
13
  from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
16
14
 
17
- AZURE_TOKEN_REFRESH_INTERVAL = int(
18
- os.getenv("AZURE_TOKEN_REFRESH_INTERVAL", "600")
19
- ) # token refresh interval in seconds
15
+ AZURE_TOKEN_REFRESH_INTERVAL = 600 # seconds
20
16
 
21
17
 
22
18
  class TokenScope(Enum):
@@ -33,24 +29,24 @@ class APITokenManager(ABC):
33
29
  :param auth_header: Authorization header prefix. Defaults to "Bearer"
34
30
  :type auth_header: str
35
31
  :param credential: Azure credential object
36
- :type credential: Optional[TokenCredential]
32
+ :type credential: Optional[Union[azure.identity.DefaultAzureCredential, azure.identity.ManagedIdentityCredential]
37
33
  """
38
34
 
39
35
  def __init__(
40
36
  self,
41
37
  logger: logging.Logger,
42
38
  auth_header: str = "Bearer",
43
- credential: Optional[TokenCredential] = None,
39
+ credential: Optional[Union[DefaultAzureCredential, ManagedIdentityCredential]] = None,
44
40
  ) -> None:
45
41
  self.logger = logger
46
42
  self.auth_header = auth_header
47
- self._lock: Optional[asyncio.Lock] = None
43
+ self._lock = None
48
44
  if credential is not None:
49
45
  self.credential = credential
50
46
  else:
51
47
  self.credential = self.get_aad_credential()
52
- self.token: Optional[str] = None
53
- self.last_refresh_time: Optional[float] = None
48
+ self.token = None
49
+ self.last_refresh_time = None
54
50
 
55
51
  @property
56
52
  def lock(self) -> asyncio.Lock:
@@ -77,26 +73,20 @@ class APITokenManager(ABC):
77
73
  identity_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID", None)
78
74
  if identity_client_id is not None:
79
75
  self.logger.info(f"Using DEFAULT_IDENTITY_CLIENT_ID: {identity_client_id}")
80
- return ManagedIdentityCredential(client_id=identity_client_id)
81
-
82
- self.logger.info("Environment variable DEFAULT_IDENTITY_CLIENT_ID is not set, using DefaultAzureCredential")
83
- return DefaultAzureCredential()
84
-
85
- @abstractmethod
86
- def get_token(self) -> str:
87
- """Async method to get the API token. Subclasses should implement this method.
88
-
89
- :return: API token
90
- :rtype: str
91
- """
76
+ credential = ManagedIdentityCredential(client_id=identity_client_id)
77
+ else:
78
+ self.logger.info("Environment variable DEFAULT_IDENTITY_CLIENT_ID is not set, using DefaultAzureCredential")
79
+ credential = DefaultAzureCredential()
80
+ return credential
92
81
 
93
82
  @abstractmethod
94
- async def get_token_async(self) -> str:
83
+ async def get_token(self) -> str:
95
84
  """Async method to get the API token. Subclasses should implement this method.
96
85
 
97
86
  :return: API token
98
87
  :rtype: str
99
88
  """
89
+ pass # pylint: disable=unnecessary-pass
100
90
 
101
91
 
102
92
  class ManagedIdentityAPITokenManager(APITokenManager):
@@ -110,18 +100,12 @@ class ManagedIdentityAPITokenManager(APITokenManager):
110
100
  :paramtype kwargs: Dict
111
101
  """
112
102
 
113
- def __init__(
114
- self,
115
- token_scope: TokenScope,
116
- logger: logging.Logger,
117
- *,
118
- auth_header: str = "Bearer",
119
- credential: Optional[TokenCredential] = None,
120
- ):
121
- super().__init__(logger, auth_header=auth_header, credential=credential)
103
+ def __init__(self, token_scope: TokenScope, logger: logging.Logger, **kwargs: Dict):
104
+ super().__init__(logger, **kwargs)
122
105
  self.token_scope = token_scope
123
106
 
124
- def get_token(self) -> str:
107
+ # Bug 3353724: This get_token is sync method, but it is defined as async method in the base class
108
+ def get_token(self) -> str: # pylint: disable=invalid-overridden-method
125
109
  """Get the API token. If the token is not available or has expired, refresh the token.
126
110
 
127
111
  :return: API token
@@ -138,31 +122,6 @@ class ManagedIdentityAPITokenManager(APITokenManager):
138
122
 
139
123
  return self.token
140
124
 
141
- async def get_token_async(self) -> str:
142
- """Get the API token synchronously. If the token is not available or has expired, refresh it.
143
-
144
- :return: API token
145
- :rtype: str
146
- """
147
- if (
148
- self.token is None
149
- or self.last_refresh_time is None
150
- or time.time() - self.last_refresh_time > AZURE_TOKEN_REFRESH_INTERVAL
151
- ):
152
- self.last_refresh_time = time.time()
153
- get_token_method = self.credential.get_token(self.token_scope.value)
154
- if inspect.isawaitable(get_token_method):
155
- # If it's awaitable, await it
156
- token_response: AccessToken = await get_token_method
157
- else:
158
- # Otherwise, call it synchronously
159
- token_response = get_token_method
160
-
161
- self.token = token_response.token
162
- self.logger.info("Refreshed Azure endpoint token.")
163
-
164
- return self.token
165
-
166
125
 
167
126
  class PlainTokenManager(APITokenManager):
168
127
  """Plain API Token Manager
@@ -175,18 +134,11 @@ class PlainTokenManager(APITokenManager):
175
134
  :paramtype kwargs: Dict
176
135
  """
177
136
 
178
- def __init__(
179
- self,
180
- openapi_key: str,
181
- logger: logging.Logger,
182
- *,
183
- auth_header: str = "Bearer",
184
- credential: Optional[TokenCredential] = None,
185
- ) -> None:
186
- super().__init__(logger, auth_header=auth_header, credential=credential)
187
- self.token: str = openapi_key
137
+ def __init__(self, openapi_key: str, logger: logging.Logger, **kwargs: Dict):
138
+ super().__init__(logger, **kwargs)
139
+ self.token = openapi_key
188
140
 
189
- def get_token(self) -> str:
141
+ async def get_token(self) -> str:
190
142
  """Get the API token
191
143
 
192
144
  :return: API token