eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,438 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable, Sequence
4
+ from dataclasses import dataclass
5
+ from functools import partial
6
+ from typing import Any, Literal, Protocol, cast, override
7
+
8
+ import torch
9
+ from vllm import LLM, SamplingParams
10
+ from vllm.inputs.data import TokensPrompt
11
+ from vllm.outputs import RequestOutput
12
+ from vllm.transformers_utils.tokenizer import get_tokenizer
13
+
14
+ from eval_framework.llm.base import BaseLLM
15
+ from eval_framework.shared.types import (
16
+ ConcatCompression,
17
+ Error,
18
+ PromptTooLongException,
19
+ RawCompletion,
20
+ RawLoglikelihood,
21
+ )
22
+ from eval_framework.tasks.base import Sample
23
+ from eval_framework.tasks.utils import raise_errors
24
+ from eval_framework.utils.constants import RED, RESET
25
+ from template_formatting.formatter import BaseFormatter, HFFormatter, Message
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class TokenizedContainer:
32
+ """
33
+ Container object to store tokens and formatted prompt
34
+ """
35
+
36
+ tokens: list[int]
37
+ text: str
38
+
39
+
40
+ class VLLMTokenizerAPI[prompt_type: (list[Message], str)](ABC):
41
+ """
42
+ Protocol for tokenizer interface that defines required methods.
43
+ Needed for type checking because of the vllm tokenizer.
44
+ """
45
+
46
+ @abstractmethod
47
+ def encode_formatted_struct(self, struct: prompt_type) -> TokenizedContainer:
48
+ """Encode prompt to token IDs."""
49
+ pass
50
+
51
+ @abstractmethod
52
+ def encode_plain_text(self, text: str) -> TokenizedContainer:
53
+ pass
54
+
55
+ @property
56
+ def chat_template(self) -> str | None:
57
+ return None
58
+
59
+
60
+ class HFTokenizerProtocol(Protocol):
61
+ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
62
+ """Encode text to token IDs."""
63
+ ...
64
+
65
+ def decode(self, tokens: list[int]) -> str:
66
+ """Decode token IDs to text."""
67
+ ...
68
+
69
+ @property
70
+ def chat_template(self) -> str | None:
71
+ """Chat template for the tokenizer."""
72
+ ...
73
+
74
+
75
+ class VLLMTokenizer(VLLMTokenizerAPI[str]):
76
+ def __init__(self, target_mdl: str) -> None:
77
+ self.tokenizer = cast(HFTokenizerProtocol, get_tokenizer(target_mdl))
78
+
79
+ def _encode_text(self, text: str) -> TokenizedContainer:
80
+ tokens = self.tokenizer.encode(text, add_special_tokens=False)
81
+ return TokenizedContainer(tokens=tokens, text=text)
82
+
83
+ def encode_formatted_struct(self, struct: str) -> TokenizedContainer:
84
+ return self._encode_text(text=struct)
85
+
86
+ def encode_plain_text(self, text: str) -> TokenizedContainer:
87
+ return self._encode_text(text=text)
88
+
89
+ def decode(self, tokens: list[int]) -> str:
90
+ return self.tokenizer.decode(tokens)
91
+
92
+ @override
93
+ @property
94
+ def chat_template(self) -> str | None:
95
+ return self.tokenizer.chat_template
96
+
97
+
98
+ class VLLMModel(BaseLLM):
99
+ LLM_NAME: str
100
+ DEFAULT_FORMATTER: Callable[[], BaseFormatter] | None = None
101
+ SEQ_LENGTH: int | None = None
102
+
103
+ def __init__(
104
+ self,
105
+ formatter: BaseFormatter | None = None,
106
+ max_model_len: int | None = None,
107
+ tensor_parallel_size: int = 1,
108
+ gpu_memory_utilization: float = 0.9,
109
+ batch_size: int = 1,
110
+ checkpoint_path: str | None = None,
111
+ checkpoint_name: str | None = None,
112
+ sampling_params: SamplingParams | dict[str, Any] | None = None,
113
+ **kwargs: Any,
114
+ ) -> None:
115
+ # Store the max_model_len for later use
116
+ self._max_model_len = max_model_len
117
+ self.checkpoint_name = checkpoint_name
118
+ self.checkpoint_path = checkpoint_path
119
+
120
+ model_args = {
121
+ "model": self.checkpoint_path or self.LLM_NAME,
122
+ "max_model_len": max_model_len or self.SEQ_LENGTH,
123
+ "max_num_seqs": batch_size,
124
+ "tensor_parallel_size": tensor_parallel_size,
125
+ "gpu_memory_utilization": gpu_memory_utilization,
126
+ **kwargs,
127
+ }
128
+
129
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
130
+
131
+ self.batch_size = batch_size
132
+ self._tokenizer: None | VLLMTokenizerAPI = None
133
+
134
+ self.model = LLM(**model_args, device=device)
135
+
136
+ self.sampling_params: SamplingParams = self._process_sampling_params(sampling_params)
137
+
138
+ logger.info(f"{RED}[ Model initialized --------------------- {RESET}{self.LLM_NAME} {RED}]{RESET}")
139
+ self._set_formatter(formatter)
140
+
141
+ def _process_sampling_params(self, sampling_params: SamplingParams | dict[str, Any] | None) -> SamplingParams:
142
+ processed_sampling_params: SamplingParams | None = None
143
+ if isinstance(sampling_params, dict):
144
+ processed_sampling_params = SamplingParams(**sampling_params)
145
+ logger.info(f"Converted sampling_params dict to SamplingParams: {processed_sampling_params}")
146
+ elif sampling_params is not None:
147
+ processed_sampling_params = sampling_params
148
+ else:
149
+ processed_sampling_params = self.model.get_default_sampling_params()
150
+
151
+ return processed_sampling_params
152
+
153
+ def _set_formatter(self, formatter: BaseFormatter | None = None) -> None:
154
+ if formatter is not None:
155
+ self._formatter = formatter
156
+ elif self.DEFAULT_FORMATTER is not None:
157
+ self._formatter = self.DEFAULT_FORMATTER()
158
+ elif self.tokenizer.chat_template is not None:
159
+ self._formatter = HFFormatter(self.LLM_NAME)
160
+ else:
161
+ raise ValueError("No formatter specified and no default formatter available.")
162
+
163
+ logger.info(
164
+ f"{RED}[ Using default formatter --------------------- {RESET}{self._formatter.__class__.__name__} {RED}]{RESET}" # noqa: E501
165
+ )
166
+
167
+ @property
168
+ def tokenizer(self) -> VLLMTokenizerAPI:
169
+ if self._tokenizer is None:
170
+ self._tokenizer = VLLMTokenizer(target_mdl=self.LLM_NAME)
171
+ return self._tokenizer
172
+
173
+ def count_tokens(self, text: str, /) -> int:
174
+ return len(self.tokenizer.encode_plain_text(text).tokens)
175
+
176
+ @property
177
+ def formatter_output_mode(self) -> Literal["string", "list"]:
178
+ return "string"
179
+
180
+ @property
181
+ def name(self) -> str:
182
+ if self.checkpoint_name:
183
+ return f"{self.__class__.__name__}_checkpoint_{self.checkpoint_name}"
184
+ return self.__class__.__name__
185
+
186
+ def build_redis_key_from_prompt_objs(
187
+ self, prompt_objs: list[TokenizedContainer], sampling_params: SamplingParams
188
+ ) -> Any:
189
+ """
190
+ Build a redis key from a list of prompt objects and sampling parameters.
191
+ TokenizedContainers are not serializable so we just pass the tokens and sampling params.
192
+ """
193
+ return ([obj.tokens for obj in prompt_objs], sampling_params)
194
+
195
+ def generate_from_messages(
196
+ self,
197
+ messages: list[Sequence[Message]],
198
+ stop_sequences: list[str] | None = None,
199
+ max_tokens: int | None = None,
200
+ temperature: float | None = None,
201
+ ) -> list[RawCompletion]:
202
+ raw_completions: list[RawCompletion | None] = [None] * len(messages)
203
+ prompt_objs = []
204
+ prompt_info = []
205
+
206
+ sampling_params = self._resolve_sampling_params(self.sampling_params, max_tokens, stop_sequences, temperature)
207
+
208
+ for i, single_messages in enumerate(messages):
209
+ output_mode = self.formatter_output_mode
210
+ prompt: str | list[Message] = self._formatter.format(single_messages, output_mode=output_mode)
211
+ prompt_obj: TokenizedContainer = self.tokenizer.encode_formatted_struct(prompt)
212
+ prompt_token_count = len(prompt_obj.tokens)
213
+
214
+ max_tokens_to_generate = self.max_seq_length - prompt_token_count
215
+
216
+ if max_tokens is not None:
217
+ max_tokens_to_generate = min(max_tokens_to_generate, max_tokens)
218
+
219
+ if max_tokens_to_generate < 1:
220
+ if raise_errors():
221
+ raise PromptTooLongException("Prompt exceeded context size.")
222
+
223
+ raw_completions[i] = RawCompletion(
224
+ prompt=prompt_obj.text,
225
+ prompt_sequence_positions=prompt_token_count,
226
+ completion="",
227
+ completion_sequence_positions=0,
228
+ raw_completion_error=Error(
229
+ error_class=PromptTooLongException.__name__,
230
+ message="Prompt exceeded context size.",
231
+ traceback="",
232
+ ),
233
+ )
234
+ continue
235
+
236
+ prompt_objs.append(prompt_obj)
237
+ prompt_info.append((i, single_messages))
238
+
239
+ if prompt_objs:
240
+ model_outputs = self._model_generate(prompt_objs=prompt_objs, sampling_params=sampling_params)
241
+
242
+ for (original_index, single_messages), prompt_obj, output in zip(prompt_info, prompt_objs, model_outputs):
243
+ raw_completions[original_index] = RawCompletion(
244
+ prompt=prompt_obj.text,
245
+ prompt_sequence_positions=len(output.prompt_token_ids) if output.prompt_token_ids else 0,
246
+ concat_compression=ConcatCompression.calculate(
247
+ single_messages, count_tokens=self.count_tokens, completion=output.outputs[0].text
248
+ ),
249
+ completion=output.outputs[0].text,
250
+ completion_sequence_positions=len(output.outputs[0].token_ids)
251
+ if output.outputs[0].token_ids
252
+ else 0,
253
+ raw_completion_error=None,
254
+ )
255
+
256
+ # Ensure all positions are filled (should never be None at this point)
257
+ return cast(list[RawCompletion], raw_completions)
258
+
259
+ @staticmethod
260
+ def _resolve_sampling_params(
261
+ sampling_params: SamplingParams,
262
+ max_tokens: int | None,
263
+ stop_sequences: list[str] | None,
264
+ temperature: float | None,
265
+ ) -> SamplingParams:
266
+ sampling_params.max_tokens = max_tokens
267
+ sampling_params.stop = stop_sequences
268
+ if temperature is not None:
269
+ logger.warning(
270
+ f"Overriding sampling params temperature {sampling_params.temperature} with custom value {temperature}"
271
+ )
272
+ sampling_params.temperature = temperature
273
+ else:
274
+ logger.info(
275
+ f"Using sampling params temperature value: {sampling_params.temperature} "
276
+ f"as no custom temperature value was provided"
277
+ )
278
+ return sampling_params
279
+
280
+ def _model_generate(
281
+ self,
282
+ prompt_objs: list[TokenizedContainer],
283
+ sampling_params: SamplingParams,
284
+ ) -> list[RequestOutput]:
285
+ vllm_token_prompt = [TokensPrompt(prompt_token_ids=prompt_obj.tokens) for prompt_obj in prompt_objs]
286
+ outputs = self.model.generate(vllm_token_prompt, sampling_params)
287
+
288
+ return outputs
289
+
290
+ def logprobs(self, samples: list[Sample]) -> list[RawLoglikelihood]:
291
+ """Batched version of logprobs for improved performance."""
292
+ results: list[RawLoglikelihood | None] = [None] * len(samples)
293
+
294
+ # Collect all prompt-choice combinations
295
+ batch_data = []
296
+ sample_choice_indices = [] # Maps batch index back to (sample_index, choice)
297
+
298
+ for sample_idx, sample in enumerate(samples):
299
+ output_mode = self.formatter_output_mode
300
+ prompt: str | list[Message] = self._formatter.format(sample.messages, output_mode=output_mode)
301
+ prompt_obj: TokenizedContainer = self.tokenizer.encode_formatted_struct(prompt)
302
+
303
+ choices_log_probs: dict[str, float] = {}
304
+ choices_log_probs_sequence_positions: dict[str, int] = {}
305
+ error: Error | None = None
306
+ valid_choices = []
307
+
308
+ for choice in sample.possible_completions or []:
309
+ choice_obj: TokenizedContainer = self.tokenizer.encode_plain_text(choice)
310
+ total_tokens_count = len(prompt_obj.tokens + choice_obj.tokens)
311
+
312
+ if total_tokens_count > self.max_seq_length:
313
+ if raise_errors():
314
+ raise PromptTooLongException("Prompt exceeded context size.")
315
+ choices_log_probs = {}
316
+ choices_log_probs_sequence_positions = {}
317
+ error = Error(
318
+ error_class=PromptTooLongException.__name__,
319
+ message="Prompt and choice exceeded context size.",
320
+ traceback="",
321
+ )
322
+ break
323
+ else:
324
+ batch_data.append((prompt_obj, choice_obj))
325
+ sample_choice_indices.append((sample_idx, choice))
326
+ valid_choices.append(choice)
327
+ choices_log_probs_sequence_positions[choice] = len(choice_obj.tokens)
328
+
329
+ # If we had an error, store the result immediately
330
+ if error is not None:
331
+ results[sample_idx] = RawLoglikelihood(
332
+ prompt=prompt_obj.text,
333
+ prompt_sequence_positions=len(prompt_obj.tokens),
334
+ loglikelihoods=choices_log_probs,
335
+ loglikelihoods_sequence_positions=choices_log_probs_sequence_positions,
336
+ raw_loglikelihood_error=error,
337
+ )
338
+ else:
339
+ results[sample_idx] = RawLoglikelihood(
340
+ prompt=prompt_obj.text,
341
+ prompt_sequence_positions=len(prompt_obj.tokens),
342
+ loglikelihoods=choices_log_probs,
343
+ loglikelihoods_sequence_positions=choices_log_probs_sequence_positions,
344
+ raw_loglikelihood_error=None,
345
+ concat_compression=ConcatCompression.calculate(
346
+ sample.messages, count_tokens=self.count_tokens, choices=valid_choices
347
+ ),
348
+ )
349
+
350
+ # Process batch if we have valid data
351
+ if batch_data:
352
+ batch_logprobs = self._model_log_probs(batch_data)
353
+
354
+ # Distribute results back to samples
355
+ for batch_idx, logprob in enumerate(batch_logprobs):
356
+ sample_idx, choice = sample_choice_indices[batch_idx]
357
+ result = results[sample_idx]
358
+ if result is not None:
359
+ result.loglikelihoods[choice] = logprob
360
+
361
+ return cast(list[RawLoglikelihood], results)
362
+
363
+ def _model_log_probs(self, batch_data: list[tuple[TokenizedContainer, TokenizedContainer]]) -> list[float]:
364
+ """Batched version of _model_log_probs for processing multiple prompt-choice pairs at once."""
365
+ sampling_params = SamplingParams(
366
+ max_tokens=1,
367
+ temperature=0.0,
368
+ prompt_logprobs=1,
369
+ detokenize=False,
370
+ )
371
+
372
+ vllm_token_prompts = [
373
+ TokensPrompt(prompt_token_ids=prompt_obj.tokens + choice_obj.tokens)
374
+ for prompt_obj, choice_obj in batch_data
375
+ ]
376
+
377
+ try:
378
+ outputs = self.model.generate(vllm_token_prompts, sampling_params)
379
+ except Exception as e:
380
+ raise e
381
+
382
+ results = []
383
+ for i, (prompt_obj, choice_obj) in enumerate(batch_data):
384
+ output = outputs[i]
385
+ assert output.prompt_logprobs is not None
386
+
387
+ choice_logprobs = output.prompt_logprobs[-len(choice_obj.tokens) :]
388
+ total_logprob = 0.0
389
+
390
+ # VLLM guarantees the actual token's logprob is included in the output
391
+ for j, token_id in enumerate(choice_obj.tokens):
392
+ logprob_obj = choice_logprobs[j]
393
+ assert logprob_obj is not None, f"logprob_obj is None: {logprob_obj}"
394
+ logprob_value = getattr(logprob_obj[token_id], "logprob")
395
+ assert logprob_value is not None, f"logprob_value is None: {logprob_value}"
396
+ total_logprob += logprob_value
397
+
398
+ results.append(total_logprob)
399
+
400
+ return results
401
+
402
+ @property
403
+ def max_seq_length(self) -> int:
404
+ """
405
+ Returns the maximum sequence length for this model.
406
+ Priority order:
407
+ 1. max_model_len parameter passed to __init__
408
+ 2. SEQ_LENGTH class attribute
409
+ 3. Model's actual max_model_len from config
410
+ 4. Default fallback of 2048
411
+ """
412
+ if self._max_model_len is not None:
413
+ return self._max_model_len
414
+
415
+ if self.SEQ_LENGTH is not None:
416
+ return self.SEQ_LENGTH
417
+
418
+ if hasattr(self.model, "llm_engine") and hasattr(self.model.llm_engine, "model_config"):
419
+ return self.model.llm_engine.model_config.max_model_len
420
+
421
+ return 2048
422
+
423
+ @property
424
+ def seq_length(self) -> int | None:
425
+ """
426
+ Kept for backward compatibility.
427
+ """
428
+ return self.max_seq_length
429
+
430
+
431
+ class Qwen3_0_6B_VLLM(VLLMModel):
432
+ LLM_NAME = "Qwen/Qwen3-0.6B"
433
+ DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME, chat_template_kwargs={"enable_thinking": True})
434
+
435
+
436
+ class Qwen3_0_6B_VLLM_No_Thinking(VLLMModel):
437
+ LLM_NAME = "Qwen/Qwen3-0.6B"
438
+ DEFAULT_FORMATTER = partial(HFFormatter, LLM_NAME, chat_template_kwargs={"enable_thinking": False})
@@ -0,0 +1,3 @@
1
+ import logging
2
+
3
+ logger = logging.getLogger("eval_framework")
eval_framework/main.py ADDED
@@ -0,0 +1,187 @@
1
+ import json
2
+ import logging
3
+ from collections.abc import Callable
4
+ from pathlib import Path
5
+ from typing import Any, Literal
6
+
7
+ import wandb
8
+
9
+ from eval_framework.evaluation_generator import EvaluationGenerator, Result
10
+ from eval_framework.llm.base import BaseLLM
11
+ from eval_framework.response_generator import ResponseGenerator
12
+ from eval_framework.result_processors.hf_processor import HFProcessor
13
+ from eval_framework.result_processors.result_processor import ResultsFileProcessor, generate_output_dir
14
+ from eval_framework.tasks.eval_config import EvalConfig
15
+ from eval_framework.utils.constants import RED, RESET
16
+ from eval_framework.utils.logging import setup_logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def main(
22
+ llm: BaseLLM,
23
+ config: EvalConfig,
24
+ should_preempt_callable: Callable[[], bool] | None = None,
25
+ trial_id: int | None = None,
26
+ ) -> list[Result]:
27
+ """Runs the entire evaluation process: responses generation and evaluation."""
28
+ # Set up centralized logging early
29
+ output_dir = generate_output_dir(llm.name, config)
30
+ print(f"Output directory for evaluation: {output_dir}")
31
+ setup_logging(output_dir=output_dir, log_level=logging.INFO, log_filename="evaluation.log")
32
+
33
+ logger.info(f"{RED}[ Running full evaluation process ------- ]{RESET}")
34
+ logger.info(f"Evaluating {llm.name} on {config.task_name}")
35
+ logger.info(f"Configuration: num_fewshot={config.num_fewshot}, num_samples={config.num_samples}")
36
+ logger.info(f"Output directory: {output_dir}")
37
+
38
+ if not should_preempt_callable:
39
+ should_preempt_callable = lambda: False # noqa: E731
40
+ preemption_data = None
41
+
42
+ if trial_id:
43
+ preemption_data = _read_preemption_data(config, trial_id)
44
+
45
+ if preemption_data is None:
46
+ output_dir = generate_output_dir(llm.name, config)
47
+ # config.wandb_run_id defaults to none, if no run_id is provided then it starts a new one
48
+ wandb_run_id = config.wandb_run_id
49
+ else:
50
+ logger.info("Found preempted run restarting ...")
51
+ output_dir = preemption_data["output_dir"]
52
+ wandb_run_id = preemption_data.get("wandb_run_id", None)
53
+
54
+ logger.info(f"Output directory: {output_dir}")
55
+ assert output_dir is not None
56
+
57
+ file_processor = ResultsFileProcessor(output_dir)
58
+ response_generator = ResponseGenerator(llm, config, file_processor)
59
+ # take care of init after preemption handling. If we have a run
60
+ # id from preemption, then we resume the original wandb run
61
+ with wandb.init(
62
+ entity=config.wandb_entity,
63
+ project=config.wandb_project,
64
+ group=llm.name,
65
+ job_type=config.task_name,
66
+ id=wandb_run_id,
67
+ config=response_generator._get_metadata(),
68
+ resume="allow",
69
+ mode=_wandb_mode(config.wandb_project),
70
+ ) as run:
71
+ _, preempted = response_generator.generate(should_preempt_callable)
72
+ if preempted:
73
+ logger.info("Response generation was preempted")
74
+ assert trial_id is not None
75
+ run.mark_preempting()
76
+ _save_preemption_data(config, trial_id, output_dir, wandb_run_id=run.id)
77
+ wandb.finish(exit_code=1)
78
+ return []
79
+ # update config from response generator with get metadata
80
+ if trial_id is not None:
81
+ _delete_preemption_file(config, trial_id)
82
+
83
+ evaluator = EvaluationGenerator(config, file_processor)
84
+ results = evaluator.run_eval()
85
+
86
+ if config.hf_upload_dir:
87
+ hf_processor = HFProcessor(config, output_dir)
88
+ status, hf_url = hf_processor.upload_responses_to_HF()
89
+ if not status:
90
+ status_message = "*** Warning: Result upload to HF failed ***"
91
+ else:
92
+ status_message = "Successfully uploaded results to HuggingFace"
93
+ if hf_url and run:
94
+ try:
95
+ run.notes = f"Results uploaded to HuggingFace: [{hf_url}]({hf_url})"
96
+ except Exception as e:
97
+ logger.warning(f"Failed to update wandb notes with HF URL: {e}")
98
+ else:
99
+ status_message = f"{RED}[ Results not persisted in a HuggingFace repo ------- ]{RESET}"
100
+
101
+ logger.info(status_message)
102
+
103
+ return results
104
+
105
+
106
+ def _read_preemption_data(config: EvalConfig, trial_id: int) -> dict[str, Any] | None:
107
+ preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json"
108
+ if not preemption_file.is_file():
109
+ return None
110
+ with open(preemption_file, "rb") as f:
111
+ preemption_data = json.load(f)
112
+ preemption_data["output_dir"] = Path(preemption_data["output_dir"])
113
+ preemption_data["wandb_run_id"] = preemption_data.get("wandb_run_id", "")
114
+ logger.info(f"Loaded preemption data from {preemption_file}")
115
+ return preemption_data
116
+
117
+
118
+ def _save_preemption_data(config: EvalConfig, trial_id: int, output_dir: Path, wandb_run_id: str = "") -> None:
119
+ preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json"
120
+ with open(preemption_file, "w") as f:
121
+ json.dump({"output_dir": str(output_dir), "wandb_run_id": wandb_run_id}, f)
122
+
123
+
124
+ def _delete_preemption_file(config: EvalConfig, trial_id: int) -> None:
125
+ preemption_file = config.output_dir / f"preemption_trial_{trial_id}.json"
126
+ if preemption_file.is_file():
127
+ preemption_file.unlink()
128
+ logger.info(f"Deleted preemption file: {preemption_file}")
129
+ else:
130
+ logger.info(f"No preemption file found to delete: {preemption_file}")
131
+ logger.info(f"Saved preemption data to {preemption_file}")
132
+
133
+
134
+ def _configure_logging(output_dir: Path) -> None:
135
+ """Configure logging to save logs to a file in the output directory."""
136
+
137
+ # Ensure the output directory exists
138
+ output_dir.mkdir(parents=True, exist_ok=True)
139
+
140
+ # Set up log file path
141
+ log_file = output_dir / "evaluation.log"
142
+
143
+ # Create file handler
144
+ file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
145
+ file_handler.setLevel(logging.INFO)
146
+
147
+ # Create formatter
148
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
149
+ file_handler.setFormatter(formatter)
150
+
151
+ # Get the root logger and add the file handler
152
+ root_logger = logging.getLogger()
153
+
154
+ # Remove existing file handlers to avoid duplicates
155
+ for handler in root_logger.handlers[:]:
156
+ if isinstance(handler, logging.FileHandler):
157
+ root_logger.removeHandler(handler)
158
+
159
+ root_logger.addHandler(file_handler)
160
+
161
+ # Set logging level if not already set
162
+ if root_logger.level == logging.NOTSET:
163
+ root_logger.setLevel(logging.INFO)
164
+
165
+
166
+ def _wandb_mode(project: str | None) -> Literal["online", "disabled"] | None:
167
+ """
168
+ Checks to see if a WandB API key is found. If not, wandb starts in offline mode.
169
+ """
170
+ if project is None:
171
+ logger.warning("No WandB project specified, disabling logging.")
172
+ return "disabled"
173
+ else:
174
+ try:
175
+ api_key = wandb.api.api_key
176
+ if api_key is None:
177
+ logger.warning(
178
+ """No wandb API key found. Disabling Wandb logging.
179
+ If you have a WandB account set the environment variable 'WANDB_API_KEY'"""
180
+ )
181
+ return "disabled"
182
+ else:
183
+ logger.info("wandb login detected. Using online mode.")
184
+ except Exception as e:
185
+ logger.warning(f"wandb login check failed: {e}. Disabling Wandb logging.")
186
+ return "disabled"
187
+ return "online"
File without changes
@@ -0,0 +1,40 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from pydantic import BaseModel, ConfigDict
5
+
6
+ from eval_framework.shared.types import Error
7
+
8
+
9
+ class MetricResult(BaseModel):
10
+ model_config = ConfigDict(extra="forbid")
11
+ metric_name: str
12
+ value: float | None
13
+ higher_is_better: bool
14
+ llm_judge_prompt: str | None = None
15
+ llm_judge_response: str | None = None
16
+ code_execution_trace: str | None = None
17
+ error: Error | None = None
18
+
19
+
20
+ class classproperty:
21
+ def __init__(self, method: Any) -> None:
22
+ self.method = method
23
+
24
+ def __get__(self, instance: Any, cls: Any) -> Any:
25
+ return self.method(cls)
26
+
27
+
28
+ class BaseMetric[Response](ABC):
29
+ NAME: str
30
+ KEYS: list[str] | None = None
31
+
32
+ @classproperty
33
+ def NAMES(cls) -> list[str]:
34
+ if cls.KEYS is None:
35
+ return [cls.NAME]
36
+ return [f"{cls.NAME}/{k}" for k in cls.KEYS]
37
+
38
+ @abstractmethod
39
+ def calculate(self, response: Response) -> list[MetricResult]:
40
+ raise NotImplementedError
@@ -0,0 +1 @@
1
+ from .accuracy_completion import AccuracyCompletion
@@ -0,0 +1,16 @@
1
+ from eval_framework.metrics.base import BaseMetric, MetricResult
2
+ from eval_framework.shared.types import Completion
3
+
4
+
5
+ class AccuracyCompletion(BaseMetric[Completion]):
6
+ NAME = "Accuracy Completion"
7
+
8
+ def calculate(self, response: Completion) -> list[MetricResult]:
9
+ if response.error is not None:
10
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
11
+
12
+ ground_truths = response.ground_truth_list
13
+ is_correct = any(response.completion == gt for gt in ground_truths)
14
+ return [
15
+ MetricResult(metric_name=self.NAME, value=float(is_correct), higher_is_better=True, error=response.error)
16
+ ]