azure-ai-evaluation 1.0.0b2__py3-none-any.whl → 1.0.0b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (78) hide show
  1. azure/ai/evaluation/__init__.py +9 -5
  2. azure/ai/evaluation/_common/constants.py +4 -2
  3. azure/ai/evaluation/_common/math.py +18 -0
  4. azure/ai/evaluation/_common/rai_service.py +54 -62
  5. azure/ai/evaluation/_common/utils.py +201 -16
  6. azure/ai/evaluation/_constants.py +12 -0
  7. azure/ai/evaluation/_evaluate/_batch_run_client/batch_run_context.py +10 -3
  8. azure/ai/evaluation/_evaluate/_batch_run_client/code_client.py +33 -17
  9. azure/ai/evaluation/_evaluate/_batch_run_client/proxy_client.py +17 -2
  10. azure/ai/evaluation/_evaluate/_eval_run.py +26 -10
  11. azure/ai/evaluation/_evaluate/_evaluate.py +161 -89
  12. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +16 -17
  13. azure/ai/evaluation/_evaluate/_utils.py +44 -25
  14. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +33 -79
  15. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +0 -5
  16. azure/ai/evaluation/_evaluators/_common/__init__.py +13 -0
  17. azure/ai/evaluation/_evaluators/_common/_base_eval.py +331 -0
  18. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +76 -0
  19. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +97 -0
  20. azure/ai/evaluation/_evaluators/_content_safety/__init__.py +0 -4
  21. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +15 -20
  22. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +63 -42
  23. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +18 -41
  24. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +18 -39
  25. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +18 -39
  26. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +18 -39
  27. azure/ai/evaluation/_evaluators/_eci/_eci.py +18 -55
  28. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +14 -6
  29. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +30 -74
  30. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +0 -5
  31. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +34 -80
  32. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -5
  33. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +18 -65
  34. azure/ai/evaluation/_evaluators/_qa/_qa.py +4 -3
  35. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +35 -83
  36. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +0 -5
  37. azure/ai/evaluation/_evaluators/{_chat → _retrieval}/__init__.py +2 -2
  38. azure/ai/evaluation/_evaluators/{_chat/retrieval → _retrieval}/_retrieval.py +25 -28
  39. azure/ai/evaluation/_evaluators/{_chat/retrieval → _retrieval}/retrieval.prompty +0 -5
  40. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +1 -1
  41. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +23 -17
  42. azure/ai/evaluation/_evaluators/_similarity/similarity.prompty +0 -5
  43. azure/ai/evaluation/_evaluators/_xpia/xpia.py +15 -90
  44. azure/ai/evaluation/_exceptions.py +9 -7
  45. azure/ai/evaluation/_http_utils.py +203 -132
  46. azure/ai/evaluation/_model_configurations.py +37 -9
  47. azure/ai/evaluation/{_evaluators/_chat/retrieval → _vendor}/__init__.py +0 -6
  48. azure/ai/evaluation/_vendor/rouge_score/__init__.py +14 -0
  49. azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +328 -0
  50. azure/ai/evaluation/_vendor/rouge_score/scoring.py +63 -0
  51. azure/ai/evaluation/_vendor/rouge_score/tokenize.py +63 -0
  52. azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +53 -0
  53. azure/ai/evaluation/_version.py +1 -1
  54. azure/ai/evaluation/simulator/_adversarial_simulator.py +85 -60
  55. azure/ai/evaluation/simulator/_conversation/__init__.py +13 -12
  56. azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -4
  57. azure/ai/evaluation/simulator/_direct_attack_simulator.py +24 -66
  58. azure/ai/evaluation/simulator/_helpers/_experimental.py +20 -9
  59. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +4 -4
  60. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +22 -64
  61. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +67 -21
  62. azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +28 -11
  63. azure/ai/evaluation/simulator/_model_tools/_template_handler.py +68 -24
  64. azure/ai/evaluation/simulator/_model_tools/models.py +10 -10
  65. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +2 -6
  66. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +0 -4
  67. azure/ai/evaluation/simulator/_simulator.py +127 -117
  68. azure/ai/evaluation/simulator/_tracing.py +4 -4
  69. {azure_ai_evaluation-1.0.0b2.dist-info → azure_ai_evaluation-1.0.0b4.dist-info}/METADATA +129 -43
  70. azure_ai_evaluation-1.0.0b4.dist-info/NOTICE.txt +50 -0
  71. azure_ai_evaluation-1.0.0b4.dist-info/RECORD +106 -0
  72. azure/ai/evaluation/_evaluators/_chat/_chat.py +0 -357
  73. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_base.py +0 -65
  74. azure/ai/evaluation/_evaluators/_protected_materials/__init__.py +0 -5
  75. azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py +0 -104
  76. azure_ai_evaluation-1.0.0b2.dist-info/RECORD +0 -99
  77. {azure_ai_evaluation-1.0.0b2.dist-info → azure_ai_evaluation-1.0.0b4.dist-info}/WHEEL +0 -0
  78. {azure_ai_evaluation-1.0.0b2.dist-info → azure_ai_evaluation-1.0.0b4.dist-info}/top_level.txt +0 -0
