eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,135 @@
1
+ # mypy: ignore-errors
2
+
3
+ import dataclasses
4
+
5
+ from eval_framework.external.ifeval_impl import instructions_registry
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class InputExample:
10
+ key: int
11
+ instruction_id_list: list[str]
12
+ prompt: str
13
+ kwargs: list[dict[str, str | int | None]]
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class OutputExample:
18
+ instruction_id_list: list[str]
19
+ prompt: str
20
+ response: str
21
+ follow_all_instructions: bool
22
+ follow_instruction_list: list[bool]
23
+
24
+
25
+ def test_instruction_following_strict(
26
+ inp,
27
+ response,
28
+ ):
29
+ """Tests response to see if instructions are followed."""
30
+ instruction_list = inp.instruction_id_list
31
+ is_following_list = []
32
+
33
+ for index, instruction_id in enumerate(instruction_list):
34
+ instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
35
+ instruction = instruction_cls(instruction_id)
36
+
37
+ # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
38
+ kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
39
+ instruction.build_description(**kwargs)
40
+ args = instruction.get_instruction_args()
41
+ if args and "prompt" in args:
42
+ instruction.build_description(prompt=inp.prompt)
43
+
44
+ if response.strip() and instruction.check_following(response):
45
+ is_following_list.append(True)
46
+ else:
47
+ is_following_list.append(False)
48
+
49
+ return OutputExample(
50
+ instruction_id_list=inp.instruction_id_list,
51
+ prompt=inp.prompt,
52
+ response=response,
53
+ follow_all_instructions=all(is_following_list),
54
+ follow_instruction_list=is_following_list,
55
+ )
56
+
57
+
58
+ def test_instruction_following_loose(
59
+ inp,
60
+ response,
61
+ ):
62
+ """Tests response for an upper bound for following instructions."""
63
+ r = response.split("\n")
64
+ response_remove_first = "\n".join(r[1:]).strip()
65
+ response_remove_last = "\n".join(r[:-1]).strip()
66
+ response_remove_both = "\n".join(r[1:-1]).strip()
67
+ revised_response = response.replace("*", "")
68
+ revised_response_remove_first = response_remove_first.replace("*", "")
69
+ revised_response_remove_last = response_remove_last.replace("*", "")
70
+ revised_response_remove_both = response_remove_both.replace("*", "")
71
+ all_responses = [
72
+ response,
73
+ revised_response,
74
+ response_remove_first,
75
+ response_remove_last,
76
+ response_remove_both,
77
+ revised_response_remove_first,
78
+ revised_response_remove_last,
79
+ revised_response_remove_both,
80
+ ]
81
+ instruction_list = inp.instruction_id_list
82
+ is_following_list = []
83
+
84
+ for index, instruction_id in enumerate(instruction_list):
85
+ instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
86
+ instruction = instruction_cls(instruction_id)
87
+
88
+ # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
89
+ kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
90
+ instruction.build_description(**kwargs)
91
+ args = instruction.get_instruction_args()
92
+ if args and "prompt" in args:
93
+ instruction.build_description(prompt=inp.prompt)
94
+
95
+ is_following = False
96
+ for r in all_responses:
97
+ if r.strip() and instruction.check_following(r):
98
+ is_following = True
99
+ break
100
+
101
+ is_following_list.append(is_following)
102
+
103
+ return OutputExample(
104
+ instruction_id_list=inp.instruction_id_list,
105
+ prompt=inp.prompt,
106
+ response=response,
107
+ follow_all_instructions=all(is_following_list),
108
+ follow_instruction_list=is_following_list,
109
+ )
110
+
111
+
112
+ def process_results(input, results):
113
+ response = results[0]
114
+ input_example = InputExample(
115
+ key=input.key,
116
+ instruction_id_list=input.instruction_id_list,
117
+ prompt=input.prompt,
118
+ kwargs=input.additional_kwargs,
119
+ )
120
+
121
+ out_strict = test_instruction_following_strict(input_example, response)
122
+ out_loose = test_instruction_following_loose(input_example, response)
123
+
124
+ return {
125
+ "prompt_level_strict_acc": out_strict.follow_all_instructions,
126
+ "inst_level_strict_acc": out_strict.follow_instruction_list,
127
+ "prompt_level_loose_acc": out_loose.follow_all_instructions,
128
+ "inst_level_loose_acc": out_loose.follow_instruction_list,
129
+ }
130
+
131
+
132
+ def agg_inst_level_acc(items):
133
+ flat_items = [item for sublist in items for item in sublist]
134
+ inst_level_acc = sum(flat_items) / len(flat_items)
135
+ return inst_level_acc
File without changes
@@ -0,0 +1,323 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ import re
7
+ import time
8
+ import traceback
9
+ from collections.abc import Callable, Sequence
10
+
11
+ import aiohttp
12
+ from aleph_alpha_client import (
13
+ AsyncClient,
14
+ BusyError,
15
+ Client,
16
+ CompletionRequest,
17
+ CompletionResponse,
18
+ EvaluationRequest,
19
+ EvaluationResponse,
20
+ Prompt,
21
+ )
22
+ from aleph_alpha_client.prompt import Text
23
+ from dotenv import load_dotenv
24
+
25
+ from eval_framework.llm.base import BaseLLM
26
+ from eval_framework.shared.types import Error, PromptTooLongException, RawCompletion, RawLoglikelihood
27
+ from eval_framework.tasks.base import Sample
28
+ from eval_framework.tasks.utils import raise_errors
29
+ from template_formatting.formatter import BaseFormatter, Llama3Formatter, Message
30
+
31
+ load_dotenv()
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def safe_json_loads(s: str) -> dict:
37
+ try:
38
+ return json.loads(s)
39
+ except (json.JSONDecodeError, TypeError):
40
+ return {}
41
+
42
+
43
+ class AlephAlphaAPIModel(BaseLLM):
44
+ LLM_NAME: str
45
+ DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
46
+
47
+ def __init__(
48
+ self,
49
+ formatter: BaseFormatter | None = None,
50
+ checkpoint_name: str | None = None,
51
+ # Please see README.md for tips if adapting the following parameters.
52
+ max_retries: int = 100,
53
+ max_async_concurrent_requests: int = 32,
54
+ request_timeout_seconds: int = 30 * 60 + 5,
55
+ queue_full_timeout_seconds: int = 30 * 60 + 5,
56
+ ) -> None:
57
+ self._formatter: BaseFormatter
58
+ if formatter is None:
59
+ if self.DEFAULT_FORMATTER is None:
60
+ raise ValueError("Either formatter or default formatter must be specified")
61
+ self._formatter = self.DEFAULT_FORMATTER()
62
+ else:
63
+ self._formatter = formatter
64
+ self._llm_name = checkpoint_name or self.LLM_NAME
65
+ self.max_async_concurrent_requests = max_async_concurrent_requests
66
+ self.max_retries = max_retries
67
+ self.request_timeout_seconds = request_timeout_seconds
68
+ self.queue_full_timeout_seconds = queue_full_timeout_seconds
69
+ self._validate_model_availability()
70
+
71
+ def _validate_model_availability(self) -> None:
72
+ """
73
+ Validate that the model name is available by making a test request.
74
+ """
75
+ try:
76
+ # 'Client' object does not support the context manager protocol
77
+ client = Client(
78
+ host=os.getenv("AA_INFERENCE_ENDPOINT", "dummy_endpoint"),
79
+ token=os.getenv("AA_TOKEN", "dummy"),
80
+ )
81
+
82
+ request = CompletionRequest(
83
+ prompt=Prompt.from_text(""),
84
+ maximum_tokens=1,
85
+ )
86
+ client.complete(request, model=self._llm_name)
87
+ logger.info(f"Model '{self._llm_name}' available and loaded.")
88
+ except Exception as e:
89
+ raise RuntimeError(f"Model '{self._llm_name}' is not available: {e}")
90
+
91
+ async def _request_with_backoff(
92
+ self, client: AsyncClient, request: CompletionRequest | EvaluationRequest, id: int
93
+ ) -> CompletionResponse | EvaluationResponse:
94
+ """
95
+ Query Aleph-Alpha API with complete. Retry with back-off until it responds.
96
+ """
97
+ num_attempts = 0
98
+ start_time: float | None = None
99
+
100
+ while True:
101
+ try:
102
+ if isinstance(request, CompletionRequest):
103
+ return await client.complete(request, model=self._llm_name)
104
+ elif isinstance(request, EvaluationRequest):
105
+ return await client.evaluate(request, model=self._llm_name)
106
+ else:
107
+ raise ValueError(f"Unsupported request type: {type(request)}")
108
+
109
+ except (TimeoutError, BusyError, RuntimeError, aiohttp.ClientError) as e:
110
+ status_code: str = safe_json_loads(e.args[1]).get("code", "") if len(e.args) >= 2 else ""
111
+ str_e = str(e)
112
+ if status_code == "QUEUE_FULL":
113
+ # Worker not available or missed a heartbeat (inference longer than scheduler's
114
+ # API_MODEL_AVAILABLE_TIMEOUT_DURATION_MILLIS) or the scheduler is overloaded.
115
+ if start_time is None:
116
+ start_time = time.time()
117
+ elapsed = time.time() - start_time
118
+ if elapsed <= self.queue_full_timeout_seconds:
119
+ logger.info(
120
+ f"Request {id}: {status_code or str_e[:256]} - retrying: attempt"
121
+ f" {num_attempts}/{self.max_retries}, elapsed {elapsed:.1f} sec"
122
+ )
123
+ # don't count as retry (request returns immediately, so just wait a bit not to DoS the server)
124
+ await asyncio.sleep(random.randint(5, 30))
125
+ continue
126
+
127
+ elif (
128
+ status_code == "TIMEOUT_TASK"
129
+ or isinstance(e, TimeoutError)
130
+ or "502 Bad Gateway" in str_e
131
+ or "504 Gateway Time-out" in str_e
132
+ or isinstance(e, aiohttp.ClientError)
133
+ ):
134
+ # client timeout, either because task too long in a queue or inference too long
135
+ # (scheduler's API_CLIENT_TIMEOUT_DURATION_MILLIS). Retrying for the "inference too long"
136
+ # case makes no sense but we unfortunately don't know which case has happened.
137
+ num_attempts += 1
138
+ start_time = None
139
+ if num_attempts < self.max_retries:
140
+ logger.info(f"Request {id}: TIMEOUT_TASK - retrying: attempt {num_attempts}/{self.max_retries}")
141
+ await asyncio.sleep(random.randint(5, 30))
142
+ continue
143
+
144
+ raise e
145
+
146
+ async def _process_request_with_client(
147
+ self,
148
+ client: AsyncClient,
149
+ semaphore: asyncio.Semaphore,
150
+ request: CompletionRequest | EvaluationRequest,
151
+ id: int,
152
+ ) -> RawCompletion | tuple[EvaluationRequest, EvaluationResponse | Error]:
153
+ async with semaphore:
154
+ try:
155
+ response = await self._request_with_backoff(client=client, request=request, id=id)
156
+ logger.info(f"Request {id}: Success")
157
+ except Exception as e:
158
+ if raise_errors():
159
+ raise e
160
+ logger.info(f"Request {id}: Failure: {str(e)[:256]}")
161
+ if len(e.args) >= 2:
162
+ status_code: str = safe_json_loads(e.args[1]).get("code", "")
163
+ if status_code == "PROMPT_TOO_LONG":
164
+ error = Error(
165
+ error_class=PromptTooLongException.__name__,
166
+ message="Prompt exceeded context size.",
167
+ traceback=traceback.format_exc(),
168
+ )
169
+ else:
170
+ error = Error(error_class=status_code, message=str(e), traceback=traceback.format_exc())
171
+ else:
172
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
173
+
174
+ if isinstance(request, CompletionRequest):
175
+ assert isinstance(request.prompt.items[0], Text)
176
+ return RawCompletion(
177
+ prompt=request.prompt.items[0].text,
178
+ prompt_sequence_positions=None,
179
+ completion="",
180
+ completion_sequence_positions=0,
181
+ raw_completion_error=error,
182
+ )
183
+ else:
184
+ return (request, error)
185
+
186
+ # Completion responses can directly be converted to RawCompletion
187
+ if isinstance(request, CompletionRequest):
188
+ assert isinstance(request.prompt.items[0], Text) and isinstance(response, CompletionResponse)
189
+ assert len(response.completions) == 1
190
+ prompt = request.prompt.items[0].text
191
+ completion = response.completions[0].completion or ""
192
+ prompt_sequence_positions: int | None = None
193
+ completion_sequence_positions: int | None = None
194
+
195
+ # Support workaround in api-worker-transformer's scaling generator to return the correct number of tokens.
196
+ # These are part of the completion string; those in CompletionResponse are invalid in this case.
197
+ m = re.match(r"\uf8c9(\d+),(\d+)\uf8c9(.*)", completion, re.DOTALL)
198
+ if m is not None:
199
+ num_input_tokens, num_completion_tokens, completion = m.groups()
200
+ prompt_sequence_positions = int(num_input_tokens)
201
+ completion_sequence_positions = int(num_completion_tokens)
202
+ else:
203
+ prompt_sequence_positions = response.num_tokens_prompt_total if response else None
204
+ completion_sequence_positions = response.num_tokens_generated if response else None
205
+
206
+ return RawCompletion(
207
+ prompt=prompt,
208
+ prompt_sequence_positions=prompt_sequence_positions,
209
+ completion=completion,
210
+ completion_sequence_positions=completion_sequence_positions,
211
+ )
212
+
213
+ # Evaluation responses must be assembled from individual choice requests later
214
+ else:
215
+ assert isinstance(response, EvaluationResponse)
216
+ return (request, response)
217
+
218
+ async def _process_requests(
219
+ self, requests: list[CompletionRequest] | list[EvaluationRequest]
220
+ ) -> list[RawCompletion | tuple[EvaluationRequest, EvaluationResponse | Error]]:
221
+ semaphore = asyncio.Semaphore(self.max_async_concurrent_requests)
222
+ async with AsyncClient(
223
+ host=os.getenv("AA_INFERENCE_ENDPOINT", "dummy_endpoint"),
224
+ nice=True,
225
+ request_timeout_seconds=self.request_timeout_seconds,
226
+ token=os.getenv("AA_TOKEN", "dummy"),
227
+ total_retries=0, # we have a custom retry policy in _request_with_backoff()
228
+ ) as client:
229
+ tasks = (
230
+ self._process_request_with_client(client, semaphore, request, i)
231
+ for i, request in enumerate[CompletionRequest | EvaluationRequest](requests)
232
+ )
233
+ responses = await asyncio.gather(*tasks) # guarantees order of responses
234
+ return responses
235
+
236
+ def generate_from_messages(
237
+ self,
238
+ messages: list[Sequence[Message]],
239
+ stop_sequences: list[str] | None = None,
240
+ max_tokens: int | None = None,
241
+ temperature: float | None = None,
242
+ ) -> list[RawCompletion]:
243
+ if temperature is None:
244
+ effective_temperature = 0.0 # Current default, TODO: refactor to use model's default
245
+ logger.info(
246
+ f"Using default temperature value: {effective_temperature} as no custom temperature value was provided"
247
+ )
248
+ else:
249
+ effective_temperature = temperature
250
+
251
+ requests = []
252
+
253
+ for single_messages in messages:
254
+ requests.append(
255
+ CompletionRequest(
256
+ prompt=Prompt.from_text(self._formatter.format(single_messages, output_mode="string")),
257
+ maximum_tokens=max_tokens,
258
+ stop_sequences=stop_sequences,
259
+ temperature=effective_temperature,
260
+ )
261
+ )
262
+
263
+ responses = asyncio.run(self._process_requests(requests))
264
+ return responses # type: ignore
265
+
266
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
267
+ samples_prompt: list[str] = []
268
+ evaluation_requests: list[EvaluationRequest] = []
269
+ results: list[RawLoglikelihood] = []
270
+
271
+ # evaluate all choices independently in flattened list
272
+ for sample in samples:
273
+ prompt: str = self._formatter.format(sample.messages, output_mode="string") if sample.messages else ""
274
+ samples_prompt.append(prompt)
275
+ for choice in sample.possible_completions or []:
276
+ evaluation_requests.append(
277
+ EvaluationRequest(prompt=Prompt.from_text(prompt), completion_expected=choice)
278
+ )
279
+
280
+ evaluation_responses = asyncio.run(self._process_requests(evaluation_requests))
281
+ evaluation_responses_iter = iter(evaluation_responses)
282
+
283
+ # assemble choices to RawLoglikelihood from a flattened list for all possible choice replies
284
+ for sample, prompt in zip(samples, samples_prompt, strict=True):
285
+ choices_log_probs: dict[str, float] = {}
286
+ choices_sequence_positions: dict[str, int] = {}
287
+ prompt_sequence_positions: int | None = 0
288
+ error = None
289
+
290
+ for choice in sample.possible_completions or []:
291
+ request, response = next(evaluation_responses_iter)
292
+ if error is not None:
293
+ continue
294
+ if isinstance(response, Error): # failure for one choice leads to failure of the whole sample
295
+ error = response
296
+ prompt_sequence_positions = None
297
+ choices_log_probs = {}
298
+ choices_sequence_positions = {}
299
+ else:
300
+ assert isinstance(request, EvaluationRequest) and isinstance(response, EvaluationResponse)
301
+ assert isinstance(request.prompt.items[0], Text)
302
+ assert prompt == request.prompt.items[0].text, f"{prompt} != {request.prompt.items[0].text}"
303
+ assert choice == request.completion_expected, f"{choice} != {request.completion_expected}"
304
+ prompt_sequence_positions = response.num_tokens_prompt_total - response.result["token_count"]
305
+ choices_log_probs[choice] = response.result["log_probability"]
306
+ choices_sequence_positions[choice] = response.result["token_count"]
307
+
308
+ results.append(
309
+ RawLoglikelihood(
310
+ prompt=prompt,
311
+ prompt_sequence_positions=prompt_sequence_positions,
312
+ loglikelihoods=choices_log_probs,
313
+ loglikelihoods_sequence_positions=choices_sequence_positions,
314
+ raw_loglikelihood_error=error,
315
+ )
316
+ )
317
+
318
+ return results
319
+
320
+
321
+ class Llama31_8B_Instruct_API(AlephAlphaAPIModel):
322
+ LLM_NAME = "llama-3.1-8b-instruct"
323
+ DEFAULT_FORMATTER = Llama3Formatter
@@ -0,0 +1,58 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Sequence
3
+
4
+ from eval_framework.shared.types import RawCompletion, RawLoglikelihood
5
+ from eval_framework.tasks.base import Sample
6
+ from template_formatting.formatter import Message
7
+
8
+
9
+ class BaseLLM(ABC):
10
+ @property
11
+ def name(self) -> str:
12
+ """
13
+ This property is used to name the results folder and identify the eval results.
14
+ Overwrite this property in the subclass with e.g. the checkpoint name/huggingface model name."""
15
+ return self.__class__.__name__
16
+
17
+ @abstractmethod
18
+ def generate_from_messages(
19
+ self,
20
+ messages: list[Sequence[Message]],
21
+ stop_sequences: list[str] | None = None,
22
+ max_tokens: int | None = None,
23
+ temperature: float | None = None,
24
+ ) -> list[RawCompletion]:
25
+ """
26
+ stop_sequences and max_tokens are injected by the task if exist. They should be overwritten or
27
+ extended with the properties of the model. This includes but is not limited to the stop tokens
28
+ by the evaluated checkpoint (e.g. <|eot_id|> for an instruction finetuned Llama3.1, <|endoftext|>
29
+ for a pretrained Llama3.1).
30
+
31
+ This function is expected to raise errors which are caught and reported when running the eval.
32
+ Please also make sure to raise an error in case of sequence length issues. We expect to always
33
+ raise an error if something impedes the expected completion of a task.
34
+
35
+ Important! The completion is expected to be detokenized and to NOT contain special tokens.
36
+
37
+ Returns: List[RawCompletion]
38
+ """
39
+ raise NotImplementedError
40
+
41
+ @abstractmethod
42
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
43
+ """
44
+ This function is expected to raise errors which are caught and reported when running the eval.
45
+ Please also make sure to raise an error in case of sequence length issues. We expect to always
46
+ raise an error if something prevents the expected completion of a task.
47
+ """
48
+ raise NotImplementedError
49
+
50
+ def generate(
51
+ self,
52
+ samples: list[Sample],
53
+ stop_sequences: list[str] | None = None,
54
+ max_tokens: int | None = None,
55
+ temperature: float | None = None,
56
+ ) -> list[RawCompletion]:
57
+ messages: list[Sequence[Message]] = [sample.messages for sample in samples]
58
+ return self.generate_from_messages(messages, stop_sequences, max_tokens, temperature)