eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,42 @@
1
+ from eval_framework.metrics.base import MetricResult
2
+ from eval_framework.metrics.loglikelihood.base import BaseLoglikelihoodMetric
3
+ from eval_framework.shared.types import Loglikelihood
4
+
5
+
6
+ class TernaryScore(BaseLoglikelihoodMetric):
7
+ """Based on Kalai et al. (2025) Why language models hallucinate. arXiv:2509.04664"""
8
+
9
+ NAME = "Ternary Score"
10
+
11
+ def __init__(
12
+ self,
13
+ *,
14
+ lc: float = 1.0, # Default reward for correct answers
15
+ lw: float = 1.0, # Default penalty for wrong answers (note: this will be negated in the score)
16
+ len_normalised: bool = True,
17
+ ) -> None:
18
+ super().__init__(len_normalised=len_normalised)
19
+ self._lc = float(lc)
20
+ self._lw = float(lw)
21
+ if not (self._lc >= 0 and self._lw >= 0):
22
+ raise ValueError(f"Invalid reward and penalty values: lc={self._lc}, lw={self._lw}. Require lc>=0, lw>=0.")
23
+
24
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
25
+ if response.error is not None:
26
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
27
+
28
+ loglikelihoods, probs = self._compute_probabilities(response.loglikelihoods)
29
+ ground_truths = self._gather_ground_truths(response)
30
+
31
+ completion_text = max(loglikelihoods, key=loglikelihoods.get) # type: ignore[arg-type]
32
+ norm_text = self._normalise_text(completion_text)
33
+ idk_key = self._normalise_text(list(response.loglikelihoods.keys())[-1]) # assumes last key is "IDK" option
34
+
35
+ if norm_text in ground_truths:
36
+ score = self._lc
37
+ elif norm_text == idk_key:
38
+ score = 0.0
39
+ else:
40
+ score = -self._lw
41
+
42
+ return [MetricResult(metric_name=self.NAME, value=score, higher_is_better=True, error=response.error)]
File without changes
@@ -0,0 +1,351 @@
1
+ import logging
2
+ import time
3
+ import traceback
4
+ from collections.abc import Callable
5
+ from datetime import UTC, datetime
6
+ from functools import partial
7
+ from typing import Any
8
+
9
+ from eval_framework.tasks.registry import get_task
10
+
11
+ try:
12
+ from determined._info import get_cluster_info
13
+ except ImportError:
14
+ get_cluster_info = None # type: ignore[assignment]
15
+
16
+
17
+ from tqdm import tqdm
18
+
19
+ from eval_framework import __version__ as eval_framework_version
20
+ from eval_framework.llm.base import BaseLLM
21
+ from eval_framework.result_processors.result_processor import ResultsFileProcessor
22
+ from eval_framework.shared.types import (
23
+ Completion,
24
+ Error,
25
+ Loglikelihood,
26
+ RawLoglikelihood,
27
+ )
28
+ from eval_framework.tasks.base import Language, ResponseType, Sample
29
+ from eval_framework.tasks.eval_config import EvalConfig
30
+ from eval_framework.tasks.perturbation import create_perturbation_class
31
+ from eval_framework.tasks.utils import raise_errors
32
+ from eval_framework.utils.constants import RED, RESET
33
+ from eval_framework.utils.tqdm_handler import get_disable_bar_flag, safe_tqdm_write
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def map_language_to_value(
39
+ language: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None,
40
+ ) -> str | dict[str, str] | dict[str, tuple[str, str]] | None:
41
+ if language is None:
42
+ return None
43
+ elif isinstance(language, Language):
44
+ return language.value
45
+ elif isinstance(language, dict):
46
+ if isinstance(list(language.values())[0], Language):
47
+ return {k: v.value for k, v in language.items()} # type: ignore[union-attr]
48
+ else:
49
+ return {k: (v[0].value, v[1].value) for k, v in language.items()} # type: ignore[index]
50
+ else:
51
+ raise ValueError(f"Invalid language: {language}")
52
+
53
+
54
+ class ResponseGenerator:
55
+ def __init__(self, llm: BaseLLM, config: EvalConfig, result_processor: ResultsFileProcessor) -> None:
56
+ self.few_shot = config.num_fewshot
57
+ self.task_name = config.task_name
58
+ self.llm = llm
59
+ self.config = config
60
+ self.result_processor = result_processor
61
+ self.num_samples = config.num_samples
62
+ self.save_intermediate_results = config.save_intermediate_results
63
+
64
+ task_class = get_task(config.task_name)
65
+
66
+ if config.perturbation_config is not None:
67
+ perturbation_task_class = create_perturbation_class(task_class, config.perturbation_config)
68
+ self.task = perturbation_task_class.with_overwrite(
69
+ self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision
70
+ )
71
+ else:
72
+ self.task = task_class.with_overwrite(
73
+ self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision
74
+ )
75
+
76
+ self.response_type = task_class.RESPONSE_TYPE
77
+
78
+ def _llm_task_param_precedence(self) -> tuple[list[str] | None, int | None]:
79
+ """
80
+ sets the stop_sequences and max_tokens values to be used in the completion generation.
81
+ Max token and stop sequence values have an order of precedence:
82
+
83
+ LLM attributes take precedence over task attributes, and therefore overload them.
84
+ :return: stop_sequences, max_tokens
85
+ """
86
+ llm_stop_sequences = getattr(self.llm, "stop_sequences", None)
87
+ llm_max_tokens = getattr(self.llm, "max_tokens", None)
88
+ task_stop_sequences = getattr(self.task, "stop_sequences", None)
89
+ task_max_tokens = self.config.max_tokens or getattr(self.task, "max_tokens", None)
90
+ # if both task and model define a max_token, the smaller value is used
91
+ max_tokens = min([x for x in [llm_max_tokens, task_max_tokens] if x is not None], default=None)
92
+ logger.info(f"Set max_tokens to {max_tokens}")
93
+ # if both task and model define stop sequences, those are merged into one list
94
+ stop_sequences_merged = (llm_stop_sequences or []) + (task_stop_sequences or [])
95
+ stop_sequences = sorted(list(set(stop_sequences_merged))) if stop_sequences_merged else None
96
+ logger.info(f"Set stop_sequences to {stop_sequences}")
97
+ return stop_sequences, max_tokens
98
+
99
+ def _generate_loglikelihoods(self, samples: list[Sample]) -> list[Loglikelihood]:
100
+ """
101
+ Generate log likelihoods when a sample is run against the model.
102
+ :param sample: sample to run the task against
103
+ :return: loglikelihoods
104
+ """
105
+ raw_loglikelihoods: list[RawLoglikelihood]
106
+ try:
107
+ raw_loglikelihoods = self.llm.logprobs(samples)
108
+ except Exception as e:
109
+ if raise_errors():
110
+ raise e
111
+ logger.info(f"Error: {e.__class__.__name__} {e}")
112
+ assert len(samples) == 1, "LLMs not handling errors are not supported in batch mode"
113
+ raw_loglikelihoods = [
114
+ RawLoglikelihood(
115
+ prompt="",
116
+ prompt_sequence_positions=0,
117
+ loglikelihoods={},
118
+ loglikelihoods_sequence_positions={},
119
+ raw_loglikelihood_error=Error(
120
+ error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
121
+ ),
122
+ )
123
+ for _ in range(len(samples))
124
+ ]
125
+
126
+ loglikelihood_list = []
127
+ for idx, sample in enumerate(samples):
128
+ raw_loglikelihood = raw_loglikelihoods[idx]
129
+ assert sample.ground_truth is not None
130
+ loglikelihood_list.append(
131
+ Loglikelihood(
132
+ id=sample.id,
133
+ subject=sample.subject,
134
+ ground_truth=sample.ground_truth,
135
+ prompt=raw_loglikelihood.prompt,
136
+ prompt_sequence_positions=raw_loglikelihood.prompt_sequence_positions,
137
+ concat_compression=raw_loglikelihood.concat_compression,
138
+ loglikelihoods=raw_loglikelihood.loglikelihoods,
139
+ loglikelihoods_sequence_positions=raw_loglikelihood.loglikelihoods_sequence_positions,
140
+ error=raw_loglikelihood.raw_loglikelihood_error,
141
+ )
142
+ )
143
+ return loglikelihood_list
144
+
145
+ def _generative_output_type_selector(self) -> Callable[[list[Sample]], list[Completion] | list[Loglikelihood]]:
146
+ """
147
+ Selects the generative output type based on the response type.
148
+ :return: function to generate responses
149
+ """
150
+ match self.response_type:
151
+ case ResponseType.COMPLETION:
152
+ stop_sequences, max_tokens = self._llm_task_param_precedence()
153
+ return partial(
154
+ self.task.generate_completions, self.llm, stop_sequences=stop_sequences, max_tokens=max_tokens
155
+ ) # type: ignore[call-arg]
156
+ case ResponseType.LOGLIKELIHOODS:
157
+ return self._generate_loglikelihoods
158
+ case _:
159
+ raise KeyError(f"Task type {self.task} not supported")
160
+
161
+ def _run_task_against_model(
162
+ self, should_preempt_callable: Callable[[], bool]
163
+ ) -> tuple[list[Completion | Loglikelihood], bool]:
164
+ """
165
+ Runs the task against the model and generates responses.
166
+ :param should_preempt_callable: function to check if preempt is called
167
+ :return: list of responses, preempted
168
+ """
169
+ logger.info(f"{RED}[ Running task {self.task.NAME} against model ------------ ]{RESET}")
170
+ self.start_time, monotonic_start = time.time(), time.monotonic()
171
+ run_fn = self._generative_output_type_selector()
172
+ self._verify_loaded_metadata_compatibility()
173
+ responses = self.result_processor.load_responses() # load responses if present
174
+ subject_response_id_mapping = self._map_subject_response_ids(responses)
175
+ self.result_processor.save_metadata(self._get_metadata())
176
+ responses, preempted = self._curate_responses(
177
+ responses, subject_response_id_mapping, run_fn, should_preempt_callable
178
+ )
179
+ self.end_time, monotonic_end = time.time(), time.monotonic()
180
+ self.total_time = monotonic_end - monotonic_start
181
+ self.result_processor.save_metadata(self._get_metadata()) # overwrite with updated timing
182
+
183
+ return responses, preempted
184
+
185
+ def _map_subject_response_ids(self, responses: list[Completion | Loglikelihood]) -> dict[str, set[int]]:
186
+ """
187
+ Maps subject to response id
188
+ :param responses: list of responses
189
+ :return: mapping of subject to response id
190
+ """
191
+ subject_response_id_mapping = {}
192
+ if responses:
193
+ response_subjects = {resp.subject for resp in responses}
194
+ subject_response_id_mapping = {
195
+ response_subject: set([resp.id for resp in responses if resp.subject == response_subject])
196
+ for response_subject in response_subjects
197
+ }
198
+
199
+ return subject_response_id_mapping
200
+
201
+ def _curate_responses(
202
+ self,
203
+ responses: list[Completion | Loglikelihood],
204
+ subject_response_id_mapping: dict[str, set[int]],
205
+ generative_output_function: Callable[[list[Sample]], list[Completion] | list[Loglikelihood]],
206
+ should_preempt_callable: Callable[[], bool],
207
+ ) -> tuple[list[Completion | Loglikelihood], bool]:
208
+ """
209
+ Generates responses for the task and saves them along with metadata.
210
+ :param responses: list of responses
211
+ :param subject_response_id_mapping: mapping of subject to response id
212
+ :param generative_output_function: function to generate responses
213
+ :param metadata: metadata dictionary
214
+ :param should_preempt_callable: function to check if preempt is called
215
+ :return: None
216
+ """
217
+
218
+ def _process_batch(samples_batch: list[Sample]) -> None:
219
+ if not samples_batch:
220
+ return
221
+ if len(samples_batch) > 1:
222
+ log_msg = "Processing batch..."
223
+ logger.info(log_msg) # For log files
224
+ safe_tqdm_write(log_msg) # For console display with tqdm
225
+
226
+ responses_batch = generative_output_function(samples_batch)
227
+ responses.extend(responses_batch)
228
+ if self.save_intermediate_results:
229
+ for response in responses_batch:
230
+ self.result_processor.save_response(response)
231
+
232
+ # In order to enable parallelism we group samples in batches and send them in parallel to the `run_fn`.
233
+ # The BaseLLM class is then in charge of managing the parallelism (eg, using AsyncClient in API models).
234
+ # If samples_batch_size = 1, samples are run sequentially; in any case, we return here after finishing each
235
+ # individual batch to honor preemption requests and save cached results.
236
+ samples_batch_size = self.config.batch_size
237
+
238
+ # Calculate total samples for progress bar - use num_samples or iterate to count
239
+ total_num_samples = self.num_samples
240
+ if total_num_samples is None:
241
+ # Count samples by iterating (this might be expensive for large datasets)
242
+ total_num_samples = sum(1 for _ in self.task.iterate_samples(None))
243
+
244
+ samples_batch: list[Sample] = []
245
+ with tqdm(
246
+ total=total_num_samples, desc=f"Processing {self.response_type.value}", disable=get_disable_bar_flag()
247
+ ) as pbar:
248
+ for i, sample in enumerate(self.task.iterate_samples(self.num_samples)):
249
+ subject = f" - Subject: {sample.subject}"
250
+ sample_index = i + 1
251
+
252
+ if sample.id in subject_response_id_mapping.get(sample.subject, []):
253
+ log_msg = (
254
+ f"Task: {self.response_type.value}{subject} - Sample: {sample_index} - skipping, already done."
255
+ )
256
+ logger.info(log_msg) # For log files
257
+ safe_tqdm_write(log_msg) # For console display with tqdm
258
+ pbar.update(1)
259
+ continue
260
+
261
+ log_msg = f"Task: {self.response_type.value}{subject} - Sample: {sample_index}/{total_num_samples}"
262
+ logger.info(log_msg) # For log files
263
+ safe_tqdm_write(log_msg) # For console display with tqdm
264
+ pbar.set_postfix_str(f"Sample {sample_index}/{total_num_samples}")
265
+ pbar.update(1)
266
+
267
+ samples_batch.append(sample)
268
+
269
+ if len(samples_batch) >= samples_batch_size:
270
+ _process_batch(samples_batch)
271
+ samples_batch = []
272
+
273
+ if should_preempt_callable():
274
+ log_msg = "Preempt"
275
+ logger.info(log_msg) # For log files
276
+ safe_tqdm_write(log_msg) # For console display with tqdm
277
+ if not self.save_intermediate_results:
278
+ self.result_processor.save_responses(responses)
279
+ return responses, True
280
+
281
+ _process_batch(samples_batch)
282
+
283
+ if not self.save_intermediate_results:
284
+ self.result_processor.save_responses(responses)
285
+ return responses, False
286
+
287
+ def _get_metadata(self) -> dict[str, Any]:
288
+ """Prepares metadata dictionary from the configuration."""
289
+ all_metrics = getattr(self.task, "METRICS", None)
290
+ metadata = self.config.model_dump(mode="json")
291
+ metadata["llm_name"] = self.llm.name
292
+ metadata["task_name"] = self.task_name
293
+ language = getattr(self.task, "LANGUAGE", None)
294
+ metadata["language"] = map_language_to_value(language)
295
+ metadata["metrics"] = [m.NAME for m in all_metrics] if all_metrics is not None else []
296
+ metadata["primary_metrics"] = getattr(self.task, "PRIMARY_METRICS", None)
297
+ metadata["eval_framework_version"] = eval_framework_version
298
+ metadata["task_output_dir"] = str(self.result_processor.output_dir)
299
+ if hasattr(self, "total_time"):
300
+ metadata["start_time"] = str(datetime.fromtimestamp(self.start_time, UTC))
301
+ metadata["end_time"] = str(datetime.fromtimestamp(self.end_time, UTC))
302
+ metadata["total_time"] = self.total_time
303
+
304
+ # add task specific metadata
305
+ metadata["task_metadata"] = self.task.get_metadata()
306
+
307
+ try:
308
+ assert get_cluster_info is not None, "Determined cluster info not available"
309
+ info = get_cluster_info()
310
+ if info is not None:
311
+ metadata["determined_agent_id"] = info.agent_id
312
+ if info.task_type == "TRIAL":
313
+ metadata["determined_experiment_id"] = info.trial.experiment_id
314
+ metadata["determined_trial_id"] = info.trial.trial_id
315
+ except Exception as e:
316
+ logger.info(f"{e}; cluster info not available in local context")
317
+
318
+ return metadata
319
+
320
+ def _verify_loaded_metadata_compatibility(self) -> None:
321
+ if not (loaded_metadata := self.result_processor.load_metadata()):
322
+ return
323
+ current_metadata = self._get_metadata()
324
+ # check if crucial keys in metadata are the same as in the previous run
325
+ keys = [
326
+ "task_name",
327
+ "task_subjects",
328
+ "num_fewshot",
329
+ "num_samples",
330
+ "llm_name",
331
+ "llm_args",
332
+ "perturbation_config",
333
+ ]
334
+ for key in keys:
335
+ if loaded_metadata[key] != current_metadata[key]:
336
+ raise ValueError(f"Existing metadata does not match current metadata for {key}.")
337
+
338
+ def __del__(self) -> None:
339
+ self.llm.__del__()
340
+
341
+ def generate(self, should_preempt_callable: Callable[[], bool]) -> tuple[list[Completion | Loglikelihood], bool]:
342
+ """Generates responses and saves them along with metadata.
343
+ :param should_preempt_callable: function to check if preempt is called
344
+ :return: list of responses, preempted: whether the process was preempted or not
345
+ """
346
+ logger.info(f"{RED}[ Running responses generation ---------- ]{RESET}")
347
+ logger.info(f"{RED}[ Will save into {self.result_processor.output_dir} ---------- ]{RESET}")
348
+ responses, preempted = self._run_task_against_model(should_preempt_callable)
349
+ logger.info("Completions generated and saved.")
350
+
351
+ return responses, preempted
File without changes
@@ -0,0 +1,88 @@
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+
4
+ from dotenv import load_dotenv
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ from eval_framework.shared.types import Completion, Error, Loglikelihood
8
+ from eval_framework.tasks.eval_config import EvalConfig
9
+
10
+ MAIN = "eval_framework_results"
11
+
12
+ load_dotenv()
13
+
14
+
15
+ class Result(BaseModel):
16
+ model_config = ConfigDict(extra="forbid")
17
+ id: int
18
+ subject: str
19
+ num_fewshot: int
20
+ llm_name: str
21
+ task_name: str
22
+ metric_class_name: str
23
+ metric_name: str
24
+ key: str | None
25
+ value: float | None
26
+ higher_is_better: bool
27
+ prompt: str
28
+ response: str
29
+ llm_judge_prompt: str | None = None
30
+ llm_judge_response: str | None = None
31
+ code_execution_trace: str | None = None
32
+ error: Error | None = None
33
+
34
+
35
+ class ResultProcessor(ABC):
36
+ @abstractmethod
37
+ def save_metadata(self, metadata: dict) -> None:
38
+ """Save metadata."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def load_metadata(self) -> dict:
43
+ """Load metadata."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def save_responses(self, responses: list[Completion | Loglikelihood]) -> None:
48
+ """Save a list of response objects (overwrite a file)."""
49
+ pass
50
+
51
+ @abstractmethod
52
+ def save_response(self, response: Completion | Loglikelihood) -> None:
53
+ """Save a single response object (append into a file)."""
54
+ pass
55
+
56
+ @abstractmethod
57
+ def load_responses(self) -> list[Completion | Loglikelihood]:
58
+ """Load a list of response objects."""
59
+ pass
60
+
61
+ @abstractmethod
62
+ def save_metrics_results(self, results: list[Result]) -> None:
63
+ """Save the results of the metrics (overwrite a file)."""
64
+ pass
65
+
66
+ @abstractmethod
67
+ def save_metrics_result(self, result: Result) -> None:
68
+ """Save a single metric result (append into a file)."""
69
+ pass
70
+
71
+ @abstractmethod
72
+ def save_aggregated_results(self, result: dict[str, float | None]) -> None:
73
+ """Save the aggregated results."""
74
+ pass
75
+
76
+ @abstractmethod
77
+ def load_metrics_results(self) -> list[Result]:
78
+ """Load the aggregated results."""
79
+ pass
80
+
81
+
82
+ class ResultsUploader(ABC):
83
+ @abstractmethod
84
+ def upload(self, llm_name: str, config: EvalConfig, output_dir: Path) -> bool:
85
+ """Upload relevant parts from `output_dir` to the desired destination.
86
+ Returns True if upload was successful, False otherwise.
87
+ """
88
+ pass
@@ -0,0 +1,75 @@
1
+ """
2
+ Module for writing result folder and its contents to HuggingFace
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from pathlib import Path
8
+
9
+ import wandb
10
+ from huggingface_hub import HfApi, login
11
+
12
+ from eval_framework.result_processors.base import ResultsUploader
13
+ from eval_framework.tasks.eval_config import EvalConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class HFUploader(ResultsUploader):
19
+ def __init__(self, config: EvalConfig):
20
+ if not config.hf_upload_dir:
21
+ logger.warning("Results will not be persisted in HuggingFace (`hf_upload_dir` not configured).")
22
+ return
23
+ if config.output_dir is None:
24
+ raise ValueError("Output directory is not set in the configuration.")
25
+ if not config.hf_upload_repo:
26
+ raise ValueError("HuggingFace upload repository is not set in the configuration.")
27
+
28
+ self.hf_api = HFUploader._login_into_hf()
29
+ if self.hf_api is None:
30
+ logger.error("Could not login into HuggingFace (check HF_TOKEN). Results not persisted in HuggingFace.")
31
+
32
+ def upload(self, llm_name: str, config: EvalConfig, output_dir: Path) -> bool:
33
+ if not hasattr(self, "hf_api") or self.hf_api is None:
34
+ return False
35
+ assert config.hf_upload_repo and config.hf_upload_dir
36
+
37
+ rel_upload_dir = output_dir.relative_to(config.output_dir)
38
+ upload_dir = Path(config.hf_upload_dir.replace("/", "")) / rel_upload_dir
39
+ logger.info(f"HuggingFace upload starting to: {upload_dir}")
40
+
41
+ upload_counter = 0
42
+ for fp in output_dir.iterdir():
43
+ if fp.name not in ["aggregated_results.json", "metadata.json"]:
44
+ logger.info(f"Skipping {fp}.")
45
+ else:
46
+ try:
47
+ self.hf_api.upload_file(
48
+ path_or_fileobj=str(fp),
49
+ path_in_repo=str(upload_dir / fp.name),
50
+ repo_id=config.hf_upload_repo,
51
+ repo_type="dataset",
52
+ )
53
+ upload_counter += 1
54
+ except Exception as e:
55
+ logger.error("Problem during HF file upload: " + str(e))
56
+ return False
57
+
58
+ hf_url = f"https://huggingface.co/datasets/{config.hf_upload_repo}/tree/main/{upload_dir}"
59
+ logger.info(f"Uploaded {upload_counter} result files to {hf_url}.")
60
+
61
+ if wandb.run is not None:
62
+ try:
63
+ wandb.run.notes = f"Results uploaded to HuggingFace: [{hf_url}]({hf_url})"
64
+ except Exception as e:
65
+ logger.warning(f"Failed to update wandb notes with HF URL: {e}")
66
+ return True
67
+
68
+ @classmethod
69
+ def _login_into_hf(cls) -> HfApi | None:
70
+ try:
71
+ login(token=os.environ.get("HF_TOKEN", ""))
72
+ logger.debug("logged into HF")
73
+ return HfApi()
74
+ except Exception:
75
+ return None