@@ -1,53 +1,24 @@
1
1
  # ---------------------------------------------------------
2
2
  # Copyright (c) Microsoft Corporation. All rights reserved.
3
3
  # ---------------------------------------------------------
4
+ # pylint: disable=C0301,C0114,R0913,R0903
4
5
  # noqa: E501
5
- import functools
6
6
  import logging
7
- from typing import Callable
8
-
9
- from promptflow._sdk._telemetry import ActivityType, monitor_operation
7
+ from typing import Callable, cast
10
8
 
9
+ from azure.ai.evaluation._common.utils import validate_azure_ai_project
11
10
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
12
- from azure.ai.evaluation._model_configurations import AzureAIProject
13
11
  from azure.ai.evaluation.simulator import AdversarialScenario
14
- from azure.identity import DefaultAzureCredential
12
+ from azure.core.credentials import TokenCredential
15
13
 
16
14
  from ._adversarial_simulator import AdversarialSimulator
15
+ from ._helpers import experimental
17
16
  from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
18
17
 
19
18
  logger = logging.getLogger(__name__)
20
19
 
21
20
 
22
- def monitor_adversarial_scenario(func) -> Callable:
23
- """Decorator to monitor adversarial scenario.
24
-
25
- :param func: The function to be decorated.
26
- :type func: Callable
27
- :return: The decorated function.
28
- :rtype: Callable
29
- """
30
-
31
- @functools.wraps(func)
32
- def wrapper(*args, **kwargs):
33
- scenario = str(kwargs.get("scenario", None))
34
- max_conversation_turns = kwargs.get("max_conversation_turns", None)
35
- max_simulation_results = kwargs.get("max_simulation_results", None)
36
- decorated_func = monitor_operation(
37
- activity_name="xpia.adversarial.simulator.call",
38
- activity_type=ActivityType.PUBLICAPI,
39
- custom_dimensions={
40
- "scenario": scenario,
41
- "max_conversation_turns": max_conversation_turns,
42
- "max_simulation_results": max_simulation_results,
43
- },
44
- )(func)
45
-
46
- return decorated_func(*args, **kwargs)
47
-
48
- return wrapper
49
-
50
-
21
+ @experimental
51
22
  class IndirectAttackSimulator:
52
23
  """
53
24
  Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
@@ -59,41 +30,29 @@ class IndirectAttackSimulator:
59
30
  :type credential: ~azure.core.credentials.TokenCredential
60
31
  """
61
32
 
62
- def __init__(self, *, azure_ai_project: AzureAIProject, credential=None):
33
+ def __init__(self, *, azure_ai_project: dict, credential):
63
34
  """Constructor."""
64
- # check if azure_ai_project has the keys: subscription_id, resource_group_name, project_name, credential
65
- if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
66
- msg = "azure_ai_project must contain keys: subscription_id, resource_group_name and project_name"
67
- raise EvaluationException(
68
- message=msg,
69
- internal_message=msg,
70
- target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
71
- category=ErrorCategory.MISSING_FIELD,
72
- blame=ErrorBlame.USER_ERROR,
73
- )
74
- if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
75
- msg = "subscription_id, resource_group_name and project_name keys cannot be None"
35
+
36
+ try:
37
+ self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
38
+ except EvaluationException as e:
76
39
  raise EvaluationException(
77
- message=msg,
78
- internal_message=msg,
40
+ message=e.message,
41
+ internal_message=e.internal_message,
79
42
  target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
80
- category=ErrorCategory.MISSING_FIELD,
81
- blame=ErrorBlame.USER_ERROR,
82
- )
83
- if "credential" not in azure_ai_project and not credential:
84
- credential = DefaultAzureCredential()
85
- elif "credential" in azure_ai_project:
86
- credential = azure_ai_project["credential"]
87
- self.credential = credential
88
- self.azure_ai_project = azure_ai_project
43
+ category=e.category,
44
+ blame=e.blame,
45
+ ) from e
46
+
47
+ self.credential = cast(TokenCredential, credential)
89
48
  self.token_manager = ManagedIdentityAPITokenManager(
90
49
  token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
91
50
  logger=logging.getLogger("AdversarialSimulator"),
92
- credential=credential,
51
+ credential=self.credential,
93
52
  )
94
- self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
53
+ self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager)
95
54
  self.adversarial_template_handler = AdversarialTemplateHandler(
96
- azure_ai_project=azure_ai_project, rai_client=self.rai_client
55
+ azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
97
56
  )
98
57
 
99
58
  def _ensure_service_dependencies(self):
@@ -107,7 +66,6 @@ class IndirectAttackSimulator:
107
66
  blame=ErrorBlame.USER_ERROR,
108
67
  )
109
68
 
110
- # @monitor_adversarial_scenario
111
69
  async def __call__(
112
70
  self,
113
71
  *,
@@ -192,7 +150,7 @@ class IndirectAttackSimulator:
192
150
  category=ErrorCategory.INVALID_VALUE,
193
151
  blame=ErrorBlame.USER_ERROR,
194
152
  )
195
- jb_sim = AdversarialSimulator(azure_ai_project=self.azure_ai_project, credential=self.credential)
153
+ jb_sim = AdversarialSimulator(azure_ai_project=cast(dict, self.azure_ai_project), credential=self.credential)
196
154
  jb_sim_results = await jb_sim(
197
155
  scenario=scenario,
198
156
  target=target,
@@ -3,13 +3,15 @@
3
3
  # ---------------------------------------------------------
4
4
 
5
5
  import asyncio
6
+ import inspect
6
7
  import logging
7
8
  import os
8
9
  import time
9
10
  from abc import ABC, abstractmethod
10
11
  from enum import Enum
11
- from typing import Dict, Optional, Union
12
+ from typing import Optional, Union
12
13
 
14
+ from azure.core.credentials import TokenCredential, AccessToken
13
15
  from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
14
16
 
15
17
  AZURE_TOKEN_REFRESH_INTERVAL = 600 # seconds
@@ -29,24 +31,24 @@ class APITokenManager(ABC):
29
31
  :param auth_header: Authorization header prefix. Defaults to "Bearer"
30
32
  :type auth_header: str
31
33
  :param credential: Azure credential object
32
- :type credential: Optional[Union[azure.identity.DefaultAzureCredential, azure.identity.ManagedIdentityCredential]
34
+ :type credential: Optional[TokenCredential]
33
35
  """
34
36
 
35
37
  def __init__(
36
38
  self,
37
39
  logger: logging.Logger,
38
40
  auth_header: str = "Bearer",
39
- credential: Optional[Union[DefaultAzureCredential, ManagedIdentityCredential]] = None,
41
+ credential: Optional[TokenCredential] = None,
40
42
  ) -> None:
41
43
  self.logger = logger
42
44
  self.auth_header = auth_header
43
- self._lock = None
45
+ self._lock: Optional[asyncio.Lock] = None
44
46
  if credential is not None:
45
47
  self.credential = credential
46
48
  else:
47
49
  self.credential = self.get_aad_credential()
48
- self.token = None
49
- self.last_refresh_time = None
50
+ self.token: Optional[str] = None
51
+ self.last_refresh_time: Optional[float] = None
50
52
 
51
53
  @property
52
54
  def lock(self) -> asyncio.Lock:
@@ -73,20 +75,26 @@ class APITokenManager(ABC):
73
75
  identity_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID", None)
74
76
  if identity_client_id is not None:
75
77
  self.logger.info(f"Using DEFAULT_IDENTITY_CLIENT_ID: {identity_client_id}")
76
- credential = ManagedIdentityCredential(client_id=identity_client_id)
77
- else:
78
- self.logger.info("Environment variable DEFAULT_IDENTITY_CLIENT_ID is not set, using DefaultAzureCredential")
79
- credential = DefaultAzureCredential()
80
- return credential
78
+ return ManagedIdentityCredential(client_id=identity_client_id)
79
+
80
+ self.logger.info("Environment variable DEFAULT_IDENTITY_CLIENT_ID is not set, using DefaultAzureCredential")
81
+ return DefaultAzureCredential()
82
+
83
+ @abstractmethod
84
+ def get_token(self) -> str:
85
+ """Async method to get the API token. Subclasses should implement this method.
86
+
87
+ :return: API token
88
+ :rtype: str
89
+ """
81
90
 
82
91
  @abstractmethod
83
- async def get_token(self) -> str:
92
+ async def get_token_async(self) -> str:
84
93
  """Async method to get the API token. Subclasses should implement this method.
85
94
 
86
95
  :return: API token
87
96
  :rtype: str
88
97
  """
89
- pass # pylint: disable=unnecessary-pass
90
98
 
91
99
 
92
100
  class ManagedIdentityAPITokenManager(APITokenManager):
@@ -100,12 +108,18 @@ class ManagedIdentityAPITokenManager(APITokenManager):
100
108
  :paramtype kwargs: Dict
101
109
  """
102
110
 
103
- def __init__(self, token_scope: TokenScope, logger: logging.Logger, **kwargs: Dict):
104
- super().__init__(logger, **kwargs)
111
+ def __init__(
112
+ self,
113
+ token_scope: TokenScope,
114
+ logger: logging.Logger,
115
+ *,
116
+ auth_header: str = "Bearer",
117
+ credential: Optional[TokenCredential] = None,
118
+ ):
119
+ super().__init__(logger, auth_header=auth_header, credential=credential)
105
120
  self.token_scope = token_scope
106
121
 
107
- # Bug 3353724: This get_token is sync method, but it is defined as async method in the base class
108
- def get_token(self) -> str: # pylint: disable=invalid-overridden-method
122
+ def get_token(self) -> str:
109
123
  """Get the API token. If the token is not available or has expired, refresh the token.
110
124
 
111
125
  :return: API token
@@ -122,6 +136,31 @@ class ManagedIdentityAPITokenManager(APITokenManager):
122
136
 
123
137
  return self.token
124
138
 
139
+ async def get_token_async(self) -> str:
140
+ """Get the API token synchronously. If the token is not available or has expired, refresh it.
141
+
142
+ :return: API token
143
+ :rtype: str
144
+ """
145
+ if (
146
+ self.token is None
147
+ or self.last_refresh_time is None
148
+ or time.time() - self.last_refresh_time > AZURE_TOKEN_REFRESH_INTERVAL
149
+ ):
150
+ self.last_refresh_time = time.time()
151
+ get_token_method = self.credential.get_token(self.token_scope.value)
152
+ if inspect.isawaitable(get_token_method):
153
+ # If it's awaitable, await it
154
+ token_response: AccessToken = await get_token_method
155
+ else:
156
+ # Otherwise, call it synchronously
157
+ token_response = get_token_method
158
+
159
+ self.token = token_response.token
160
+ self.logger.info("Refreshed Azure endpoint token.")
161
+
162
+ return self.token
163
+
125
164
 
126
165
  class PlainTokenManager(APITokenManager):
127
166
  """Plain API Token Manager
@@ -134,11 +173,18 @@ class PlainTokenManager(APITokenManager):
134
173
  :paramtype kwargs: Dict
135
174
  """
136
175
 
137
- def __init__(self, openapi_key: str, logger: logging.Logger, **kwargs: Dict):
138
- super().__init__(logger, **kwargs)
139
- self.token = openapi_key
176
+ def __init__(
177
+ self,
178
+ openapi_key: str,
179
+ logger: logging.Logger,
180
+ *,
181
+ auth_header: str = "Bearer",
182
+ credential: Optional[TokenCredential] = None,
183
+ ) -> None:
184
+ super().__init__(logger, auth_header=auth_header, credential=credential)
185
+ self.token: str = openapi_key
140
186
 
141
- async def get_token(self) -> str:
187
+ def get_token(self) -> str:
142
188
  """Get the API token
143
189
 
144
190
  :return: API token
@@ -6,13 +6,14 @@ import copy
6
6
  import json
7
7
  import time
8
8
  import uuid
9
- from typing import Dict, List
9
+ from typing import Any, Dict, List, Optional, cast
10
10
 
11
11
  from azure.ai.evaluation._http_utils import AsyncHttpPipeline, get_async_http_client
12
12
  from azure.ai.evaluation._user_agent import USER_AGENT
13
13
  from azure.core.exceptions import HttpResponseError
14
14
  from azure.core.pipeline.policies import AsyncRetryPolicy, RetryMode
15
15
 
16
+ from .._model_tools._template_handler import TemplateParameters
16
17
  from .models import OpenAIChatCompletionsModel
17
18
 
18
19
 
@@ -33,7 +34,15 @@ class SimulationRequestDTO:
33
34
  :type template_parameters: Dict
34
35
  """
35
36
 
36
- def __init__(self, url, headers, payload, params, templatekey, template_parameters):
37
+ def __init__(
38
+ self,
39
+ url: str,
40
+ headers: Dict[str, str],
41
+ payload: Dict[str, Any],
42
+ params: Dict[str, str],
43
+ templatekey: str,
44
+ template_parameters: Optional[TemplateParameters],
45
+ ):
37
46
  self.url = url
38
47
  self.headers = headers
39
48
  self.json = json.dumps(payload)
@@ -47,9 +56,12 @@ class SimulationRequestDTO:
47
56
  :return: The DTO as a dictionary.
48
57
  :rtype: Dict
49
58
  """
50
- if self.templateParameters is not None:
51
- self.templateParameters = {str(k): str(v) for k, v in self.templateParameters.items()}
52
- return self.__dict__
59
+ toReturn = self.__dict__.copy()
60
+
61
+ if toReturn["templateParameters"] is not None:
62
+ toReturn["templateParameters"] = {str(k): str(v) for k, v in toReturn["templateParameters"].items()}
63
+
64
+ return toReturn
53
65
 
54
66
  def to_json(self):
55
67
  """Convert the DTO to a JSON string.
@@ -73,12 +85,12 @@ class ProxyChatCompletionsModel(OpenAIChatCompletionsModel):
73
85
  :keyword kwargs: Additional keyword arguments to pass to the parent class.
74
86
  """
75
87
 
76
- def __init__(self, name: str, template_key: str, template_parameters, *args, **kwargs) -> None:
88
+ def __init__(self, name: str, template_key: str, template_parameters: TemplateParameters, **kwargs) -> None:
77
89
  self.tkey = template_key
78
90
  self.tparam = template_parameters
79
- self.result_url = None
91
+ self.result_url: Optional[str] = None
80
92
 
81
- super().__init__(name=name, *args, **kwargs)
93
+ super().__init__(name=name, **kwargs)
82
94
 
83
95
  def format_request_data(self, messages: List[Dict], **request_params) -> Dict: # type: ignore[override]
84
96
  """Format the request data to query the model with.
@@ -160,7 +172,6 @@ class ProxyChatCompletionsModel(OpenAIChatCompletionsModel):
160
172
  }
161
173
  # add all additional headers
162
174
  headers.update(self.additional_headers) # type: ignore[arg-type]
163
-
164
175
  params = {}
165
176
  if self.api_version:
166
177
  params["api-version"] = self.api_version
@@ -184,8 +195,8 @@ class ProxyChatCompletionsModel(OpenAIChatCompletionsModel):
184
195
  message=f"Received unexpected HTTP status: {response.status_code} {response.text()}", response=response
185
196
  )
186
197
 
187
- response = response.json()
188
- self.result_url = response["location"]
198
+ response_data = response.json()
199
+ self.result_url = cast(str, response_data["location"])
189
200
 
190
201
  retry_policy = AsyncRetryPolicy( # set up retry configuration
191
202
  retry_on_status_codes=[202], # on which statuses to retry
@@ -202,6 +213,12 @@ class ProxyChatCompletionsModel(OpenAIChatCompletionsModel):
202
213
  time.sleep(15)
203
214
 
204
215
  async with get_async_http_client().with_policies(retry_policy=retry_policy) as exp_retry_client:
216
+ token = await self.token_manager.get_token_async()
217
+ proxy_headers = {
218
+ "Authorization": f"Bearer {token}",
219
+ "Content-Type": "application/json",
220
+ "User-Agent": USER_AGENT,
221
+ }
205
222
  response = await exp_retry_client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
206
223
  self.result_url, headers=proxy_headers
207
224
  )
@@ -2,25 +2,66 @@
2
2
  # Copyright (c) Microsoft Corporation. All rights reserved.
3
3
  # ---------------------------------------------------------
4
4
 
5
- from typing import Optional
5
+ from typing import Dict, List, Optional, TypedDict, cast
6
+
7
+ from typing_extensions import NotRequired
6
8
 
7
9
  from azure.ai.evaluation._model_configurations import AzureAIProject
8
10
 
9
11
  from ._rai_client import RAIClient
10
12
 
11
- CONTENT_HARM_TEMPLATES_COLLECTION_KEY = set(
12
- [
13
- "adv_qa",
14
- "adv_conversation",
15
- "adv_summarization",
16
- "adv_search",
17
- "adv_rewrite",
18
- "adv_content_gen_ungrounded",
19
- "adv_content_gen_grounded",
20
- "adv_content_protected_material",
21
- "adv_politics",
22
- ]
23
- )
13
+ CONTENT_HARM_TEMPLATES_COLLECTION_KEY = {
14
+ "adv_qa",
15
+ "adv_conversation",
16
+ "adv_summarization",
17
+ "adv_search",
18
+ "adv_rewrite",
19
+ "adv_content_gen_ungrounded",
20
+ "adv_content_gen_grounded",
21
+ "adv_content_protected_material",
22
+ "adv_politics",
23
+ }
24
+
25
+
26
+ class TemplateParameters(TypedDict):
27
+ """Parameters used in Templates
28
+
29
+ .. note::
30
+
31
+ This type is good enough to type check, but is incorrect. It's meant to represent a dictionary with a known
32
+ `metadata` key (Dict[str, str]), a known `ch_template_placeholder` key (str), and an unknown number of keys
33
+ that map to `str` values.
34
+
35
+ In typescript, this type would be spelled:
36
+
37
+ .. code-block:: typescript
38
+
39
+ type AdversarialTemplateParameters = {
40
+ [key: string]: string
41
+ ch_template_placeholder: string
42
+ metadata: {[index: string]: string} # Doesn't typecheck but gets the point across
43
+ }
44
+
45
+ At time of writing, this isn't possible to express with a TypedDict. TypedDicts must be "closed" in that
46
+ they fully specify all the keys they can contain.
47
+
48
+ `PEP 728 – TypedDict with Typed Extra Items <https://peps.python.org/pep-0728/>` is a proposal to support
49
+ this, but would only be available in Python 3.13 at the earliest.
50
+ """
51
+
52
+ metadata: Dict[str, str]
53
+ conversation_starter: str
54
+ ch_template_placeholder: str
55
+ group_of_people: NotRequired[str]
56
+ category: NotRequired[str]
57
+ target_population: NotRequired[str]
58
+ topic: NotRequired[str]
59
+
60
+
61
+ class _CategorizedParameter(TypedDict):
62
+ parameters: List[TemplateParameters]
63
+ category: str
64
+ parameters_key: str
24
65
 
25
66
 
26
67
  class ContentHarmTemplatesUtils:
@@ -85,13 +126,19 @@ class AdversarialTemplate:
85
126
  :param template_parameters: The template parameters.
86
127
  """
87
128
 
88
- def __init__(self, template_name, text, context_key, template_parameters=None) -> None:
129
+ def __init__(
130
+ self,
131
+ template_name: str,
132
+ text: Optional[str],
133
+ context_key: List,
134
+ template_parameters: Optional[List[TemplateParameters]] = None,
135
+ ) -> None:
89
136
  self.text = text
90
137
  self.context_key = context_key
91
138
  self.template_name = template_name
92
- self.template_parameters = template_parameters
139
+ self.template_parameters = template_parameters or []
93
140
 
94
- def __str__(self):
141
+ def __str__(self) -> str:
95
142
  return "{{ch_template_placeholder}}"
96
143
 
97
144
 
@@ -106,16 +153,13 @@ class AdversarialTemplateHandler:
106
153
  """
107
154
 
108
155
  def __init__(self, azure_ai_project: AzureAIProject, rai_client: RAIClient) -> None:
109
- self.cached_templates_source = {}
110
- # self.template_env = JinjaEnvironment(loader=JinjaFileSystemLoader(searchpath=template_dir))
111
156
  self.azure_ai_project = azure_ai_project
112
- self.categorized_ch_parameters = None
157
+ self.categorized_ch_parameters: Optional[Dict[str, _CategorizedParameter]] = None
113
158
  self.rai_client = rai_client
114
159
 
115
- async def _get_content_harm_template_collections(self, collection_key):
116
-
160
+ async def _get_content_harm_template_collections(self, collection_key: str) -> List[AdversarialTemplate]:
117
161
  if self.categorized_ch_parameters is None:
118
- categorized_parameters = {}
162
+ categorized_parameters: Dict[str, _CategorizedParameter] = {}
119
163
  util = ContentHarmTemplatesUtils
120
164
 
121
165
  parameters = await self.rai_client.get_contentharm_parameters()
@@ -123,7 +167,7 @@ class AdversarialTemplateHandler:
123
167
  for k in parameters.keys():
124
168
  template_key = util.get_template_key(k)
125
169
  categorized_parameters[template_key] = {
126
- "parameters": parameters[k],
170
+ "parameters": cast(List[TemplateParameters], parameters[k]),
127
171
  "category": util.get_template_category(k),
128
172
  "parameters_key": k,
129
173
  }
@@ -49,10 +49,10 @@ class LLMBase(ABC):
49
49
  Base class for all LLM models.
50
50
  """
51
51
 
52
- def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[dict] = {}):
52
+ def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[Dict[str, str]] = None):
53
53
  self.endpoint_url = endpoint_url
