azure-ai-evaluation 0.0.0b0__py3-none-any.whl → 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (100) hide show
  1. azure/ai/evaluation/__init__.py +60 -0
  2. azure/ai/evaluation/_common/__init__.py +16 -0
  3. azure/ai/evaluation/_common/constants.py +65 -0
  4. azure/ai/evaluation/_common/rai_service.py +452 -0
  5. azure/ai/evaluation/_common/utils.py +87 -0
  6. azure/ai/evaluation/_constants.py +50 -0
  7. azure/ai/evaluation/_evaluate/__init__.py +3 -0
  8. azure/ai/evaluation/_evaluate/_batch_run_client/__init__.py +8 -0
  9. azure/ai/evaluation/_evaluate/_batch_run_client/batch_run_context.py +72 -0
  10. azure/ai/evaluation/_evaluate/_batch_run_client/code_client.py +150 -0
  11. azure/ai/evaluation/_evaluate/_batch_run_client/proxy_client.py +61 -0
  12. azure/ai/evaluation/_evaluate/_eval_run.py +494 -0
  13. azure/ai/evaluation/_evaluate/_evaluate.py +689 -0
  14. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +174 -0
  15. azure/ai/evaluation/_evaluate/_utils.py +237 -0
  16. azure/ai/evaluation/_evaluators/__init__.py +3 -0
  17. azure/ai/evaluation/_evaluators/_bleu/__init__.py +9 -0
  18. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +73 -0
  19. azure/ai/evaluation/_evaluators/_chat/__init__.py +9 -0
  20. azure/ai/evaluation/_evaluators/_chat/_chat.py +350 -0
  21. azure/ai/evaluation/_evaluators/_chat/retrieval/__init__.py +9 -0
  22. azure/ai/evaluation/_evaluators/_chat/retrieval/_retrieval.py +163 -0
  23. azure/ai/evaluation/_evaluators/_chat/retrieval/retrieval.prompty +48 -0
  24. azure/ai/evaluation/_evaluators/_coherence/__init__.py +7 -0
  25. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +122 -0
  26. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +62 -0
  27. azure/ai/evaluation/_evaluators/_content_safety/__init__.py +21 -0
  28. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +108 -0
  29. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_base.py +66 -0
  30. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +296 -0
  31. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +78 -0
  32. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +76 -0
  33. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +76 -0
  34. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +76 -0
  35. azure/ai/evaluation/_evaluators/_eci/__init__.py +0 -0
  36. azure/ai/evaluation/_evaluators/_eci/_eci.py +99 -0
  37. azure/ai/evaluation/_evaluators/_f1_score/__init__.py +9 -0
  38. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +141 -0
  39. azure/ai/evaluation/_evaluators/_fluency/__init__.py +9 -0
  40. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +122 -0
  41. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +61 -0
  42. azure/ai/evaluation/_evaluators/_gleu/__init__.py +9 -0
  43. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +71 -0
  44. azure/ai/evaluation/_evaluators/_groundedness/__init__.py +9 -0
  45. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +123 -0
  46. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +54 -0
  47. azure/ai/evaluation/_evaluators/_meteor/__init__.py +9 -0
  48. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +96 -0
  49. azure/ai/evaluation/_evaluators/_protected_material/__init__.py +5 -0
  50. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +104 -0
  51. azure/ai/evaluation/_evaluators/_protected_materials/__init__.py +5 -0
  52. azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py +104 -0
  53. azure/ai/evaluation/_evaluators/_qa/__init__.py +9 -0
  54. azure/ai/evaluation/_evaluators/_qa/_qa.py +111 -0
  55. azure/ai/evaluation/_evaluators/_relevance/__init__.py +9 -0
  56. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +131 -0
  57. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +69 -0
  58. azure/ai/evaluation/_evaluators/_rouge/__init__.py +10 -0
  59. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +98 -0
  60. azure/ai/evaluation/_evaluators/_similarity/__init__.py +9 -0
  61. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +130 -0
  62. azure/ai/evaluation/_evaluators/_similarity/similarity.prompty +71 -0
  63. azure/ai/evaluation/_evaluators/_xpia/__init__.py +5 -0
  64. azure/ai/evaluation/_evaluators/_xpia/xpia.py +140 -0
  65. azure/ai/evaluation/_exceptions.py +107 -0
  66. azure/ai/evaluation/_http_utils.py +395 -0
  67. azure/ai/evaluation/_model_configurations.py +27 -0
  68. azure/ai/evaluation/_user_agent.py +6 -0
  69. azure/ai/evaluation/_version.py +5 -0
  70. azure/ai/evaluation/py.typed +0 -0
  71. azure/ai/evaluation/simulator/__init__.py +15 -0
  72. azure/ai/evaluation/simulator/_adversarial_scenario.py +27 -0
  73. azure/ai/evaluation/simulator/_adversarial_simulator.py +450 -0
  74. azure/ai/evaluation/simulator/_constants.py +17 -0
  75. azure/ai/evaluation/simulator/_conversation/__init__.py +315 -0
  76. azure/ai/evaluation/simulator/_conversation/_conversation.py +178 -0
  77. azure/ai/evaluation/simulator/_conversation/constants.py +30 -0
  78. azure/ai/evaluation/simulator/_direct_attack_simulator.py +252 -0
  79. azure/ai/evaluation/simulator/_helpers/__init__.py +4 -0
  80. azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py +17 -0
  81. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +93 -0
  82. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +207 -0
  83. azure/ai/evaluation/simulator/_model_tools/__init__.py +23 -0
  84. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +147 -0
  85. azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +228 -0
  86. azure/ai/evaluation/simulator/_model_tools/_rai_client.py +157 -0
  87. azure/ai/evaluation/simulator/_model_tools/_template_handler.py +157 -0
  88. azure/ai/evaluation/simulator/_model_tools/models.py +616 -0
  89. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +69 -0
  90. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +36 -0
  91. azure/ai/evaluation/simulator/_tracing.py +92 -0
  92. azure/ai/evaluation/simulator/_utils.py +111 -0
  93. azure/ai/evaluation/simulator/simulator.py +579 -0
  94. azure_ai_evaluation-1.0.0b1.dist-info/METADATA +377 -0
  95. azure_ai_evaluation-1.0.0b1.dist-info/RECORD +97 -0
  96. {azure_ai_evaluation-0.0.0b0.dist-info → azure_ai_evaluation-1.0.0b1.dist-info}/WHEEL +1 -1
  97. azure_ai_evaluation-1.0.0b1.dist-info/top_level.txt +1 -0
  98. azure_ai_evaluation-0.0.0b0.dist-info/METADATA +0 -7
  99. azure_ai_evaluation-0.0.0b0.dist-info/RECORD +0 -4
  100. azure_ai_evaluation-0.0.0b0.dist-info/top_level.txt +0 -1
