eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,51 @@
1
+ from eval_framework.metrics.base import BaseMetric, MetricResult
2
+ from eval_framework.shared.types import Loglikelihood
3
+
4
+
5
+ class AccuracyLoglikelihood(BaseMetric[Loglikelihood]):
6
+ NAME = "Accuracy Loglikelihood"
7
+
8
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
9
+ if response.error is not None:
10
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
11
+
12
+ ground_truth_list = response.ground_truth_list
13
+ completion_text = max(response.loglikelihoods, key=response.loglikelihoods.get) # type: ignore[arg-type]
14
+
15
+ return [
16
+ MetricResult(
17
+ metric_name=self.NAME,
18
+ value=float(completion_text in ground_truth_list),
19
+ higher_is_better=True,
20
+ error=response.error,
21
+ )
22
+ ]
23
+
24
+
25
+ class AccuracyNormLoglikelihood(BaseMetric[Loglikelihood]):
26
+ NAME = "Accuracy Normalized Loglikelihood"
27
+
28
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
29
+ if response.error is not None:
30
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
31
+
32
+ ground_truth_list = response.ground_truth_list
33
+
34
+ output_len_normalized = {}
35
+ for k, v in response.loglikelihoods.items():
36
+ completion_length = len(k)
37
+
38
+ if completion_length != 0:
39
+ output_len_normalized[k] = v / completion_length
40
+ else:
41
+ output_len_normalized[k] = v
42
+
43
+ model_output_len_normalized = max(output_len_normalized, key=output_len_normalized.get) # type:ignore
44
+ return [
45
+ MetricResult(
46
+ metric_name=self.NAME,
47
+ value=float(model_output_len_normalized in ground_truth_list),
48
+ higher_is_better=True,
49
+ error=response.error,
50
+ )
51
+ ]
@@ -0,0 +1,56 @@
1
+ import numpy as np
2
+
3
+ from eval_framework.metrics.base import BaseMetric, MetricResult
4
+ from eval_framework.shared.types import Loglikelihood
5
+
6
+
7
+ class ProbabilityMass(BaseMetric[Loglikelihood]):
8
+ NAME = "Probability Mass"
9
+
10
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
11
+ if response.error is not None:
12
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
13
+
14
+ assert isinstance(response.ground_truth, str)
15
+ # https://docs.python.org/3.10/library/stdtypes.html?highlight=dictview#dictionary-view-objects
16
+ possible_completions = list(response.loglikelihoods.keys())
17
+ ground_truth_index = possible_completions.index(response.ground_truth)
18
+ split_idx = ground_truth_index + 1
19
+
20
+ log_probs = list(response.loglikelihoods.values())
21
+ probs = np.exp(log_probs) / np.sum(np.exp(log_probs))
22
+ prob_mass = np.sum(probs[:split_idx])
23
+
24
+ return [
25
+ MetricResult(metric_name=self.NAME, value=float(prob_mass), higher_is_better=True, error=response.error)
26
+ ]
27
+
28
+
29
+ class ProbabilityMassNorm(BaseMetric[Loglikelihood]):
30
+ NAME = "Probability Mass Normalized"
31
+
32
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
33
+ if response.error is not None:
34
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
35
+
36
+ assert isinstance(response.ground_truth, str)
37
+ # len normalized
38
+
39
+ output_len_normalized = {}
40
+ for k, v in response.loglikelihoods.items():
41
+ completion_length = len(k)
42
+
43
+ if completion_length != 0:
44
+ output_len_normalized[k] = v / completion_length
45
+ else:
46
+ output_len_normalized[k] = v
47
+
48
+ possible_completions = list(response.loglikelihoods.keys())
49
+ ground_truth_index = possible_completions.index(response.ground_truth)
50
+ split_idx = ground_truth_index + 1
51
+
52
+ log_probs = list(output_len_normalized.values())
53
+ probs = np.exp(log_probs) / np.sum(np.exp(log_probs))
54
+ prob_mass_norm = np.sum(probs[:split_idx])
55
+
56
+ return [MetricResult(metric_name=self.NAME, value=prob_mass_norm, higher_is_better=True, error=response.error)]
File without changes
@@ -0,0 +1,416 @@
1
+ import logging
2
+ import time
3
+ import traceback
4
+ from collections.abc import Callable
5
+ from datetime import UTC, datetime
6
+ from functools import partial
7
+ from typing import Any
8
+
9
+ from eval_framework.tasks.registry import get_task
10
+
11
+ try:
12
+ from determined._info import get_cluster_info
13
+ except ImportError:
14
+ get_cluster_info = None # type: ignore[assignment]
15
+
16
+
17
+ from tqdm import tqdm
18
+
19
+ from eval_framework import __version__ as eval_framework_version
20
+ from eval_framework.llm.base import BaseLLM
21
+ from eval_framework.result_processors.result_processor import ResultsFileProcessor
22
+ from eval_framework.shared.types import (
23
+ Completion,
24
+ Error,
25
+ Loglikelihood,
26
+ RawCompletion,
27
+ RawLoglikelihood,
28
+ )
29
+ from eval_framework.tasks.base import Language, ResponseType, Sample
30
+ from eval_framework.tasks.eval_config import EvalConfig
31
+ from eval_framework.tasks.perturbation import create_perturbation_class
32
+ from eval_framework.tasks.utils import raise_errors
33
+ from eval_framework.utils.constants import RED, RESET
34
+ from template_formatting.formatter import Message, Role
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def map_language_to_value(
40
+ language: Language | dict[str, Language] | dict[str, tuple[Language, Language]] | None,
41
+ ) -> str | dict[str, str] | dict[str, tuple[str, str]] | None:
42
+ if language is None:
43
+ return None
44
+ elif isinstance(language, Language):
45
+ return language.value
46
+ elif isinstance(language, dict):
47
+ if isinstance(list(language.values())[0], Language):
48
+ return {k: v.value for k, v in language.items()} # type: ignore[union-attr]
49
+ else:
50
+ return {k: (v[0].value, v[1].value) for k, v in language.items()} # type: ignore[index]
51
+ else:
52
+ raise ValueError(f"Invalid language: {language}")
53
+
54
+
55
+ class ResponseGenerator:
56
+ def __init__(self, llm: BaseLLM, config: EvalConfig, result_processor: ResultsFileProcessor) -> None:
57
+ self.few_shot = config.num_fewshot
58
+ self.task_name = config.task_name
59
+ self.llm = llm
60
+ self.config = config
61
+ self.result_processor = result_processor
62
+ self.num_samples = config.num_samples
63
+ self.save_intermediate_results = config.save_intermediate_results
64
+
65
+ task_class = get_task(config.task_name)
66
+
67
+ if config.perturbation_config is not None:
68
+ perturbation_task_class = create_perturbation_class(task_class, config.perturbation_config)
69
+ self.task = perturbation_task_class.with_overwrite(
70
+ self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision
71
+ )
72
+ else:
73
+ self.task = task_class.with_overwrite(
74
+ self.few_shot, custom_subjects=self.config.task_subjects, custom_hf_revision=self.config.hf_revision
75
+ )
76
+
77
+ self.response_type = task_class.RESPONSE_TYPE
78
+
79
+ def _llm_task_param_precedence(self) -> tuple[list[str] | None, int | None]:
80
+ """
81
+ sets the stop_sequences and max_tokens values to be used in the completion generation.
82
+ Max token and stop sequence values have an order of precedence:
83
+
84
+ LLM attributes take precedence over task attributes, and therefore overload them.
85
+ :return: stop_sequences, max_tokens
86
+ """
87
+ llm_stop_sequences = getattr(self.llm, "stop_sequences", None)
88
+ llm_max_tokens = getattr(self.llm, "max_tokens", None)
89
+ task_stop_sequences = getattr(self.task, "stop_sequences", None)
90
+ task_max_tokens = self.config.max_tokens or getattr(self.task, "max_tokens", None)
91
+ # if both task and model define a max_token, the smaller value is used
92
+ max_tokens = min([x for x in [llm_max_tokens, task_max_tokens] if x is not None], default=None)
93
+ logger.info(f"Set max_tokens to {max_tokens}")
94
+ # if both task and model define stop sequences, those are merged into one list
95
+ stop_sequences_merged = (llm_stop_sequences or []) + (task_stop_sequences or [])
96
+ stop_sequences = sorted(list(set(stop_sequences_merged))) if stop_sequences_merged else None
97
+ logger.info(f"Set stop_sequences to {stop_sequences}")
98
+ return stop_sequences, max_tokens
99
+
100
+ def _generate_completions(
101
+ self,
102
+ samples: list[Sample],
103
+ stop_sequences: list[str] | None = None,
104
+ max_tokens: int | None = None,
105
+ ) -> list[Completion]:
106
+ """
107
+ Generates completions for the sample.
108
+ :param sample: sample to generate completions for
109
+ :param stop_sequences: stop sequences to use in completion generation
110
+ :param max_tokens: maximum tokens to use in completion generation
111
+ :return: completion
112
+ """
113
+ if stop_sequences is None:
114
+ stop_sequences = []
115
+
116
+ raw_completions: list[RawCompletion]
117
+ try:
118
+ raw_completions = self.llm.generate(samples=samples, stop_sequences=stop_sequences, max_tokens=max_tokens)
119
+ except Exception as e:
120
+ if raise_errors():
121
+ raise e
122
+ logger.info(f"Error: {e.__class__.__name__} {e}")
123
+ assert len(samples) == 1, "LLMs not handling errors are not supported in batch mode"
124
+ raw_completions = [
125
+ RawCompletion(
126
+ prompt="",
127
+ prompt_sequence_positions=0,
128
+ completion="",
129
+ completion_sequence_positions=0,
130
+ raw_completion_error=Error(
131
+ error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
132
+ ),
133
+ )
134
+ for _ in range(len(samples))
135
+ ]
136
+
137
+ completion_list = []
138
+ for idx, sample in enumerate(samples):
139
+ raw_completion = raw_completions[idx]
140
+
141
+ if sample.messages and sample.messages[-1].role == Role.ASSISTANT:
142
+ messages = sample.messages[:-1] + [
143
+ Message(role=Role.ASSISTANT, content=sample.messages[-1].content + raw_completion.completion)
144
+ ]
145
+ else:
146
+ messages = sample.messages + [Message(role=Role.ASSISTANT, content=raw_completion.completion)]
147
+
148
+ try:
149
+ error = None
150
+ completion = self.task.post_process_generated_completion(raw_completion.completion, sample)
151
+ except Exception as e:
152
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
153
+ completion = ""
154
+
155
+ completion_list.append(
156
+ Completion(
157
+ id=sample.id,
158
+ subject=sample.subject,
159
+ ground_truth=sample.ground_truth,
160
+ prompt=raw_completion.prompt,
161
+ prompt_sequence_positions=raw_completion.prompt_sequence_positions,
162
+ concat_compression=raw_completion.concat_compression,
163
+ messages=messages,
164
+ completion=completion,
165
+ raw_completion=raw_completion.completion,
166
+ raw_completion_sequence_positions=raw_completion.completion_sequence_positions,
167
+ context=sample.context,
168
+ error=raw_completion.raw_completion_error or error,
169
+ )
170
+ )
171
+
172
+ return completion_list
173
+
174
+ def _generate_loglikelihoods(self, samples: list[Sample]) -> list[Loglikelihood]:
175
+ """
176
+ Generate log likelihoods when a sample is run against the model.
177
+ :param sample: sample to run the task against
178
+ :return: loglikelihoods
179
+ """
180
+ raw_loglikelihoods: list[RawLoglikelihood]
181
+ try:
182
+ raw_loglikelihoods = self.llm.logprobs(samples)
183
+ except Exception as e:
184
+ if raise_errors():
185
+ raise e
186
+ logger.info(f"Error: {e.__class__.__name__} {e}")
187
+ assert len(samples) == 1, "LLMs not handling errors are not supported in batch mode"
188
+ raw_loglikelihoods = [
189
+ RawLoglikelihood(
190
+ prompt="",
191
+ prompt_sequence_positions=0,
192
+ loglikelihoods={},
193
+ loglikelihoods_sequence_positions={},
194
+ raw_loglikelihood_error=Error(
195
+ error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc()
196
+ ),
197
+ )
198
+ for _ in range(len(samples))
199
+ ]
200
+
201
+ loglikelihood_list = []
202
+ for idx, sample in enumerate(samples):
203
+ raw_loglikelihood = raw_loglikelihoods[idx]
204
+ assert sample.ground_truth is not None
205
+ loglikelihood_list.append(
206
+ Loglikelihood(
207
+ id=sample.id,
208
+ subject=sample.subject,
209
+ ground_truth=sample.ground_truth,
210
+ prompt=raw_loglikelihood.prompt,
211
+ prompt_sequence_positions=raw_loglikelihood.prompt_sequence_positions,
212
+ concat_compression=raw_loglikelihood.concat_compression,
213
+ loglikelihoods=raw_loglikelihood.loglikelihoods,
214
+ loglikelihoods_sequence_positions=raw_loglikelihood.loglikelihoods_sequence_positions,
215
+ error=raw_loglikelihood.raw_loglikelihood_error,
216
+ )
217
+ )
218
+ return loglikelihood_list
219
+
220
+ def _generative_output_type_selector(self) -> Callable[[list[Sample]], list[Completion] | list[Loglikelihood]]:
221
+ """
222
+ Selects the generative output type based on the response type.
223
+ :return: function to generate responses
224
+ """
225
+ match self.response_type:
226
+ case ResponseType.COMPLETION:
227
+ stop_sequences, max_tokens = self._llm_task_param_precedence()
228
+ return partial(self._generate_completions, stop_sequences=stop_sequences, max_tokens=max_tokens) # type: ignore[call-arg]
229
+ case ResponseType.LOGLIKELIHOODS:
230
+ return self._generate_loglikelihoods
231
+ case _:
232
+ raise KeyError(f"Task type {self.task} not supported")
233
+
234
+ def _run_task_against_model(
235
+ self, should_preempt_callable: Callable[[], bool]
236
+ ) -> tuple[list[Completion | Loglikelihood], bool]:
237
+ """
238
+ Runs the task against the model and generates responses.
239
+ :param should_preempt_callable: function to check if preempt is called
240
+ :return: list of responses, preempted
241
+ """
242
+ logger.info(f"{RED}[ Running task {self.task.NAME} against model ------------ ]{RESET}")
243
+ self.start_time, monotonic_start = time.time(), time.monotonic()
244
+ run_fn = self._generative_output_type_selector()
245
+ self._verify_loaded_metadata_compatibility()
246
+ responses = self.result_processor.load_responses() # load responses if present
247
+ subject_response_id_mapping = self._map_subject_response_ids(responses)
248
+ self.result_processor.save_metadata(self._get_metadata())
249
+ responses, preempted = self._curate_responses(
250
+ responses, subject_response_id_mapping, run_fn, should_preempt_callable
251
+ )
252
+ self.end_time, monotonic_end = time.time(), time.monotonic()
253
+ self.total_time = monotonic_end - monotonic_start
254
+ self.result_processor.save_metadata(self._get_metadata()) # overwrite with updated timing
255
+
256
+ return responses, preempted
257
+
258
+ def _map_subject_response_ids(self, responses: list[Completion | Loglikelihood]) -> dict[str, set[int]]:
259
+ """
260
+ Maps subject to response id
261
+ :param responses: list of responses
262
+ :return: mapping of subject to response id
263
+ """
264
+ subject_response_id_mapping = {}
265
+ if responses:
266
+ response_subjects = {resp.subject for resp in responses}
267
+ subject_response_id_mapping = {
268
+ response_subject: set([resp.id for resp in responses if resp.subject == response_subject])
269
+ for response_subject in response_subjects
270
+ }
271
+
272
+ return subject_response_id_mapping
273
+
274
+ def _curate_responses(
275
+ self,
276
+ responses: list[Completion | Loglikelihood],
277
+ subject_response_id_mapping: dict[str, set[int]],
278
+ generative_output_function: Callable[[list[Sample]], list[Completion] | list[Loglikelihood]],
279
+ should_preempt_callable: Callable[[], bool],
280
+ ) -> tuple[list[Completion | Loglikelihood], bool]:
281
+ """
282
+ Generates responses for the task and saves them along with metadata.
283
+ :param responses: list of responses
284
+ :param subject_response_id_mapping: mapping of subject to response id
285
+ :param generative_output_function: function to generate responses
286
+ :param metadata: metadata dictionary
287
+ :param should_preempt_callable: function to check if preempt is called
288
+ :return: None
289
+ """
290
+
291
+ def _process_batch(samples_batch: list[Sample]) -> None:
292
+ if not samples_batch:
293
+ return
294
+ if len(samples_batch) > 1:
295
+ log_msg = "Processing batch..."
296
+ logger.info(log_msg) # For log files
297
+ tqdm.write(log_msg) # For console display with tqdm
298
+
299
+ responses_batch = generative_output_function(samples_batch)
300
+ responses.extend(responses_batch)
301
+ if self.save_intermediate_results:
302
+ for response in responses_batch:
303
+ self.result_processor.save_response(response)
304
+
305
+ # In order to enable parallelism we group samples in batches and send them in parallel to the `run_fn`.
306
+ # The BaseLLM class is then in charge of managing the parallelism (eg, using AsyncClient in API models).
307
+ # If samples_batch_size = 1, samples are run sequentially; in any case, we return here after finishing each
308
+ # individual batch to honor preemption requests and save cached results.
309
+ samples_batch_size = self.config.batch_size
310
+
311
+ # Calculate total samples for progress bar - use num_samples or iterate to count
312
+ total_num_samples = self.num_samples
313
+ if total_num_samples is None:
314
+ # Count samples by iterating (this might be expensive for large datasets)
315
+ total_num_samples = sum(1 for _ in self.task.iterate_samples(None))
316
+
317
+ samples_batch: list[Sample] = []
318
+ with tqdm(total=total_num_samples, desc=f"Processing {self.response_type.value}") as pbar:
319
+ for i, sample in enumerate(self.task.iterate_samples(self.num_samples)):
320
+ subject = f" - Subject: {sample.subject}"
321
+ sample_index = i + 1
322
+
323
+ if sample.id in subject_response_id_mapping.get(sample.subject, []):
324
+ log_msg = (
325
+ f"Task: {self.response_type.value}{subject} - Sample: {sample_index} - skipping, already done."
326
+ )
327
+ logger.info(log_msg) # For log files
328
+ tqdm.write(log_msg) # For console display with tqdm
329
+ pbar.update(1)
330
+ continue
331
+
332
+ log_msg = f"Task: {self.response_type.value}{subject} - Sample: {sample_index}/{total_num_samples}"
333
+ logger.info(log_msg) # For log files
334
+ tqdm.write(log_msg) # For console display with tqdm
335
+ pbar.set_postfix_str(f"Sample {sample_index}/{total_num_samples}")
336
+ pbar.update(1)
337
+
338
+ samples_batch.append(sample)
339
+
340
+ if len(samples_batch) >= samples_batch_size:
341
+ _process_batch(samples_batch)
342
+ samples_batch = []
343
+
344
+ if should_preempt_callable():
345
+ log_msg = "Preempt"
346
+ logger.info(log_msg) # For log files
347
+ tqdm.write(log_msg) # For console display with tqdm
348
+ if not self.save_intermediate_results:
349
+ self.result_processor.save_responses(responses)
350
+ return responses, True
351
+
352
+ _process_batch(samples_batch)
353
+
354
+ if not self.save_intermediate_results:
355
+ self.result_processor.save_responses(responses)
356
+ return responses, False
357
+
358
+ def _get_metadata(self) -> dict[str, Any]:
359
+ """Prepares metadata dictionary from the configuration."""
360
+ all_metrics = getattr(self.task, "METRICS", None)
361
+ metadata = self.config.model_dump()
362
+ metadata["llm_name"] = self.llm.name
363
+ metadata["task_name"] = self.task_name
364
+ language = getattr(self.task, "LANGUAGE", None)
365
+ metadata["language"] = map_language_to_value(language)
366
+ metadata["metrics"] = [m.NAME for m in all_metrics] if all_metrics is not None else []
367
+ metadata["primary_metrics"] = getattr(self.task, "PRIMARY_METRICS", None)
368
+ metadata["eval_framework_version"] = eval_framework_version
369
+ metadata["task_output_dir"] = str(self.result_processor.output_dir)
370
+ if hasattr(self, "total_time"):
371
+ metadata["start_time"] = str(datetime.fromtimestamp(self.start_time, UTC))
372
+ metadata["end_time"] = str(datetime.fromtimestamp(self.end_time, UTC))
373
+ metadata["total_time"] = self.total_time
374
+
375
+ try:
376
+ assert get_cluster_info is not None, "Determined cluster info not available"
377
+ info = get_cluster_info()
378
+ if info is not None:
379
+ metadata["determined_agent_id"] = info.agent_id
380
+ if info.task_type == "TRIAL":
381
+ metadata["determined_experiment_id"] = info.trial.experiment_id
382
+ metadata["determined_trial_id"] = info.trial.trial_id
383
+ except Exception as e:
384
+ logger.info(f"{e}; cluster info not available in local context")
385
+
386
+ return metadata
387
+
388
+ def _verify_loaded_metadata_compatibility(self) -> None:
389
+ if not (loaded_metadata := self.result_processor.load_metadata()):
390
+ return
391
+ current_metadata = self._get_metadata()
392
+ # check if crucial keys in metadata are the same as in the previous run
393
+ keys = [
394
+ "task_name",
395
+ "task_subjects",
396
+ "num_fewshot",
397
+ "num_samples",
398
+ "llm_name",
399
+ "llm_args",
400
+ "perturbation_config",
401
+ ]
402
+ for key in keys:
403
+ if loaded_metadata[key] != current_metadata[key]:
404
+ raise ValueError(f"Existing metadata does not match current metadata for {key}.")
405
+
406
+ def generate(self, should_preempt_callable: Callable[[], bool]) -> tuple[list[Completion | Loglikelihood], bool]:
407
+ """Generates responses and saves them along with metadata.
408
+ :param should_preempt_callable: function to check if preempt is called
409
+ :return: list of responses, preempted: whether the process was preempted or not
410
+ """
411
+ logger.info(f"{RED}[ Running responses generation ---------- ]{RESET}")
412
+ logger.info(f"{RED}[ Will save into {self.result_processor.output_dir} ---------- ]{RESET}")
413
+ responses, preempted = self._run_task_against_model(should_preempt_callable)
414
+ logger.info("Completions generated and saved.")
415
+
416
+ return responses, preempted
File without changes
@@ -0,0 +1,74 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+ from eval_framework.shared.types import Completion, Error, Loglikelihood
6
+
7
+ MAIN = "eval_framework_results"
8
+
9
+
10
+ class Result(BaseModel):
11
+ model_config = ConfigDict(extra="forbid")
12
+ id: int
13
+ subject: str
14
+ num_fewshot: int
15
+ llm_name: str
16
+ task_name: str
17
+ metric_class_name: str
18
+ metric_name: str
19
+ key: str | None
20
+ value: float | None
21
+ higher_is_better: bool
22
+ prompt: str
23
+ response: str
24
+ llm_judge_prompt: str | None = None
25
+ llm_judge_response: str | None = None
26
+ code_execution_trace: str | None = None
27
+ error: Error | None = None
28
+
29
+
30
+ class ResultProcessor(ABC):
31
+ @abstractmethod
32
+ def save_metadata(self, metadata: dict) -> None:
33
+ """Save metadata."""
34
+ pass
35
+
36
+ @abstractmethod
37
+ def load_metadata(self) -> dict:
38
+ """Load metadata."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def save_responses(self, responses: list[Completion | Loglikelihood]) -> None:
43
+ """Save a list of response objects (overwrite a file)."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ def save_response(self, response: Completion | Loglikelihood) -> None:
48
+ """Save a single response object (append into a file)."""
49
+ pass
50
+
51
+ @abstractmethod
52
+ def load_responses(self) -> list[Completion | Loglikelihood]:
53
+ """Load a list of response objects."""
54
+ pass
55
+
56
+ @abstractmethod
57
+ def save_metrics_results(self, results: list[Result]) -> None:
58
+ """Save the results of the metrics (overwrite a file)."""
59
+ pass
60
+
61
+ @abstractmethod
62
+ def save_metrics_result(self, result: Result) -> None:
63
+ """Save a single metric result (append into a file)."""
64
+ pass
65
+
66
+ @abstractmethod
67
+ def save_aggregated_results(self, result: dict[str, float | None]) -> None:
68
+ """Save the aggregated results."""
69
+ pass
70
+
71
+ @abstractmethod
72
+ def load_metrics_results(self) -> list[Result]:
73
+ """Load the aggregated results."""
74
+ pass
@@ -0,0 +1,87 @@
1
+ """
2
+ Module for writing result folder and its contents to HuggingFace
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from pathlib import Path
8
+
9
+ from dotenv import load_dotenv
10
+ from huggingface_hub import HfApi, login
11
+ from tqdm import tqdm
12
+
13
+ from eval_framework.tasks.eval_config import EvalConfig
14
+ from eval_framework.utils.constants import RED, RESET
15
+
16
+ load_dotenv()
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class HFProcessor:
22
+ def __init__(self, config: EvalConfig, current_dir: Path) -> None:
23
+ self.output_dir = config.output_dir
24
+ self.current_dir = current_dir
25
+ self.hf_upload_dir = config.hf_upload_dir
26
+ self.hf_upload_repo = config.hf_upload_repo
27
+ assert self.output_dir is not None
28
+ assert self.current_dir is not None
29
+ assert self.hf_upload_dir is not None
30
+ self.hf_upload_dir = self.hf_upload_dir.replace("/", "")
31
+ self.hf_api = HFProcessor._login_into_hf()
32
+
33
+ @classmethod
34
+ def _login_into_hf(cls) -> HfApi | None:
35
+ try:
36
+ login(token=os.environ.get("HF_TOKEN", ""))
37
+ logger.info("logged into HF")
38
+ return HfApi()
39
+
40
+ except Exception:
41
+ logger.info("Could not login into HuggingFace. Check credentials")
42
+ return None
43
+
44
+ def upload_responses_to_HF(self) -> tuple[bool, str | None]:
45
+ hf_repo_name = self.hf_upload_repo
46
+ assert hf_repo_name is not None, "No HF upload repository configured (hf_upload_repo)!"
47
+
48
+ if self.hf_api is None:
49
+ logger.info("Not logged into HuggingFace")
50
+ return False, None
51
+
52
+ try:
53
+ self.upload_dir = Path(self.current_dir).relative_to(Path(self.output_dir))
54
+ self.upload_dir = Path(str(self.hf_upload_dir)) / self.upload_dir # type ignore
55
+ logger.info(f"{RED}[ HF upload to {self.upload_dir} ------- ]{RESET}")
56
+
57
+ except Exception as e:
58
+ logger.info(f"Upload path not properly defined: {e}")
59
+ return False, None
60
+
61
+ upload_counter = 0
62
+ for filename in tqdm(os.listdir(self.current_dir)):
63
+ if filename not in ["results.jsonl", "output.jsonl"]:
64
+ upload_counter += 1
65
+ source_filename = str(Path(self.current_dir) / filename)
66
+ dest_filename = str(Path(self.upload_dir) / filename)
67
+ else:
68
+ logger.info(f"Skipping {filename}; file too large")
69
+
70
+ try:
71
+ self.hf_api.upload_file(
72
+ path_or_fileobj=source_filename,
73
+ path_in_repo=dest_filename,
74
+ repo_id=hf_repo_name,
75
+ repo_type="dataset",
76
+ )
77
+ except Exception as e:
78
+ self.status = "Problem during HF file upload: " + str(e)
79
+ logger.info(self.status)
80
+ return False, None
81
+
82
+ logger.info(f"uploaded {upload_counter} files")
83
+
84
+ hf_url = f"https://huggingface.co/datasets/{hf_repo_name}/tree/main/{self.upload_dir}"
85
+ logger.info(f"Results uploaded to: {hf_url}")
86
+
87
+ return True, hf_url