54
54
  self.name = name
55
- self.additional_headers = additional_headers
55
+ self.additional_headers = additional_headers or {}
56
56
  self.logger = logging.getLogger(repr(self))
57
57
 
58
58
  # Metric tracking
@@ -208,7 +208,7 @@ class OpenAICompletionsModel(LLMBase):
208
208
  *,
209
209
  endpoint_url: str,
210
210
  name: str = "OpenAICompletionsModel",
211
- additional_headers: Optional[dict] = {},
211
+ additional_headers: Optional[Dict[str, str]] = None,
212
212
  api_version: Optional[str] = "2023-03-15-preview",
213
213
  token_manager: APITokenManager,
214
214
  azureml_model_deployment: Optional[str] = None,
@@ -220,7 +220,7 @@ class OpenAICompletionsModel(LLMBase):
220
220
  frequency_penalty: Optional[float] = 0,
221
221
  presence_penalty: Optional[float] = 0,
222
222
  stop: Optional[Union[List[str], str]] = None,
223
- image_captions: Dict[str, str] = {},
223
+ image_captions: Optional[Dict[str, str]] = None,
224
224
  images_dir: Optional[str] = None, # Note: unused, kept for class compatibility
225
225
  ):
226
226
  super().__init__(endpoint_url=endpoint_url, name=name, additional_headers=additional_headers)