@@ -0,0 +1,616 @@
1
+ # ---------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # ---------------------------------------------------------
4
+ # pylint: skip-file
5
+ import ast
6
+ import asyncio
7
+ import copy
8
+ import logging
9
+ import time
10
+ import uuid
11
+ from abc import ABC, abstractmethod
12
+ from collections import deque
13
+ from typing import Deque, Dict, List, Optional, Union
14
+ from urllib.parse import urlparse
15
+ import ast
16
+
17
+ from azure.ai.evaluation._http_utils import AsyncHttpPipeline
18
+ from azure.ai.evaluation._exceptions import EvaluationException, ErrorBlame, ErrorCategory, ErrorTarget
19
+
20
+ from ._identity_manager import APITokenManager
21
+
22
+ MIN_ERRORS_TO_FAIL = 3
23
+ MAX_TIME_TAKEN_RECORDS = 20_000
24
+
25
+
26
+ def get_model_class_from_url(endpoint_url: str):
27
+ """Convert an endpoint URL to the appropriate model class."""
28
+ endpoint_path = urlparse(endpoint_url).path # remove query params
29
+
30
+ if endpoint_path.endswith("chat/completions"):
31
+ return OpenAIChatCompletionsModel
32
+ elif endpoint_path.endswith("completions"):
33
+ return OpenAICompletionsModel
34
+ else:
35
+ raise EvaluationException(
36
+ message=f"Unknown API type for endpoint {endpoint_url}",
37
+ internal_message="Unknown API type",
38
+ error_category=ErrorCategory.UNKNOWN_FIELD,
39
+ error_blame=ErrorBlame.USER_ERROR,
40
+ error_target=ErrorTarget.MODELS,
41
+ )
42
+
43
+
44
+ # ===========================================================
45
+ # ===================== LLMBase Class =======================
46
+ # ===========================================================
47
+
48
+
49
+ class LLMBase(ABC):
50
+ """
51
+ Base class for all LLM models.
52
+ """
53
+
54
+ def __init__(self, endpoint_url: str, name: str = "unknown", additional_headers: Optional[dict] = {}):
55
+ self.endpoint_url = endpoint_url
56
+ self.name = name
57
+ self.additional_headers = additional_headers
58
+ self.logger = logging.getLogger(repr(self))
59
+
60
+ # Metric tracking
61
+ self._lock = None
62
+ self.response_times: Deque[Union[int, float]] = deque(maxlen=MAX_TIME_TAKEN_RECORDS)
63
+ self.step = 0
64
+ self.error_count = 0
65
+
66
+ @property
67
+ async def lock(self):
68
+ if self._lock is None:
69
+ self._lock = asyncio.Lock()
70
+ return self._lock
71
+
72
+ @abstractmethod
73
+ def get_model_params(self) -> dict:
74
+ pass
75
+
76
+ @abstractmethod
77
+ def format_request_data(self, prompt: str, **request_params) -> dict:
78
+ pass
79
+
80
+ async def get_completion(
81
+ self,
82
+ prompt: str,
83
+ session: AsyncHttpPipeline,
84
+ **request_params,
85
+ ) -> dict:
86
+ """
87
+ Query the model a single time with a prompt.
88
+
89
+ Parameters
90
+ ----------
91
+ prompt: Prompt str to query model with.
92
+ session: AsyncHttpPipeline object to use for the request.
93
+ **request_params: Additional parameters to pass to the request.
94
+ """
95
+ request_data = self.format_request_data(prompt, **request_params)
96
+ return await self.request_api(
97
+ session=session,
98
+ request_data=request_data,
99
+ )
100
+
101
+ @abstractmethod
102
+ async def get_all_completions(
103
+ self,
104
+ prompts: List[str],
105
+ session: AsyncHttpPipeline,
106
+ api_call_max_parallel_count: int,
107
+ api_call_delay_seconds: float,
108
+ request_error_rate_threshold: float,
109
+ **request_params,
110
+ ) -> List[dict]:
111
+ pass
112
+
113
+ @abstractmethod
114
+ async def request_api(
115
+ self,
116
+ session: AsyncHttpPipeline,
117
+ request_data: dict,
118
+ ) -> dict:
119
+ pass
120
+
121
+ @abstractmethod
122
+ async def get_conversation_completion(
123
+ self,
124
+ messages: List[dict],
125
+ session: AsyncHttpPipeline,
126
+ role: str,
127
+ **request_params,
128
+ ) -> dict:
129
+ pass
130
+
131
+ @abstractmethod
132
+ async def request_api_parallel(
133
+ self,
134
+ request_datas: List[dict],
135
+ output_collector: List,
136
+ session: AsyncHttpPipeline,
137
+ api_call_delay_seconds: float,
138
+ request_error_rate_threshold: float,
139
+ ) -> None:
140
+ pass
141
+
142
+ def _log_request(self, request: dict) -> None:
143
+ self.logger.info(f"Request: {request}")
144
+
145
+ async def _add_successful_response(self, time_taken: Union[int, float]) -> None:
146
+ async with self.lock:
147
+ self.response_times.append(time_taken)
148
+ self.step += 1
149
+
150
+ async def _add_error(self) -> None:
151
+ async with self.lock:
152
+ self.error_count += 1
153
+ self.step += 1
154
+
155
+ async def get_response_count(self) -> int:
156
+ async with self.lock:
157
+ return len(self.response_times)
158
+
159
+ async def get_response_times(self) -> List[float]:
160
+ async with self.lock:
161
+ return list(self.response_times)
162
+
163
+ async def get_average_response_time(self) -> float:
164
+ async with self.lock:
165
+ return sum(self.response_times) / len(self.response_times)
166
+
167
+ async def get_error_rate(self) -> float:
168
+ async with self.lock:
169
+ return self.error_count / self.step
170
+
171
+ async def get_error_count(self) -> int:
172
+ async with self.lock:
173
+ return self.error_count
174
+
175
+ def __repr__(self):
176
+ return f"{self.__class__.__name__}(name={self.name})"
177
+
178
+
179
+ # ===========================================================
180
+ # ================== OpenAICompletions ======================
181
+ # ===========================================================
182
+
183
+
184
+ class OpenAICompletionsModel(LLMBase):
185
+ """
186
+ Object for calling a Completions-style API for OpenAI models.
187
+ """
188
+
189
+ prompt_idx_key = "__prompt_idx__"
190
+
191
+ max_stop_tokens = 4
192
+ stop_tokens = ["<|im_end|>", "<|endoftext|>"]
193
+
194
+ model_param_names = [
195
+ "model",
196
+ "temperature",
197
+ "max_tokens",
198
+ "top_p",
199
+ "n",
200
+ "frequency_penalty",
201
+ "presence_penalty",
202
+ "stop",
203
+ ]
204
+
205
+ CHAT_START_TOKEN = "<|im_start|>"
206
+ CHAT_END_TOKEN = "<|im_end|>"
207
+
208
+ def __init__(
209
+ self,
210
+ *,
211
+ endpoint_url: str,
212
+ name: str = "OpenAICompletionsModel",
213
+ additional_headers: Optional[dict] = {},
214
+ api_version: Optional[str] = "2023-03-15-preview",
215
+ token_manager: APITokenManager,
216
+ azureml_model_deployment: Optional[str] = None,
217
+ model: Optional[str] = None,
218
+ temperature: Optional[float] = 0.7,
219
+ max_tokens: Optional[int] = 300,
220
+ top_p: Optional[float] = None, # Recommended to use top_p or temp, not both
221
+ n: Optional[int] = 1,
222
+ frequency_penalty: Optional[float] = 0,
223
+ presence_penalty: Optional[float] = 0,
224
+ stop: Optional[Union[List[str], str]] = None,
225
+ image_captions: Dict[str, str] = {},
226
+ images_dir: Optional[str] = None, # Note: unused, kept for class compatibility
227
+ ):
228
+ super().__init__(endpoint_url=endpoint_url, name=name, additional_headers=additional_headers)
229
+ self.api_version = api_version
230
+ self.token_manager = token_manager
231
+ self.azureml_model_deployment = azureml_model_deployment
232
+ self.model = model
233
+ self.temperature = temperature
234
+ self.max_tokens = max_tokens
235
+ self.top_p = top_p
236
+ self.n = n
237
+ self.frequency_penalty = frequency_penalty
238
+ self.presence_penalty = presence_penalty
239
+ self.image_captions = image_captions
240
+
241
+ # Default stop to end token if not provided
242
+ if not stop:
243
+ stop = []
244
+ # Else if stop sequence is given as a string (Ex: "["\n", "<im_end>"]"), convert
245
+ elif type(stop) is str and stop.startswith("[") and stop.endswith("]"):
246
+ stop = ast.literal_eval(stop)
247
+ elif type(stop) is str:
248
+ stop = [stop]
249
+ self.stop: List = stop # type: ignore[assignment]
250
+
251
+ # If stop tokens do not include default end tokens, add them
252
+ for token in self.stop_tokens:
253
+ if len(self.stop) >= self.max_stop_tokens:
254
+ break
255
+ if token not in self.stop:
256
+ self.stop.append(token)
257
+
258
+ if top_p not in [None, 1.0] and temperature is not None:
259
+ self.logger.warning(
260
+ "Both top_p and temperature are set. OpenAI advises against using both at the same time."
261
+ )
262
+
263
+ self.logger.info(f"Default model settings: {self.get_model_params()}")
264
+
265
+ def get_model_params(self):
266
+ return {param: getattr(self, param) for param in self.model_param_names if getattr(self, param) is not None}
267
+
268
+ def format_request_data(self, prompt: str, **request_params) -> Dict[str, str]:
269
+ """
270
+ Format the request data for the OpenAI API.
271
+ """
272
+ request_data = {"prompt": prompt, **self.get_model_params()}
273
+ request_data.update(request_params)
274
+ return request_data
275
+
276
+ async def get_conversation_completion(
277
+ self,
278
+ messages: List[dict],
279
+ session: AsyncHttpPipeline,
280
+ role: str = "assistant",
281
+ **request_params,
282
+ ) -> dict:
283
+ """
284
+ Query the model a single time with a message.
285
+
286
+ Parameters
287
+ ----------
288
+ messages: List of messages to query the model with.
289
+ Expected format: [{"role": "user", "content": "Hello!"}, ...]
290
+ session: AsyncHttpPipeline object to query the model with.
291
+ role: Role of the user sending the message.
292
+ request_params: Additional parameters to pass to the model.
293
+ """
294
+ prompt = []
295
+ for message in messages:
296
+ prompt.append(f"{self.CHAT_START_TOKEN}{message['role']}\n{message['content']}\n{self.CHAT_END_TOKEN}\n")
297
+ prompt_string: str = "".join(prompt)
298
+ prompt_string += f"{self.CHAT_START_TOKEN}{role}\n"
299
+
300
+ return await self.get_completion(
301
+ prompt=prompt_string,
302
+ session=session,
303
+ **request_params,
304
+ )
305
+
306
+ async def get_all_completions( # type: ignore[override]
307
+ self,
308
+ prompts: List[Dict[str, str]],
309
+ session: AsyncHttpPipeline,
310
+ api_call_max_parallel_count: int = 1,
311
+ api_call_delay_seconds: float = 0.1,
312
+ request_error_rate_threshold: float = 0.5,
313
+ **request_params,
314
+ ) -> List[dict]:
315
+ """
316
+ Run a batch of prompts through the model and return the results in the order given.
317
+
318
+ Parameters
319
+ ----------
320
+ prompts: List of prompts to query the model with.
321
+ session: AsyncHttpPipeline to use for the request.
322
+ api_call_max_parallel_count: Number of parallel requests to make to the API.
323
+ api_call_delay_seconds: Number of seconds to wait between API requests.
324
+ request_error_rate_threshold: Maximum error rate allowed before raising an error.
325
+ request_params: Additional parameters to pass to the API.
326
+ """
327
+ if api_call_max_parallel_count > 1:
328
+ self.logger.info(f"Using {api_call_max_parallel_count} parallel workers to query the API..")
329
+
330
+ # Format prompts and tag with index
331
+ request_datas: List[Dict] = []
332
+ for idx, prompt in enumerate(prompts):
333
+ prompt: Dict[str, str] = self.format_request_data(prompt, **request_params)
334
+ prompt[self.prompt_idx_key] = idx # type: ignore[assignment]
335
+ request_datas.append(prompt)
336
+
337
+ # Perform inference
338
+ if len(prompts) == 0:
339
+ return [] # queue is empty
340
+
341
+ output_collector: List = []
342
+ tasks = [ # create a set of worker-tasks to query inference endpoint in parallel
343
+ asyncio.create_task(
344
+ self.request_api_parallel(
345
+ request_datas=request_datas,
346
+ output_collector=output_collector,
347
+ session=session,
348
+ api_call_delay_seconds=api_call_delay_seconds,
349
+ request_error_rate_threshold=request_error_rate_threshold,
350
+ )
351
+ )
352
+ for _ in range(api_call_max_parallel_count)
353
+ ]
354
+
355
+ # Await the completion of all tasks, and propagate any exceptions
356
+ await asyncio.gather(*tasks, return_exceptions=False)
357
+ if len(request_datas):
358
+ msg = "All inference tasks were finished, but the queue is not empty"
359
+ raise EvaluationException(
360
+ message=msg,
361
+ internal_message=msg,
362
+ target=ErrorTarget.MODELS,
363
+ category=ErrorCategory.FAILED_EXECUTION,
364
+ blame=ErrorBlame.UNKNOWN,
365
+ )
366
+
367
+ # Output results back to the caller
368
+ output_collector.sort(key=lambda x: x[self.prompt_idx_key])
369
+ for output in output_collector:
370
+ output.pop(self.prompt_idx_key)
371
+ return output_collector
372
+
373
+ async def request_api_parallel(
374
+ self,
375
+ request_datas: List[dict],
376
+ output_collector: List,
377
+ session: AsyncHttpPipeline,
378
+ api_call_delay_seconds: float = 0.1,
379
+ request_error_rate_threshold: float = 0.5,
380
+ ) -> None:
381
+ """
382
+ Query the model for all prompts given as a list and append the output to output_collector.
383
+ No return value, output_collector is modified in place.
384
+ """
385
+ logger_tasks: List = [] # to await for logging to finish
386
+
387
+ while True: # process data from queue until it"s empty
388
+ try:
389
+ request_data = request_datas.pop()
390
+ prompt_idx = request_data.pop(self.prompt_idx_key)
391
+
392
+ try:
393
+ response = await self.request_api(
394
+ session=session,
395
+ request_data=request_data,
396
+ )
397
+ await self._add_successful_response(response["time_taken"])
398
+ except Exception as e:
399
+ response = {
400
+ "request": request_data,
401
+ "response": {
402
+ "finish_reason": "error",
403
+ "error": str(e),
404
+ },
405
+ }
406
+ await self._add_error()
407
+
408
+ self.logger.exception(f"Errored on prompt #{prompt_idx}")
409
+
410
+ # if we count too many errors, we stop and raise an exception
411
+ response_count = await self.get_response_count()
412
+ error_rate = await self.get_error_rate()
413
+ if response_count >= MIN_ERRORS_TO_FAIL and error_rate >= request_error_rate_threshold:
414
+ error_msg = (
415
+ f"Error rate is more than {request_error_rate_threshold:.0%} -- something is broken!"
416
+ )
417
+ raise EvaluationException(
418
+ message=error_msg,
419
+ internal_message=error_msg,
420
+ target=ErrorTarget.MODELS,
421
+ category=ErrorCategory.FAILED_EXECUTION,
422
+ blame=ErrorBlame.UNKNOWN,
423
+ )
424
+
425
+ response[self.prompt_idx_key] = prompt_idx
426
+ output_collector.append(response)
427
+
428
+ # Sleep between consecutive requests to avoid rate limit
429
+ await asyncio.sleep(api_call_delay_seconds)
430
+
431
+ except IndexError: # when the queue is empty, the worker is done
432
+ # wait for logging tasks to finish
433
+ await asyncio.gather(*logger_tasks)
434
+ return
435
+
436
+ async def request_api(
437
+ self,
438
+ session: AsyncHttpPipeline,
439
+ request_data: dict,
440
+ ) -> dict:
441
+ """
442
+ Request the model with a body of data.
443
+
444
+ Parameters
445
+ ----------
446
+ session: HTTPS Session for invoking the endpoint.
447
+ request_data: Prompt dictionary to query the model with. (Pass {"prompt": prompt} instead of prompt.)
448
+ """
449
+
450
+ self._log_request(request_data)
451
+
452
+ token = await self.token_manager.get_token()
453
+
454
+ headers = {
455
+ "Content-Type": "application/json",
456
+ "X-CV": f"{uuid.uuid4()}",
457
+ "X-ModelType": self.model or "",
458
+ }
459
+
460
+ if self.token_manager.auth_header == "Bearer":
461
+ headers["Authorization"] = f"Bearer {token}"
462
+ elif self.token_manager.auth_header == "api-key":
463
+ headers["api-key"] = token
464
+ headers["Authorization"] = "api-key"
465
+
466
+ # Update timeout for proxy endpoint
467
+ if self.azureml_model_deployment:
468
+ headers["azureml-model-deployment"] = self.azureml_model_deployment
469
+
470
+ # add all additional headers
471
+ if self.additional_headers:
472
+ headers.update(self.additional_headers)
473
+
474
+ params = {}
475
+ if self.api_version:
476
+ params["api-version"] = self.api_version
477
+
478
+ time_start = time.time()
479
+ full_response = None
480
+
481
+ response = await session.post(url=self.endpoint_url, headers=headers, json=request_data, params=params)
482
+
483
+ response.raise_for_status()
484
+
485
+ response_data = response.json()
486
+
487
+ self.logger.info(f"Response: {response_data}")
488
+
489
+ # Copy the full response and return it to be saved in jsonl.
490
+ full_response = copy.copy(response_data)
491
+
492
+ time_taken = time.time() - time_start
493
+
494
+ parsed_response = self._parse_response(response_data, request_data=request_data)
495
+
496
+ return {
497
+ "request": request_data,
498
+ "response": parsed_response,
499
+ "time_taken": time_taken,
500
+ "full_response": full_response,
501
+ }
502
+
503
+ def _parse_response(self, response_data: dict, request_data: Optional[dict] = None) -> dict:
504
+ # https://platform.openai.com/docs/api-reference/completions
505
+ samples = []
506
+ finish_reason = []
507
+ for choice in response_data["choices"]:
508
+ if "text" in choice:
509
+ samples.append(choice["text"])
510
+ if "finish_reason" in choice:
511
+ finish_reason.append(choice["finish_reason"])
512
+
513
+ return {"samples": samples, "finish_reason": finish_reason, "id": response_data["id"]}
514
+
515
+
516
+ # ===========================================================
517
+ # ============== OpenAIChatCompletionsModel =================
518
+ # ===========================================================
519
+
520
+
521
+ class OpenAIChatCompletionsModel(OpenAICompletionsModel):
522
+ """
523
+ OpenAIChatCompletionsModel is a wrapper around OpenAICompletionsModel that
524
+ formats the prompt for chat completion.
525
+ """
526
+
527
+ def __init__(self, name="OpenAIChatCompletionsModel", *args, **kwargs):
528
+ super().__init__(name=name, *args, **kwargs)
529
+
530
+ def format_request_data(self, messages: List[dict], **request_params): # type: ignore[override]
531
+ request_data = {"messages": messages, **self.get_model_params()}
532
+ request_data.update(request_params)
533
+ return request_data
534
+
535
+ async def get_conversation_completion(
536
+ self,
537
+ messages: List[dict],
538
+ session: AsyncHttpPipeline,
539
+ role: str = "assistant",
540
+ **request_params,
541
+ ) -> dict:
542
+ """
543
+ Query the model a single time with a message.
544
+
545
+ Parameters
546
+ ----------
547
+ messages: List of messages to query the model with.
548
+ Expected format: [{"role": "user", "content": "Hello!"}, ...]
549
+ session: AsyncHttpPipeline object to query the model with.
550
+ role: Not used for this model, since it is a chat model.
551
+ request_params: Additional parameters to pass to the model.
552
+ """
553
+ request_data = self.format_request_data(
554
+ messages=messages,
555
+ **request_params,
556
+ )
557
+ return await self.request_api(
558
+ session=session,
559
+ request_data=request_data,
560
+ )
561
+
562
+ async def get_completion(
563
+ self,
564
+ prompt: str,
565
+ session: AsyncHttpPipeline,
566
+ **request_params,
567
+ ) -> dict:
568
+ """
569
+ Query a ChatCompletions model with a single prompt. Note: entire message will be inserted into a "system" call.
570
+
571
+ Parameters
572
+ ----------
573
+ prompt: Prompt str to query model with.
574
+ session: AsyncHttpPipeline object to use for the request.
575
+ **request_params: Additional parameters to pass to the request.
576
+ """
577
+ messages = [{"role": "system", "content": prompt}]
578
+
579
+ request_data = self.format_request_data(messages=messages, **request_params)
580
+ return await self.request_api(
581
+ session=session,
582
+ request_data=request_data,
583
+ )
584
+
585
+ async def get_all_completions(
586
+ self,
587
+ prompts: List[str], # type: ignore[override]
588
+ session: AsyncHttpPipeline,
589
+ api_call_max_parallel_count: int = 1,
590
+ api_call_delay_seconds: float = 0.1,
591
+ request_error_rate_threshold: float = 0.5,
592
+ **request_params,
593
+ ) -> List[dict]:
594
+ prompts_list = [{"role": "system", "content": prompt} for prompt in prompts]
595
+
596
+ return await super().get_all_completions(
597
+ prompts=prompts_list,
598
+ session=session,
599
+ api_call_max_parallel_count=api_call_max_parallel_count,
600
+ api_call_delay_seconds=api_call_delay_seconds,
601
+ request_error_rate_threshold=request_error_rate_threshold,
602
+ **request_params,
603
+ )
604
+
605
+ def _parse_response(self, response_data: dict, request_data: Optional[dict] = None) -> dict:
606
+ # https://platform.openai.com/docs/api-reference/chat
607
+ samples = []
608
+ finish_reason = []
609
+
610
+ for choice in response_data["choices"]:
611
+ if "message" in choice and "content" in choice["message"]:
612
+ samples.append(choice["message"]["content"])
613
+ if "message" in choice and "finish_reason" in choice["message"]:
614
+ finish_reason.append(choice["message"]["finish_reason"])
615
+
616
+ return {"samples": samples, "finish_reason": finish_reason, "id": response_data["id"]}
@@ -0,0 +1,69 @@
1
+ ---
2
+ name: TaskSimulatorQueryResponse
3
+ description: Gets queries and responses from a blob of text
4
+ model:
5
+ api: chat
6
+ configuration:
7
+ type: azure_openai
8
+ azure_deployment: ${env:AZURE_DEPLOYMENT}
9
+ api_key: ${env:AZURE_OPENAI_API_KEY}
10
+ azure_endpoint: ${env:AZURE_OPENAI_ENDPOINT}
11
+ parameters:
12
+ temperature: 0.0
13
+ top_p: 1.0
14
+ presence_penalty: 0
15
+ frequency_penalty: 0
16
+ response_format:
17
+ type: json_object
18
+
19
+ inputs:
20
+ text:
21
+ type: string
22
+ num_queries:
23
+ type: integer
24
+
25
+ ---
26
+ system:
27
+ You're an AI that helps in preparing a Question/Answer quiz from Text for "Who wants to be a millionaire" tv show
28
+ Both Questions and Answers MUST BE extracted from given Text
29
+ Frame Question in a way so that Answer is RELEVANT SHORT BITE-SIZED info from Text
30
+ RELEVANT info could be: NUMBER, DATE, STATISTIC, MONEY, NAME
31
+ A sentence should contribute multiple QnAs if it has more info in it
32
+ Answer must not be more than 5 words
33
+ Answer must be picked from Text as is
34
+ Question should be as descriptive as possible and must include as much context as possible from Text
35
+ Output must always have the provided number of QnAs
36
+ Output must be in JSON format
37
+ Text:
38
+ <|text_start|>
39
+ On January 24, 1984, former Apple CEO Steve Jobs introduced the first Macintosh. In late 2003, Apple had 2.06 percent of the desktop share in the United States.
40
+ Some years later, research firms IDC and Gartner reported that Apple's market share in the U.S. had increased to about 6%.
41
+ <|text_end|>
42
+ Output with 5 QnAs:
43
+ [
44
+ {
45
+ "q": "When did the former Apple CEO Steve Jobs introduced the first Macintosh?",
46
+ "r": "January 24, 1984"
47
+ },
48
+ {
49
+ "q": "Who was the former Apple CEO that introduced the first Macintosh on January 24, 1984?",
50
+ "r": "Steve Jobs"
51
+ },
52
+ {
53
+ "q": "What percent of the desktop share did Apple have in the United States in late 2003?",
54
+ "r": "2.06 percent"
55
+ },
56
+ {
57
+ "q": "What were the research firms that reported on Apple's market share in the U.S.?",
58
+ "r": "IDC and Gartner"
59
+ },
60
+ {
61
+ "q": "What was the percentage increase of Apple's market share in the U.S., as reported by research firms IDC and Gartner?",
62
+ "r": "6%"
63
+ }
64
+ ]
65
+ Text:
66
+ <|text_start|>
67
+ {{ text }}
68
+ <|text_end|>
69
+ Output with {{ num_queries }} QnAs: