eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,64 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
4
+ AccuracyLoglikelihood,
5
+ AccuracyNormLoglikelihood,
6
+ )
7
+ from eval_framework.metrics.loglikelihood.confidence_weighted_accuracy import ConfidenceWeightedAccuracy
8
+ from eval_framework.metrics.loglikelihood.dcs import DistributionalCorrectnessScore
9
+ from eval_framework.metrics.loglikelihood.ternary import TernaryScore
10
+ from eval_framework.tasks.base import BaseTask, Language, ResponseType
11
+
12
+
13
+ class WINOGENDER(BaseTask[str]):
14
+ """WINOGENDER dataset: https://huggingface.co/datasets/datasets/oskarvanderwal/winogender"""
15
+
16
+ NAME = "Winogender"
17
+ DATASET_PATH = "oskarvanderwal/winogender"
18
+ SAMPLE_SPLIT = "test"
19
+ FEWSHOT_SPLIT = "test"
20
+ RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
21
+ METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
22
+ SUBJECTS = ["all"]
23
+ LANGUAGE = Language.ENG
24
+
25
+ def _extract_question(self, item: dict) -> str:
26
+ """Format question according to Llama paper."""
27
+ return f"{item['sentence']} '{item['pronoun'].capitalize()}' refers to"
28
+
29
+ def _extract_choices(self, item: dict) -> list[str]:
30
+ choices = item["occupation"], item["participant"]
31
+ # add "the" to any choice that isn't "someone" (else it's ungrammatical)
32
+ return [f"the {c}" if c.lower() != "someone" else c for c in choices]
33
+
34
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
35
+ return self._extract_question(item)
36
+
37
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
38
+ choices = self._extract_choices(item)
39
+ return f" {choices[item['label']]}"
40
+
41
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
42
+ return [f" {choice}" for choice in self._extract_choices(item)]
43
+
44
+
45
+ class WINOGENDER_IDK(WINOGENDER):
46
+ NAME = "Winogender_IDK"
47
+ METRICS = [
48
+ AccuracyLoglikelihood,
49
+ AccuracyNormLoglikelihood,
50
+ ConfidenceWeightedAccuracy,
51
+ DistributionalCorrectnessScore,
52
+ TernaryScore,
53
+ ]
54
+
55
+ def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
56
+ return (
57
+ "Complete the sentence only if you are confident, since mistakes may be penalised, while correct "
58
+ "completions receive points. It is acceptable to answer with 'don't know' if you are unsure, and "
59
+ "you will receive 0 points."
60
+ )
61
+
62
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
63
+ completions = super()._get_possible_completions(item)
64
+ return (completions or []) + [" don't know"]
@@ -0,0 +1,69 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
4
+ AccuracyLoglikelihood,
5
+ AccuracyNormLoglikelihood,
6
+ )
7
+ from eval_framework.metrics.loglikelihood.confidence_weighted_accuracy import ConfidenceWeightedAccuracy
8
+ from eval_framework.metrics.loglikelihood.dcs import DistributionalCorrectnessScore
9
+ from eval_framework.metrics.loglikelihood.ternary import TernaryScore
10
+ from eval_framework.tasks.base import BaseTask, Language, ResponseType
11
+
12
+ ANSWER_STR_TO_NUM = {"1": 0, "2": 1}
13
+
14
+
15
+ class WINOGRANDE(BaseTask[str]):
16
+ """WINOGRANDE dataset: https://huggingface.co/datasets/winogrande"""
17
+
18
+ NAME = "Winogrande"
19
+ DATASET_PATH = "winogrande"
20
+ SAMPLE_SPLIT = "validation"
21
+ FEWSHOT_SPLIT = "train"
22
+ RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
23
+ METRICS = [AccuracyLoglikelihood, AccuracyNormLoglikelihood]
24
+ SUBJECTS = ["winogrande_xl"]
25
+ PERTURBATION_UNMODIFIABLE_WORDS = ["1", "2"]
26
+ LANGUAGE = Language.ENG
27
+
28
+ def _extract_question(self, item: dict) -> str:
29
+ question, _ = item["sentence"].split("_")
30
+ question = question.replace(" ", " ")
31
+ return question.strip()
32
+
33
+ def _extract_choices(self, item: dict) -> list[str]:
34
+ _, choice_suffix = item["sentence"].split("_")
35
+ choice_suffix = choice_suffix.replace(" ", " ")
36
+ choices = [choice + choice_suffix for choice in [item["option1"], item["option2"]]]
37
+ return choices
38
+
39
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
40
+ return f"{self._extract_question(item)}"
41
+
42
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
43
+ choices = self._extract_choices(item)
44
+ return f" {choices[ANSWER_STR_TO_NUM[item['answer']]]}"
45
+
46
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
47
+ return [f" {choice}" for choice in self._extract_choices(item)]
48
+
49
+
50
+ class WINOGRANDE_IDK(WINOGRANDE):
51
+ NAME = "Winogrande_IDK"
52
+ METRICS = [
53
+ AccuracyLoglikelihood,
54
+ AccuracyNormLoglikelihood,
55
+ ConfidenceWeightedAccuracy,
56
+ DistributionalCorrectnessScore,
57
+ TernaryScore,
58
+ ]
59
+
60
+ def _get_initial_prompt_text(self, item: dict[str, Any]) -> str:
61
+ return (
62
+ "Complete the sentence only if you are confident, since mistakes may be penalised, while correct "
63
+ "answers receive points. It is acceptable to answer with 'I do not know' if you are unsure, and "
64
+ "you will receive 0 points."
65
+ )
66
+
67
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
68
+ completions = super()._get_possible_completions(item)
69
+ return (completions or []) + [" I do not know."]
@@ -0,0 +1,57 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.tasks.base import Language
4
+ from eval_framework.tasks.benchmarks.winogrande import WINOGRANDE
5
+
6
+ ANSWER_STR_TO_NUM = {"1": 0, "2": 1}
7
+
8
+
9
+ class WINOX(WINOGRANDE):
10
+ """
11
+ Wino-X is a parallel dataset of German, French, and Russian Winograd schemas, aligned with their English
12
+ counterparts, used to examine whether neural machine translation models can perform coreference resolution that
13
+ requires commonsense knowledge, and whether multilingual language models are capable of commonsense reasoning
14
+ across multiple languages.
15
+
16
+ Winogrande: https://arxiv.org/abs/1907.10641
17
+ Wino-X: https://github.com/demelin/Wino-X
18
+ Wino-X: https://huggingface.co/datasets/demelin/wino_x
19
+ """
20
+
21
+ DATASET_PATH = "demelin/wino_x"
22
+ SAMPLE_SPLIT = "test"
23
+ FEWSHOT_SPLIT = "test"
24
+ LANGUAGE_SHORT_CODE = ""
25
+
26
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
27
+ choices = self._extract_choices(item)
28
+ # in winogrande answer is a string but in wino_x it is an int
29
+ return f" {choices[ANSWER_STR_TO_NUM[str(item['answer'])]]}"
30
+
31
+ def _extract_question(self, item: dict) -> str:
32
+ question, _ = item[f"context_{self.LANGUAGE_SHORT_CODE}"].split("_")
33
+ question = question.replace(" ", " ")
34
+ return question.strip()
35
+
36
+ def _extract_choices(self, item: dict) -> list[str]:
37
+ _, choice_suffix = item[f"context_{self.LANGUAGE_SHORT_CODE}"].split("_")
38
+ choice_suffix = choice_suffix.replace(" ", " ")
39
+ choices = [
40
+ choice + choice_suffix
41
+ for choice in [item[f"option1_{self.LANGUAGE_SHORT_CODE}"], item[f"option2_{self.LANGUAGE_SHORT_CODE}"]]
42
+ ]
43
+ return choices
44
+
45
+
46
+ class WINOX_DE(WINOX):
47
+ NAME = "WINOX_DE"
48
+ SUBJECTS = ["lm_en_de"]
49
+ LANGUAGE = Language.DEU
50
+ LANGUAGE_SHORT_CODE = "de"
51
+
52
+
53
+ class WINOX_FR(WINOX):
54
+ NAME = "WINOX_FR"
55
+ SUBJECTS = ["lm_en_fr"]
56
+ LANGUAGE = Language.FRA
57
+ LANGUAGE_SHORT_CODE = "fr"
@@ -0,0 +1,160 @@
1
+ import random
2
+ from abc import ABC
3
+ from typing import Any
4
+
5
+ import pycountry
6
+ import sacrebleu
7
+
8
+ from eval_framework.metrics.completion.bleu import LINEWISE_BLEU
9
+ from eval_framework.metrics.completion.chrf import LINEWISE_CHRF
10
+ from eval_framework.metrics.completion.ter import LINEWISE_TER
11
+ from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, Sample
12
+
13
+
14
+ class WMT(BaseTask[str], ABC):
15
+ """WMT dataset:"""
16
+
17
+ NAME = "WMT"
18
+ DATASET_PATH = ""
19
+ SAMPLE_SPLIT = "test"
20
+ FEWSHOT_SPLIT = "test"
21
+ RESPONSE_TYPE = ResponseType.COMPLETION
22
+ METRICS = [LINEWISE_BLEU, LINEWISE_CHRF, LINEWISE_TER]
23
+ PERTURBATION_UNMODIFIABLE_WORDS = ["phrase"]
24
+
25
+ def __init__(self, num_fewshot: int = 0) -> None:
26
+ super().__init__(num_fewshot)
27
+ self.stop_sequences: list[str] = [".\n", " phrase: ", "phrase:", "phrase: ", " phrase:", "\n\n"]
28
+
29
+ def _load_dataset(self, subject: str | None) -> None:
30
+ src_file, ref_file, _, _, _ = sacrebleu.download_test_set(test_set=self.DATASET_PATH, langpair=subject)
31
+ src_data, ref_data = [[line.rstrip() for line in sacrebleu.smart_open(file)] for file in (src_file, ref_file)]
32
+
33
+ data_list = [{"source": src, "target": ref, "subject": subject} for src, ref in zip(src_data, ref_data)]
34
+ self.rnd = random.Random(RANDOM_SEED)
35
+ self.rnd.shuffle(data_list)
36
+ self.dataset = {"test": data_list}
37
+
38
+ def _code_to_language(self, code: str) -> str:
39
+ # key is alpha_2 or alpha_3 depending on the code length
40
+ key = f"alpha_{len(code)}"
41
+ language_tuple = pycountry.languages.get(**{key: code})
42
+ return language_tuple.name
43
+
44
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
45
+ language_codes = item["subject"].split("-")
46
+ src_lang = self._code_to_language(language_codes[0])
47
+
48
+ language_codes = item["subject"].split("-")
49
+ tar_lang = self._code_to_language(language_codes[1])
50
+ cue = f"{tar_lang} phrase:"
51
+
52
+ return f"{src_lang} phrase: {item['source']}\n{cue}"
53
+
54
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
55
+ return item["target"] if isinstance(item["target"], str) else item["target"][0]
56
+
57
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
58
+ target = self._get_ground_truth(item)
59
+ assert target is not None
60
+ assert isinstance(target, str)
61
+ return f" {target}"
62
+
63
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
64
+ for stop_sequence in self.stop_sequences:
65
+ if stop_sequence in completion_text:
66
+ completion_text = completion_text.split(stop_sequence)[0]
67
+ return completion_text.strip()
68
+
69
+
70
+ class WMT14(WMT):
71
+ NAME = "WMT14"
72
+ DATASET_PATH = "wmt14"
73
+ SUBJECTS = ["en-fr", "fr-en"]
74
+ LANGUAGE = {
75
+ "en-fr": (Language["ENG"], Language["FRA"]),
76
+ "fr-en": (Language["FRA"], Language["ENG"]),
77
+ }
78
+
79
+
80
+ class WMT16(WMT):
81
+ NAME = "WMT16"
82
+ DATASET_PATH = "wmt16"
83
+ SUBJECTS = ["de-en", "en-de"]
84
+ LANGUAGE = {
85
+ "de-en": (Language["DEU"], Language["ENG"]),
86
+ "en-de": (Language["ENG"], Language["DEU"]),
87
+ }
88
+
89
+
90
+ class WMT20(WMT):
91
+ NAME = "WMT20"
92
+ DATASET_PATH = "wmt20"
93
+ SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
94
+ LANGUAGE = {
95
+ "de-en": (Language["DEU"], Language["ENG"]),
96
+ "de-fr": (Language["DEU"], Language["FRA"]),
97
+ "en-de": (Language["ENG"], Language["DEU"]),
98
+ "fr-de": (Language["FRA"], Language["DEU"]),
99
+ }
100
+
101
+
102
+ class WMT_INSTRUCT(WMT):
103
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Please", "translate"]
104
+ COMPLETION_PREFIX = "This is the translation:"
105
+
106
+ def __init__(self, num_fewshot: int = 0) -> None:
107
+ super().__init__(num_fewshot)
108
+ self.stop_sequences: list[str] = ["Please translate"]
109
+
110
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
111
+ src_lang, tar_lang = map(self._code_to_language, item["subject"].split("-"))
112
+ return f"Please translate from {src_lang} to {tar_lang}: {item['source']}"
113
+
114
+ def _get_cue(self, item: dict[str, Any]) -> str:
115
+ return self.COMPLETION_PREFIX
116
+
117
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
118
+ target = self._get_ground_truth(item)
119
+ assert target is not None
120
+ return f" {target}"
121
+
122
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
123
+ completion_text = completion_text.removeprefix(self.COMPLETION_PREFIX)
124
+ completion_text = completion_text.strip()
125
+ for stop_sequence in self.stop_sequences:
126
+ if stop_sequence in completion_text:
127
+ completion_text = completion_text.split(stop_sequence)[0]
128
+ return completion_text
129
+
130
+
131
+ class WMT14_INSTRUCT(WMT_INSTRUCT):
132
+ NAME = "WMT14 Instruct"
133
+ DATASET_PATH = "wmt14"
134
+ SUBJECTS = ["en-fr", "fr-en"]
135
+ LANGUAGE = {
136
+ "en-fr": (Language["ENG"], Language["FRA"]),
137
+ "fr-en": (Language["FRA"], Language["ENG"]),
138
+ }
139
+
140
+
141
+ class WMT16_INSTRUCT(WMT_INSTRUCT):
142
+ NAME = "WMT16 Instruct"
143
+ DATASET_PATH = "wmt16"
144
+ SUBJECTS = ["de-en", "en-de"]
145
+ LANGUAGE = {
146
+ "de-en": (Language["DEU"], Language["ENG"]),
147
+ "en-de": (Language["ENG"], Language["DEU"]),
148
+ }
149
+
150
+
151
+ class WMT20_INSTRUCT(WMT_INSTRUCT):
152
+ NAME = "WMT20 Instruct"
153
+ DATASET_PATH = "wmt20"
154
+ SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
155
+ LANGUAGE = {
156
+ "de-en": (Language["DEU"], Language["ENG"]),
157
+ "de-fr": (Language["DEU"], Language["FRA"]),
158
+ "en-de": (Language["ENG"], Language["DEU"]),
159
+ "fr-de": (Language["FRA"], Language["DEU"]),
160
+ }
@@ -0,0 +1,197 @@
1
+ import re
2
+ from typing import Any
3
+
4
+ from eval_framework.metrics.completion.exponential_similarity import ExponentialSimilarity
5
+ from eval_framework.metrics.completion.f1 import F1
6
+ from eval_framework.metrics.completion.rouge_geometric_mean import ROUGE_GEOMETRIC_MEAN
7
+ from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
8
+ AccuracyLoglikelihood,
9
+ )
10
+ from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
11
+ from eval_framework.tasks.utils import get_n_letters
12
+
13
+
14
+ class ZERO_SCROLLS_QUALITY(BaseTask[str]):
15
+ """ZeroSCROLLS dataset: https://huggingface.co/datasets/tau/zero_scrolls"""
16
+
17
+ NAME = "ZeroSCROLLS QuALITY"
18
+ DATASET_PATH = "tau/zero_scrolls"
19
+ SAMPLE_SPLIT = "validation"
20
+ FEWSHOT_SPLIT = "validation"
21
+ RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
22
+ METRICS = [AccuracyLoglikelihood]
23
+ SUBJECTS = ["quality"]
24
+
25
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
26
+ LANGUAGE = Language.ENG
27
+
28
+ def __init__(self, num_fewshot: int = 0) -> None:
29
+ assert num_fewshot == 0, "ZeroSCROLLS QuALITY only supports zero fewshot examples"
30
+ super().__init__(num_fewshot)
31
+ self.keys = get_n_letters(4)
32
+
33
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
34
+ query_end_index = item["query_end_index"]
35
+ return f"{item['input'][:query_end_index]}\n\n"
36
+
37
+ def _get_cue_text(self, item: dict[str, Any]) -> str:
38
+ return "Answer:"
39
+
40
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
41
+ return f" {item['output']}"
42
+
43
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
44
+ return [f" {key}" for key in self.keys]
45
+
46
+
47
+ class ZERO_SCROLLS_COMPLETION(BaseTask[str]):
48
+ """ZeroSCROLLS dataset: https://huggingface.co/datasets/tau/zero_scrolls"""
49
+
50
+ DATASET_PATH = "tau/zero_scrolls"
51
+ SAMPLE_SPLIT = "validation"
52
+ FEWSHOT_SPLIT = "validation"
53
+ RESPONSE_TYPE = ResponseType.COMPLETION
54
+
55
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
56
+ return item["output"]
57
+
58
+
59
+ class ZERO_SCROLLS_GOV_REPORT(ZERO_SCROLLS_COMPLETION):
60
+ NAME = "ZeroSCROLLS GovReport"
61
+ METRICS = [ROUGE_GEOMETRIC_MEAN]
62
+ SUBJECTS = ["gov_report"]
63
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Summary"]
64
+
65
+ def __init__(self, num_fewshot: int = 0) -> None:
66
+ assert num_fewshot == 0, "ZeroSCROLLS GovReport only supports zero fewshot examples"
67
+ super().__init__(num_fewshot)
68
+
69
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
70
+ query_end_index = item["query_end_index"]
71
+ return f"{item['input'][:query_end_index]}Summary:"
72
+
73
+
74
+ class ZERO_SCROLLS_QMSUM(ZERO_SCROLLS_COMPLETION):
75
+ NAME = "ZeroSCROLLS QMSum"
76
+ METRICS = [ROUGE_GEOMETRIC_MEAN]
77
+ SUBJECTS = ["qmsum"]
78
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
79
+
80
+ def __init__(self, num_fewshot: int = 0) -> None:
81
+ assert num_fewshot == 0, "ZeroSCROLLS QMSum only supports zero fewshot examples"
82
+ super().__init__(num_fewshot)
83
+
84
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
85
+ query_end_index = item["query_end_index"]
86
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
87
+
88
+
89
+ class ZERO_SCROLLS_SQUALITY(ZERO_SCROLLS_COMPLETION):
90
+ NAME = "ZeroSCROLLS SQuALITY"
91
+ METRICS = [ROUGE_GEOMETRIC_MEAN]
92
+ SUBJECTS = ["squality"]
93
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
94
+
95
+ def __init__(self, num_fewshot: int = 0) -> None:
96
+ assert num_fewshot == 0, "ZeroSCROLLS SQuALITY only supports zero fewshot examples"
97
+ super().__init__(num_fewshot)
98
+
99
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
100
+ query_end_index = item["query_end_index"]
101
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
102
+
103
+
104
+ class ZERO_SCROLLS_QASPER(ZERO_SCROLLS_COMPLETION):
105
+ NAME = "ZeroSCROLLS Qasper"
106
+ METRICS = [F1]
107
+ SUBJECTS = ["qasper"]
108
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
109
+
110
+ def __init__(self, num_fewshot: int = 0) -> None:
111
+ assert num_fewshot == 0, "ZeroSCROLLS Qasper only supports zero fewshot examples"
112
+ super().__init__(num_fewshot)
113
+
114
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
115
+ query_end_index = item["query_end_index"]
116
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
117
+
118
+
119
+ class ZERO_SCROLLS_NARRATIVEQA(ZERO_SCROLLS_COMPLETION):
120
+ NAME = "ZeroSCROLLS NarrativeQA"
121
+ METRICS = [F1]
122
+ SUBJECTS = ["narrative_qa"]
123
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
124
+
125
+ def __init__(self, num_fewshot: int = 0) -> None:
126
+ assert num_fewshot == 0, "ZeroSCROLLS NarrativeQA only supports zero fewshot examples"
127
+ super().__init__(num_fewshot)
128
+
129
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
130
+ query_end_index = item["query_end_index"]
131
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
132
+
133
+
134
+ class ZERO_SCROLLS_MUSIQUE(ZERO_SCROLLS_COMPLETION):
135
+ NAME = "ZeroSCROLLS MuSiQue"
136
+ METRICS = [F1]
137
+ SUBJECTS = ["musique"]
138
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
139
+
140
+ def __init__(self, num_fewshot: int = 0) -> None:
141
+ assert num_fewshot == 0, "ZeroSCROLLS MuSiQue only supports zero fewshot examples"
142
+ super().__init__(num_fewshot)
143
+
144
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
145
+ query_end_index = item["query_end_index"]
146
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
147
+
148
+
149
+ class ZERO_SCROLLS_SPACE_DIGEST(ZERO_SCROLLS_COMPLETION):
150
+ NAME = "ZeroSCROLLS SpaceDigest"
151
+ METRICS = [ExponentialSimilarity]
152
+ SUBJECTS = ["space_digest"]
153
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
154
+
155
+ def __init__(self, num_fewshot: int = 0) -> None:
156
+ assert num_fewshot == 0, "ZeroSCROLLS SpaceDigest only supports zero fewshot examples"
157
+ super().__init__(num_fewshot)
158
+
159
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
160
+ # First, try to find patterns like "X%" or "X percent" or "X percentage"
161
+ percentage_patterns = [
162
+ r"(\d+(?:\.\d+)?)%", # Matches: 30%, 30.5%
163
+ r"(\d+(?:\.\d+)?)\s*percent", # Matches: 30 percent, 30.5 percent
164
+ r"(\d+(?:\.\d+)?)\s*percentage", # Matches: 30 percentage, 30.5 percentage
165
+ r"percentage\s*(?:is|of|:)?\s*(\d+(?:\.\d+)?)", # Matches: percentage is 30, percentage: 30.5
166
+ r"(?:is|equals|equal to|about|approximately|around|roughly)\s*(\d+(?:\.\d+)?)\s*%",
167
+ # Matches: is 30%, equals 30.5%
168
+ r"(?:is|equals|equal to|about|approximately|around|roughly)\s*(\d+(?:\.\d+)?)\s*percent",
169
+ # Matches: is 30 percent
170
+ r"it'?s\s*(\d+(?:\.\d+)?)", # Matches: it's 60, its 60
171
+ r"that'?s\s*(\d+(?:\.\d+)?)", # Matches: that's 60, thats 60
172
+ ]
173
+
174
+ for pattern in percentage_patterns:
175
+ match = re.search(pattern, completion_text, re.IGNORECASE)
176
+ if match:
177
+ return match.group(1).strip()
178
+
179
+ # If no percentage pattern is found, check if the entire text is just a number
180
+ if re.fullmatch(r"\s*(\d+(?:\.\d+)?)\s*", completion_text):
181
+ return completion_text.strip()
182
+
183
+ # If not a standalone number, look for any number in the text
184
+ # This is a fallback and might be less accurate
185
+ number_match = re.search(r"(\d+(?:\.\d+)?)", completion_text)
186
+ if number_match:
187
+ return number_match.group(1).strip()
188
+
189
+ # If no number is found, return the original text stripped
190
+ return completion_text.strip()
191
+
192
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
193
+ query_end_index = item["query_end_index"]
194
+ return f"{item['input'][:query_end_index]}Answer:"
195
+
196
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
197
+ return self.post_process_generated_completion(item["output"])
@@ -0,0 +1,136 @@
1
+ import ast
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Annotated, Any
5
+
6
+ from pydantic import AfterValidator, BeforeValidator, Field, field_serializer, field_validator, model_validator
7
+
8
+ from eval_framework.base_config import BaseConfig
9
+ from eval_framework.llm.base import BaseLLM
10
+ from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
11
+ from eval_framework.tasks.base import BaseTask
12
+ from eval_framework.tasks.perturbation import PerturbationConfig
13
+ from eval_framework.tasks.registry import get_task, validate_task_name
14
+ from eval_framework.utils.constants import ROOT_DIR
15
+
16
+ # Keys that don't impact actual evaluation results and should be excluded from config dumps for hashing purposes.
17
+ KEYS_UNRELATED_TO_RESULTS = {
18
+ "output_dir",
19
+ "wandb_project",
20
+ "wandb_entity",
21
+ "wandb_run_id",
22
+ "wandb_upload_results",
23
+ "hf_upload_dir",
24
+ "hf_upload_repo",
25
+ "description",
26
+ "save_intermediate_results",
27
+ "save_logs",
28
+ "delete_output_dir_after_upload",
29
+ }
30
+
31
+
32
+ class EvalConfig(BaseConfig):
33
+ output_dir: Path = ROOT_DIR
34
+ wandb_project: str | None = None
35
+ wandb_entity: str | None = None
36
+ wandb_run_id: str | None = None
37
+ wandb_upload_results: Annotated[bool, BeforeValidator(lambda v: True if v is None else v)] = True
38
+ hf_upload_dir: str | None = None
39
+ hf_upload_repo: str | None = None
40
+ num_fewshot: Annotated[int, Field(ge=0)] = 0
41
+ num_samples: Annotated[int | None, Field(ge=1)] = 10 # Allows None or int
42
+ max_tokens: int | None = None
43
+ perturbation_config: PerturbationConfig | None = None
44
+ task_name: Annotated[str, AfterValidator(validate_task_name)]
45
+ task_subjects: list[str] | None = None
46
+ hf_revision: str | None = None
47
+ llm_class: type[BaseLLM]
48
+ llm_args: dict[str, Any] = Field(default_factory=dict)
49
+ llm_judge_class: type[BaseLLM] | None = None
50
+ judge_model_args: dict[str, Any] = Field(default_factory=dict)
51
+ randomize_judge_order: bool = False
52
+ batch_size: Annotated[int, Field(ge=1)] = 1
53
+ description: str | None = None
54
+ save_intermediate_results: Annotated[bool, BeforeValidator(lambda v: True if v is None else v)] = True
55
+ save_logs: Annotated[bool, BeforeValidator(lambda v: True if v is None else v)] = True
56
+ delete_output_dir_after_upload: Annotated[bool, BeforeValidator(lambda v: False if v is None else v)] = False
57
+
58
+ # Adding a new member? Remember to update KEYS_UNRELATED_TO_RESULTS if it doesn't impact eval results.
59
+
60
+ @property
61
+ def task_class(self) -> type[BaseTask]:
62
+ return get_task(self.task_name)
63
+
64
+ @field_serializer("output_dir")
65
+ def serialize_output_dir(self, value: Path) -> str:
66
+ return str(value)
67
+
68
+ @field_validator("output_dir", mode="before")
69
+ @classmethod
70
+ def validate_output_dir(cls, value: str | Path) -> Path:
71
+ if isinstance(value, str):
72
+ return Path(value)
73
+ return value
74
+
75
+ @field_validator("llm_args", mode="before")
76
+ @classmethod
77
+ def validate_llm_args(cls, value: dict[str, Any]) -> dict[str, Any]:
78
+ def convert_value(v: Any) -> Any:
79
+ if isinstance(v, dict):
80
+ # Recursively process nested dictionaries (like sampling_params)
81
+ return {k: convert_value(nested_v) for k, nested_v in v.items()}
82
+ elif isinstance(v, str):
83
+ try:
84
+ # Try to evaluate as a Python literal (int, float, bool, None, list, dict, etc.)
85
+ return ast.literal_eval(v)
86
+ except (ValueError, SyntaxError):
87
+ return v # keep as string if not a valid literal
88
+ else:
89
+ return v # already proper type
90
+
91
+ return convert_value(value)
92
+
93
+ @field_validator("judge_model_args", mode="before")
94
+ @classmethod
95
+ def validate_judge_model_args(cls, value: dict[str, Any]) -> dict[str, Any]:
96
+ typed_value = {}
97
+ for k, v in value.items():
98
+ try: # maybe this llm argument is actually a number?
99
+ if "." in str(v):
100
+ v = float(v)
101
+ else:
102
+ v = int(v)
103
+ except ValueError:
104
+ pass
105
+ typed_value[k] = v
106
+ return typed_value
107
+
108
+ @model_validator(mode="after")
109
+ def validate_llm_judge_defined(self) -> "EvalConfig":
110
+ task = get_task(self.task_name)
111
+ for metric_class in task.METRICS:
112
+ if issubclass(metric_class, BaseLLMJudgeMetric):
113
+ assert self.llm_judge_class is not None, "The LLM Judge must be defined for this evaluation task."
114
+ return self
115
+
116
+ @field_serializer("llm_class")
117
+ def serialize_llm_class(self, value: type[BaseLLM] | None) -> str | None:
118
+ """Serialize the class into its fully qualified name."""
119
+ if value:
120
+ return value.__name__
121
+ return None
122
+
123
+ @field_serializer("llm_judge_class")
124
+ def serialize_llm_judge_class(self, value: type[BaseLLM] | None) -> str | None:
125
+ """Serialize the class into its fully qualified name."""
126
+ if value:
127
+ return value.__name__
128
+ return None
129
+
130
+ def model_json_dump(self) -> str:
131
+ model_dump = self.model_dump(mode="json")
132
+ return json.dumps(model_dump, sort_keys=True)
133
+
134
+ def model_json_robust_subset_dump(self) -> str:
135
+ model_dump = self.model_dump(mode="json", exclude=KEYS_UNRELATED_TO_RESULTS)
136
+ return json.dumps(model_dump, sort_keys=True)