@@ -234,7 +234,7 @@ class OpenAICompletionsModel(LLMBase):
234
234
  self.n = n
235
235
  self.frequency_penalty = frequency_penalty
236
236
  self.presence_penalty = presence_penalty
237
- self.image_captions = image_captions
237
+ self.image_captions = image_captions or {}
238
238
 
239
239
  # Default stop to end token if not provided
240
240
  if not stop:
@@ -263,7 +263,7 @@ class OpenAICompletionsModel(LLMBase):
263
263
  def get_model_params(self):
264
264
  return {param: getattr(self, param) for param in self.model_param_names if getattr(self, param) is not None}
265
265
 
266
- def format_request_data(self, prompt: str, **request_params) -> Dict[str, str]:
266
+ def format_request_data(self, prompt: Dict[str, str], **request_params) -> Dict[str, str]: # type: ignore[override]
267
267
  """
268
268
  Format the request data for the OpenAI API.
269
269
  """
@@ -328,7 +328,7 @@ class OpenAICompletionsModel(LLMBase):
328
328
  # Format prompts and tag with index
329
329
  request_datas: List[Dict] = []
330
330
  for idx, prompt in enumerate(prompts):
331
- prompt: Dict[str, str] = self.format_request_data(prompt, **request_params)
331
+ prompt = self.format_request_data(prompt, **request_params)
332
332
  prompt[self.prompt_idx_key] = idx # type: ignore[assignment]
