eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,135 @@
1
+ # mypy: ignore-errors
2
+
3
+ import dataclasses
4
+
5
+ from eval_framework.external.ifeval_impl import instructions_registry
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class InputExample:
10
+ key: int
11
+ instruction_id_list: list[str]
12
+ prompt: str
13
+ kwargs: list[dict[str, str | int | None]]
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class OutputExample:
18
+ instruction_id_list: list[str]
19
+ prompt: str
20
+ response: str
21
+ follow_all_instructions: bool
22
+ follow_instruction_list: list[bool]
23
+
24
+
25
+ def test_instruction_following_strict(
26
+ inp,
27
+ response,
28
+ ):
29
+ """Tests response to see if instructions are followed."""
30
+ instruction_list = inp.instruction_id_list
31
+ is_following_list = []
32
+
33
+ for index, instruction_id in enumerate(instruction_list):
34
+ instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
35
+ instruction = instruction_cls(instruction_id)
36
+
37
+ # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
38
+ kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
39
+ instruction.build_description(**kwargs)
40
+ args = instruction.get_instruction_args()
41
+ if args and "prompt" in args:
42
+ instruction.build_description(prompt=inp.prompt)
43
+
44
+ if response.strip() and instruction.check_following(response):
45
+ is_following_list.append(True)
46
+ else:
47
+ is_following_list.append(False)
48
+
49
+ return OutputExample(
50
+ instruction_id_list=inp.instruction_id_list,
51
+ prompt=inp.prompt,
52
+ response=response,
53
+ follow_all_instructions=all(is_following_list),
54
+ follow_instruction_list=is_following_list,
55
+ )
56
+
57
+
58
+ def test_instruction_following_loose(
59
+ inp,
60
+ response,
61
+ ):
62
+ """Tests response for an upper bound for following instructions."""
63
+ r = response.split("\n")
64
+ response_remove_first = "\n".join(r[1:]).strip()
65
+ response_remove_last = "\n".join(r[:-1]).strip()
66
+ response_remove_both = "\n".join(r[1:-1]).strip()
67
+ revised_response = response.replace("*", "")
68
+ revised_response_remove_first = response_remove_first.replace("*", "")
69
+ revised_response_remove_last = response_remove_last.replace("*", "")
70
+ revised_response_remove_both = response_remove_both.replace("*", "")
71
+ all_responses = [
72
+ response,
73
+ revised_response,
74
+ response_remove_first,
75
+ response_remove_last,
76
+ response_remove_both,
77
+ revised_response_remove_first,
78
+ revised_response_remove_last,
79
+ revised_response_remove_both,
80
+ ]
81
+ instruction_list = inp.instruction_id_list
82
+ is_following_list = []
83
+
84
+ for index, instruction_id in enumerate(instruction_list):
85
+ instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
86
+ instruction = instruction_cls(instruction_id)
87
+
88
+ # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
89
+ kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
90
+ instruction.build_description(**kwargs)
91
+ args = instruction.get_instruction_args()
92
+ if args and "prompt" in args:
93
+ instruction.build_description(prompt=inp.prompt)
94
+
95
+ is_following = False
96
+ for r in all_responses:
97
+ if r.strip() and instruction.check_following(r):
98
+ is_following = True
99
+ break
100
+
101
+ is_following_list.append(is_following)
102
+
103
+ return OutputExample(
104
+ instruction_id_list=inp.instruction_id_list,
105
+ prompt=inp.prompt,
106
+ response=response,
107
+ follow_all_instructions=all(is_following_list),
108
+ follow_instruction_list=is_following_list,
109
+ )
110
+
111
+
112
+ def process_results(input, results):
113
+ response = results[0]
114
+ input_example = InputExample(
115
+ key=input.key,
116
+ instruction_id_list=input.instruction_id_list,
117
+ prompt=input.prompt,
118
+ kwargs=input.additional_kwargs,
119
+ )
120
+
121
+ out_strict = test_instruction_following_strict(input_example, response)
122
+ out_loose = test_instruction_following_loose(input_example, response)
123
+
124
+ return {
125
+ "prompt_level_strict_acc": out_strict.follow_all_instructions,
126
+ "inst_level_strict_acc": out_strict.follow_instruction_list,
127
+ "prompt_level_loose_acc": out_loose.follow_all_instructions,
128
+ "inst_level_loose_acc": out_loose.follow_instruction_list,
129
+ }
130
+
131
+
132
+ def agg_inst_level_acc(items):
133
+ flat_items = [item for sublist in items for item in sublist]
134
+ inst_level_acc = sum(flat_items) / len(flat_items)
135
+ return inst_level_acc
File without changes
@@ -0,0 +1,432 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import re
8
+ import time
9
+ import traceback
10
+ from collections.abc import Callable, Sequence
11
+
12
+ import aiohttp
13
+ from aleph_alpha_client import (
14
+ AsyncClient,
15
+ BusyError,
16
+ Client,
17
+ CompletionRequest,
18
+ CompletionResponse,
19
+ Prompt,
20
+ )
21
+ from aleph_alpha_client.prompt import Text
22
+ from dotenv import load_dotenv
23
+
24
+ from eval_framework.llm.base import BaseLLM
25
+ from eval_framework.shared.types import Error, PromptTooLongException, RawCompletion, RawLoglikelihood
26
+ from eval_framework.tasks.base import Sample
27
+ from eval_framework.tasks.utils import raise_errors
28
+ from template_formatting.formatter import BaseFormatter, Llama3Formatter, Message
29
+
30
+ load_dotenv()
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def safe_json_loads(s: str) -> dict[str, str]:
36
+ try:
37
+ return json.loads(s)
38
+ except (json.JSONDecodeError, TypeError):
39
+ return {}
40
+
41
+
42
+ class AlephAlphaAPIModel(BaseLLM):
43
+ LLM_NAME: str
44
+ DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
45
+ BYTES_PER_TOKEN: float = 4.0 # rule of thumb according to https://platform.openai.com/tokenizer
46
+
47
+ def __init__(
48
+ self,
49
+ formatter: BaseFormatter | None = None,
50
+ checkpoint_name: str | None = None,
51
+ temperature: float | None = None,
52
+ # Please see README.md for tips if adapting the following parameters.
53
+ max_retries: int = 100,
54
+ max_async_concurrent_requests: int = 32,
55
+ request_timeout_seconds: int = 30 * 60 + 5,
56
+ queue_full_timeout_seconds: int = 30 * 60 + 5,
57
+ bytes_per_token: float | None = None,
58
+ token: str = os.getenv("AA_TOKEN", "dummy"),
59
+ base_url: str = os.getenv("AA_INFERENCE_ENDPOINT", "dummy_endpoint"),
60
+ ) -> None:
61
+ self._formatter: BaseFormatter
62
+ if formatter is None:
63
+ if self.DEFAULT_FORMATTER is None:
64
+ raise ValueError("Either formatter or default formatter must be specified")
65
+ self._formatter = self.DEFAULT_FORMATTER()
66
+ else:
67
+ self._formatter = formatter
68
+ self._llm_name = checkpoint_name or self.LLM_NAME
69
+ self._temperature = temperature if temperature is not None else 0.0
70
+ self.max_async_concurrent_requests = max_async_concurrent_requests
71
+ self.max_retries = max_retries
72
+ self.request_timeout_seconds = request_timeout_seconds
73
+ self.queue_full_timeout_seconds = queue_full_timeout_seconds
74
+ self.token = token
75
+ self.base_url = base_url
76
+ self._validate_model_availability(base_url, token)
77
+ # set bytes_per_token_scalar for non-standard models
78
+ if bytes_per_token is not None and bytes_per_token <= 0:
79
+ raise ValueError("bytes_per_token must be positive")
80
+ self.bytes_per_token_scalar = (
81
+ 4.0 / bytes_per_token if bytes_per_token is not None else 4.0 / self.BYTES_PER_TOKEN
82
+ )
83
+
84
+ def _validate_model_availability(self, base_url: str, token: str) -> None:
85
+ """
86
+ Validate that the model name is available by making a test request.
87
+ """
88
+ try:
89
+ # 'Client' object does not support the context manager protocol
90
+ client = Client(
91
+ host=base_url,
92
+ token=token,
93
+ )
94
+
95
+ request = CompletionRequest(
96
+ prompt=Prompt.from_text(""),
97
+ maximum_tokens=1,
98
+ )
99
+ client.complete(request, model=self._llm_name)
100
+ logger.info(f"Model '{self._llm_name}' available and loaded.")
101
+ except Exception as e:
102
+ raise RuntimeError(f"Model '{self._llm_name}' is not available: {e}")
103
+
104
+ async def _request_with_backoff(
105
+ self, client: AsyncClient, request: CompletionRequest, id: int
106
+ ) -> CompletionResponse:
107
+ """
108
+ Query Aleph-Alpha API with complete. Retry with back-off until it responds.
109
+ """
110
+ num_attempts = 0
111
+ start_time: float | None = None
112
+
113
+ while True:
114
+ try:
115
+ return await client.complete(request, model=self._llm_name)
116
+
117
+ except (TimeoutError, BusyError, RuntimeError, aiohttp.ClientError) as e:
118
+ status_code: str = safe_json_loads(e.args[1]).get("code", "") if len(e.args) >= 2 else ""
119
+ str_e = str(e)
120
+ if status_code == "QUEUE_FULL":
121
+ # Worker not available or missed a heartbeat (inference longer than scheduler's
122
+ # API_MODEL_AVAILABLE_TIMEOUT_DURATION_MILLIS) or the scheduler is overloaded.
123
+ if start_time is None:
124
+ start_time = time.time()
125
+ elapsed = time.time() - start_time
126
+ if elapsed <= self.queue_full_timeout_seconds:
127
+ logger.info(
128
+ f"Request {id}: {status_code or str_e[:256]} - retrying: attempt"
129
+ f" {num_attempts}/{self.max_retries}, elapsed {elapsed:.1f} sec"
130
+ )
131
+ # don't count as retry (request returns immediately, so just wait a bit not to DoS the server)
132
+ await asyncio.sleep(random.randint(5, 30))
133
+ continue
134
+
135
+ elif (
136
+ status_code == "TIMEOUT_TASK"
137
+ or isinstance(e, TimeoutError)
138
+ or "502 Bad Gateway" in str_e
139
+ or "504 Gateway Time-out" in str_e
140
+ or isinstance(e, aiohttp.ClientError)
141
+ ):
142
+ # client timeout, either because task too long in a queue or inference too long
143
+ # (scheduler's API_CLIENT_TIMEOUT_DURATION_MILLIS). Retrying for the "inference too long"
144
+ # case makes no sense but we unfortunately don't know which case has happened.
145
+ num_attempts += 1
146
+ start_time = None
147
+ if num_attempts < self.max_retries:
148
+ logger.info(f"Request {id}: TIMEOUT_TASK - retrying: attempt {num_attempts}/{self.max_retries}")
149
+ await asyncio.sleep(random.randint(5, 30))
150
+ continue
151
+
152
+ raise e
153
+
154
+ def _error_from_exception(self, e: Exception) -> Error:
155
+ """Convert an exception to an Error object."""
156
+ if len(e.args) >= 2:
157
+ status_code: str = safe_json_loads(e.args[1]).get("code", "")
158
+ if status_code == "PROMPT_TOO_LONG":
159
+ return Error(
160
+ error_class=PromptTooLongException.__name__,
161
+ message="Prompt exceeded context size.",
162
+ traceback=traceback.format_exc(),
163
+ )
164
+ else:
165
+ return Error(
166
+ error_class=status_code or e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
167
+ )
168
+ else:
169
+ return Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
170
+
171
+ async def _process_request_with_client(
172
+ self,
173
+ client: AsyncClient,
174
+ semaphore: asyncio.Semaphore,
175
+ request: CompletionRequest,
176
+ id: int,
177
+ ) -> tuple[CompletionRequest, CompletionResponse | Error]:
178
+ """Process a single request, returning the request and either a response or error."""
179
+ async with semaphore:
180
+ try:
181
+ response = await self._request_with_backoff(client=client, request=request, id=id)
182
+ logger.info(f"Request {id}: Success")
183
+ return (request, response)
184
+ except Exception as e:
185
+ if raise_errors():
186
+ raise e
187
+ logger.info(f"Request {id}: Failure: {str(e)[:256]}")
188
+ return (request, self._error_from_exception(e))
189
+
190
+ async def _process_requests(
191
+ self,
192
+ requests: list[CompletionRequest],
193
+ ) -> list[tuple[CompletionRequest, CompletionResponse | Error]]:
194
+ """Process multiple requests concurrently, returning request/response pairs."""
195
+ semaphore = asyncio.Semaphore(self.max_async_concurrent_requests)
196
+ async with AsyncClient(
197
+ host=self.base_url,
198
+ nice=True,
199
+ request_timeout_seconds=self.request_timeout_seconds,
200
+ token=self.token,
201
+ total_retries=0, # we have a custom retry policy in _request_with_backoff()
202
+ ) as client:
203
+ tasks = (
204
+ self._process_request_with_client(
205
+ client,
206
+ semaphore,
207
+ request,
208
+ i,
209
+ )
210
+ for i, request in enumerate(requests)
211
+ )
212
+ responses = await asyncio.gather(*tasks) # guarantees order of responses
213
+ return list(responses)
214
+
215
+ def _response_to_raw_completion(
216
+ self, request: CompletionRequest, response: CompletionResponse | Error
217
+ ) -> RawCompletion:
218
+ """Convert a request/response pair to a RawCompletion."""
219
+ assert isinstance(request.prompt.items[0], Text)
220
+ prompt = request.prompt.items[0].text
221
+
222
+ if isinstance(response, Error):
223
+ return RawCompletion(
224
+ prompt=prompt,
225
+ prompt_sequence_positions=None,
226
+ completion="",
227
+ completion_sequence_positions=0,
228
+ raw_completion_error=response,
229
+ )
230
+
231
+ assert len(response.completions) == 1
232
+ completion = response.completions[0].completion or ""
233
+ prompt_sequence_positions: int | None = None
234
+ completion_sequence_positions: int | None = None
235
+
236
+ # Support workaround in api-worker-transformer's scaling generator to return the correct number of tokens.
237
+ # These are part of the completion string; those in CompletionResponse are invalid in this case.
238
+ m = re.match(r"\uf8c9(\d+),(\d+)\uf8c9(.*)", completion, re.DOTALL)
239
+ if m is not None:
240
+ num_input_tokens, num_completion_tokens, completion = m.groups()
241
+ prompt_sequence_positions = int(num_input_tokens)
242
+ completion_sequence_positions = int(num_completion_tokens)
243
+ else:
244
+ prompt_sequence_positions = response.num_tokens_prompt_total if response else None
245
+ completion_sequence_positions = response.num_tokens_generated if response else None
246
+
247
+ return RawCompletion(
248
+ prompt=prompt,
249
+ prompt_sequence_positions=prompt_sequence_positions,
250
+ completion=completion,
251
+ completion_sequence_positions=completion_sequence_positions,
252
+ )
253
+
254
+ def generate_from_messages(
255
+ self,
256
+ messages: list[Sequence[Message]],
257
+ stop_sequences: list[str] | None = None,
258
+ max_tokens: int | None = None,
259
+ temperature: float | None = None,
260
+ ) -> list[RawCompletion]:
261
+ effective_temperature = temperature if temperature is not None else self._temperature
262
+
263
+ requests: list[CompletionRequest] = []
264
+
265
+ # Adjust max tokens based on bytes_per_token_scalar so that non-standard models generate full responses
266
+ scaled_max_tokens = math.ceil(max_tokens * self.bytes_per_token_scalar) if max_tokens is not None else None
267
+
268
+ for single_messages in messages:
269
+ requests.append(
270
+ CompletionRequest(
271
+ prompt=Prompt.from_text(self._formatter.format(single_messages, output_mode="string")),
272
+ maximum_tokens=scaled_max_tokens,
273
+ stop_sequences=stop_sequences,
274
+ temperature=effective_temperature,
275
+ )
276
+ )
277
+
278
+ responses = asyncio.run(self._process_requests(requests))
279
+ return [self._response_to_raw_completion(req, resp) for req, resp in responses]
280
+
281
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
282
+ prompts: list[str] = []
283
+ completion_requests: list[CompletionRequest] = []
284
+
285
+ for sample in samples:
286
+ prompt: str = self._formatter.format(sample.messages, output_mode="string") if sample.messages else ""
287
+ prompts.append(prompt)
288
+ for choice in sample.possible_completions or []:
289
+ completion_requests.append(
290
+ CompletionRequest(
291
+ prompt=Prompt.from_text(prompt + choice),
292
+ maximum_tokens=0,
293
+ temperature=0.0,
294
+ log_probs=0,
295
+ echo=True,
296
+ tokens=True,
297
+ )
298
+ )
299
+
300
+ completion_responses: list[tuple[CompletionRequest, CompletionResponse | Error]] = []
301
+ if completion_requests:
302
+ completion_responses = asyncio.run(self._process_requests(completion_requests))
303
+ completion_iter = iter(completion_responses)
304
+
305
+ results: list[RawLoglikelihood] = []
306
+ for sample_idx, (sample, prompt) in enumerate(zip(samples, prompts, strict=True)):
307
+ choices_log_probs: dict[str, float] = {}
308
+ choices_sequence_positions: dict[str, int] = {}
309
+ prompt_sequence_positions: int | None = 0
310
+ number_of_initial_choices_tokens: int | None = None
311
+ error: Error | None = None
312
+
313
+ for choice in sample.possible_completions or []:
314
+ request, response = next(completion_iter)
315
+ assert isinstance(request, CompletionRequest)
316
+ if error is not None:
317
+ continue
318
+
319
+ if isinstance(response, Error):
320
+ error = response
321
+ prompt_sequence_positions = None
322
+ choices_log_probs = {}
323
+ choices_sequence_positions = {}
324
+ else:
325
+ try:
326
+ logprob, choice_token_count = self._extract_choice_logprob_from_completion(
327
+ prompt=prompt,
328
+ choice=choice,
329
+ response=response,
330
+ )
331
+ choices_log_probs[choice] = logprob
332
+ choices_sequence_positions[choice] = choice_token_count
333
+ if number_of_initial_choices_tokens is None:
334
+ number_of_initial_choices_tokens = choice_token_count
335
+
336
+ self._check_choices_token_count(
337
+ sample_idx, choice_token_count, number_of_initial_choices_tokens
338
+ )
339
+
340
+ except Exception as exc:
341
+ if raise_errors():
342
+ raise
343
+ error = Error(
344
+ error_class=exc.__class__.__name__,
345
+ message=str(exc),
346
+ traceback=traceback.format_exc(),
347
+ )
348
+ prompt_sequence_positions = None
349
+ choices_log_probs = {}
350
+ choices_sequence_positions = {}
351
+
352
+ results.append(
353
+ RawLoglikelihood(
354
+ prompt=prompt,
355
+ prompt_sequence_positions=prompt_sequence_positions,
356
+ loglikelihoods=choices_log_probs,
357
+ loglikelihoods_sequence_positions=choices_sequence_positions,
358
+ raw_loglikelihood_error=error,
359
+ )
360
+ )
361
+
362
+ return results
363
+
364
+ @staticmethod
365
+ def _check_choices_token_count(
366
+ sample_idx: int, choice_token_count: int, number_of_initial_choices_tokens: int | None
367
+ ) -> None:
368
+ if number_of_initial_choices_tokens is not None:
369
+ if choice_token_count != number_of_initial_choices_tokens:
370
+ logger.warning(
371
+ "Choice token count differed between choices for sample %s (%s vs %s). Using latest value.",
372
+ sample_idx,
373
+ choice_token_count,
374
+ number_of_initial_choices_tokens,
375
+ )
376
+
377
+ @staticmethod
378
+ def _extract_choice_logprob_from_completion(
379
+ prompt: str, choice: str, response: CompletionResponse
380
+ ) -> tuple[float, int]:
381
+ if not response.completions:
382
+ raise ValueError("Completion response did not contain any choices.")
383
+ completion_result = response.completions[0]
384
+ if completion_result.log_probs is None:
385
+ raise ValueError("Completion result did not include log_probs.")
386
+ if completion_result.completion_tokens is None:
387
+ raise ValueError("Completion result did not include completion_tokens.")
388
+
389
+ tokens = list(completion_result.completion_tokens)
390
+ log_prob_entries = list(completion_result.log_probs)
391
+
392
+ if len(tokens) != len(log_prob_entries):
393
+ raise ValueError("Mismatch between completion tokens and log_prob entries.")
394
+
395
+ combined_text = "".join(tokens)
396
+ expected_text = prompt + choice
397
+ if combined_text != expected_text:
398
+ raise ValueError("Completion tokens differed from prompt + choice text.")
399
+
400
+ prompt_token_count = AlephAlphaAPIModel._count_prompt_tokens_from_sequence(tokens, prompt)
401
+ choice_token_count = len(tokens) - prompt_token_count
402
+ if choice_token_count < 0:
403
+ raise ValueError("Choice token count computed as negative.")
404
+
405
+ total_logprob = 0.0
406
+ for entry in log_prob_entries[prompt_token_count:]:
407
+ assert isinstance(entry, dict)
408
+ if len(entry) != 1:
409
+ raise ValueError("Log_probs entry did not contain exactly one key-value pair.")
410
+ _, value = entry.popitem()
411
+ assert isinstance(value, float)
412
+ total_logprob += value
413
+
414
+ return total_logprob, choice_token_count
415
+
416
+ @staticmethod
417
+ def _count_prompt_tokens_from_sequence(tokens: Sequence[str], prompt: str) -> int:
418
+ if not prompt:
419
+ return 0
420
+ current_text = ""
421
+ for idx, token in enumerate(tokens):
422
+ current_text += token
423
+ if current_text == prompt:
424
+ return idx + 1
425
+ if len(current_text) > len(prompt):
426
+ break
427
+ raise ValueError("Unable to align completion tokens with prompt text.")
428
+
429
+
430
+ class Llama31_8B_Instruct_API(AlephAlphaAPIModel):
431
+ LLM_NAME = "llama-3.1-8b-instruct"
432
+ DEFAULT_FORMATTER = Llama3Formatter