333
333
  request_datas.append(prompt)
334
334
 
@@ -447,7 +447,7 @@ class OpenAICompletionsModel(LLMBase):
447
447
 
448
448
  self._log_request(request_data)
449
449
 
450
- token = await self.token_manager.get_token()
450
+ token = self.token_manager.get_token()
451
451
 
452
452
  headers = {
453
453
  "Content-Type": "application/json",
@@ -522,8 +522,8 @@ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
522
522
  formats the prompt for chat completion.
523
523
  """
524
524
 
525
- def __init__(self, name="OpenAIChatCompletionsModel", *args, **kwargs):
526
- super().__init__(name=name, *args, **kwargs)
525
+ def __init__(self, name="OpenAIChatCompletionsModel", **kwargs):
526
+ super().__init__(name=name, **kwargs)
527
527
 
528
528
  def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
529
529
  request_data = {"messages": messages, **self.get_model_params()}
@@ -3,11 +3,6 @@ name: TaskSimulatorQueryResponse
3
3
  description: Gets queries and responses from a blob of text
4
4
  model:
5
5
  api: chat
6
- configuration:
7
- type: azure_openai
8
- azure_deployment: ${env:AZURE_DEPLOYMENT}
9
- api_key: ${env:AZURE_OPENAI_API_KEY}
10
- azure_endpoint: ${env:AZURE_OPENAI_ENDPOINT}
11
6
  parameters:
12
7
  temperature: 0.0
13
8
  top_p: 1.0
@@ -33,7 +28,8 @@ Answer must not be more than 5 words
33
28
  Answer must be picked from Text as is
34
29
  Question should be as descriptive as possible and must include as much context as possible from Text
35
30
  Output must always have the provided number of QnAs
36
- Output must be in JSON format
31
+ Output must be in JSON format.
32
+ Output must have {{num_queries}} objects in the format specified below. Any other count is unacceptable.
37
33
  Text:
38
34
  <|text_start|>
39
35
  On January 24, 1984, former Apple CEO Steve Jobs introduced the first Macintosh. In late 2003, Apple had 2.06 percent of the desktop share in the United States.
@@ -3,10 +3,6 @@ name: TaskSimulatorWithPersona
3
3
  description: Simulates a user to complete a conversation
4
4
  model:
5
5
  api: chat
6
- configuration:
7
- type: azure_openai
8
- azure_deployment: ${env:AZURE_DEPLOYMENT}
9
- azure_endpoint: ${env:AZURE_OPENAI_ENDPOINT}
10
6
  parameters:
11
7
  temperature: 0.0
12
8
  top_p: 1.0