eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,160 @@
1
+ import random
2
+ from abc import ABC
3
+ from typing import Any
4
+
5
+ import pycountry
6
+ import sacrebleu
7
+
8
+ from eval_framework.metrics.completion.bleu import LINEWISE_BLEU
9
+ from eval_framework.metrics.completion.chrf import LINEWISE_CHRF
10
+ from eval_framework.metrics.completion.ter import LINEWISE_TER
11
+ from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Language, ResponseType, Sample
12
+
13
+
14
+ class WMT(BaseTask[str], ABC):
15
+ """WMT dataset:"""
16
+
17
+ NAME = "WMT"
18
+ DATASET_PATH = ""
19
+ SAMPLE_SPLIT = "test"
20
+ FEWSHOT_SPLIT = "test"
21
+ RESPONSE_TYPE = ResponseType.COMPLETION
22
+ METRICS = [LINEWISE_BLEU, LINEWISE_CHRF, LINEWISE_TER]
23
+ PERTURBATION_UNMODIFIABLE_WORDS = ["phrase"]
24
+
25
+ def __init__(self, num_fewshot: int = 0) -> None:
26
+ super().__init__(num_fewshot)
27
+ self.stop_sequences: list[str] = [".\n", " phrase: ", "phrase:", "phrase: ", " phrase:", "\n\n"]
28
+
29
+ def _load_dataset(self, subject: str | None) -> None:
30
+ src_file, ref_file, _, _, _ = sacrebleu.download_test_set(test_set=self.DATASET_PATH, langpair=subject)
31
+ src_data, ref_data = [[line.rstrip() for line in sacrebleu.smart_open(file)] for file in (src_file, ref_file)]
32
+
33
+ data_list = [{"source": src, "target": ref, "subject": subject} for src, ref in zip(src_data, ref_data)]
34
+ self.rnd = random.Random(RANDOM_SEED)
35
+ self.rnd.shuffle(data_list)
36
+ self.dataset = {"test": data_list}
37
+
38
+ def _code_to_language(self, code: str) -> str:
39
+ # key is alpha_2 or alpha_3 depending on the code length
40
+ key = f"alpha_{len(code)}"
41
+ language_tuple = pycountry.languages.get(**{key: code})
42
+ return language_tuple.name
43
+
44
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
45
+ language_codes = item["subject"].split("-")
46
+ src_lang = self._code_to_language(language_codes[0])
47
+
48
+ language_codes = item["subject"].split("-")
49
+ tar_lang = self._code_to_language(language_codes[1])
50
+ cue = f"{tar_lang} phrase:"
51
+
52
+ return f"{src_lang} phrase: {item['source']}\n{cue}"
53
+
54
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
55
+ return item["target"] if isinstance(item["target"], str) else item["target"][0]
56
+
57
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
58
+ target = self._get_ground_truth(item)
59
+ assert target is not None
60
+ assert isinstance(target, str)
61
+ return f" {target}"
62
+
63
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
64
+ for stop_sequence in self.stop_sequences:
65
+ if stop_sequence in completion_text:
66
+ completion_text = completion_text.split(stop_sequence)[0]
67
+ return completion_text.strip()
68
+
69
+
70
+ class WMT14(WMT):
71
+ NAME = "WMT14"
72
+ DATASET_PATH = "wmt14"
73
+ SUBJECTS = ["en-fr", "fr-en"]
74
+ LANGUAGE = {
75
+ "en-fr": (Language["ENG"], Language["FRA"]),
76
+ "fr-en": (Language["FRA"], Language["ENG"]),
77
+ }
78
+
79
+
80
+ class WMT16(WMT):
81
+ NAME = "WMT16"
82
+ DATASET_PATH = "wmt16"
83
+ SUBJECTS = ["de-en", "en-de"]
84
+ LANGUAGE = {
85
+ "de-en": (Language["DEU"], Language["ENG"]),
86
+ "en-de": (Language["ENG"], Language["DEU"]),
87
+ }
88
+
89
+
90
+ class WMT20(WMT):
91
+ NAME = "WMT20"
92
+ DATASET_PATH = "wmt20"
93
+ SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
94
+ LANGUAGE = {
95
+ "de-en": (Language["DEU"], Language["ENG"]),
96
+ "de-fr": (Language["DEU"], Language["FRA"]),
97
+ "en-de": (Language["ENG"], Language["DEU"]),
98
+ "fr-de": (Language["FRA"], Language["DEU"]),
99
+ }
100
+
101
+
102
+ class WMT_INSTRUCT(WMT):
103
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Please", "translate"]
104
+ COMPLETION_PREFIX = "This is the translation:"
105
+
106
+ def __init__(self, num_fewshot: int = 0) -> None:
107
+ super().__init__(num_fewshot)
108
+ self.stop_sequences: list[str] = ["Please translate"]
109
+
110
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
111
+ src_lang, tar_lang = map(self._code_to_language, item["subject"].split("-"))
112
+ return f"Please translate from {src_lang} to {tar_lang}: {item['source']}"
113
+
114
+ def _get_cue(self, item: dict[str, Any]) -> str:
115
+ return self.COMPLETION_PREFIX
116
+
117
+ def _get_fewshot_target_text(self, item: dict[str, Any]) -> str:
118
+ target = self._get_ground_truth(item)
119
+ assert target is not None
120
+ return f" {target}"
121
+
122
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
123
+ completion_text = completion_text.removeprefix(self.COMPLETION_PREFIX)
124
+ completion_text = completion_text.strip()
125
+ for stop_sequence in self.stop_sequences:
126
+ if stop_sequence in completion_text:
127
+ completion_text = completion_text.split(stop_sequence)[0]
128
+ return completion_text
129
+
130
+
131
+ class WMT14_INSTRUCT(WMT_INSTRUCT):
132
+ NAME = "WMT14 Instruct"
133
+ DATASET_PATH = "wmt14"
134
+ SUBJECTS = ["en-fr", "fr-en"]
135
+ LANGUAGE = {
136
+ "en-fr": (Language["ENG"], Language["FRA"]),
137
+ "fr-en": (Language["FRA"], Language["ENG"]),
138
+ }
139
+
140
+
141
+ class WMT16_INSTRUCT(WMT_INSTRUCT):
142
+ NAME = "WMT16 Instruct"
143
+ DATASET_PATH = "wmt16"
144
+ SUBJECTS = ["de-en", "en-de"]
145
+ LANGUAGE = {
146
+ "de-en": (Language["DEU"], Language["ENG"]),
147
+ "en-de": (Language["ENG"], Language["DEU"]),
148
+ }
149
+
150
+
151
+ class WMT20_INSTRUCT(WMT_INSTRUCT):
152
+ NAME = "WMT20 Instruct"
153
+ DATASET_PATH = "wmt20"
154
+ SUBJECTS = ["de-en", "de-fr", "en-de", "fr-de"]
155
+ LANGUAGE = {
156
+ "de-en": (Language["DEU"], Language["ENG"]),
157
+ "de-fr": (Language["DEU"], Language["FRA"]),
158
+ "en-de": (Language["ENG"], Language["DEU"]),
159
+ "fr-de": (Language["FRA"], Language["DEU"]),
160
+ }
@@ -0,0 +1,197 @@
1
+ import re
2
+ from typing import Any
3
+
4
+ from eval_framework.metrics.completion.exponential_similarity import ExponentialSimilarity
5
+ from eval_framework.metrics.completion.f1 import F1
6
+ from eval_framework.metrics.completion.rouge_geometric_mean import ROUGE_GEOMETRIC_MEAN
7
+ from eval_framework.metrics.loglikelihood.accuracy_loglikelihood import (
8
+ AccuracyLoglikelihood,
9
+ )
10
+ from eval_framework.tasks.base import BaseTask, Language, ResponseType, Sample
11
+ from eval_framework.tasks.utils import get_n_letters
12
+
13
+
14
+ class ZERO_SCROLLS_QUALITY(BaseTask[str]):
15
+ """ZeroSCROLLS dataset: https://huggingface.co/datasets/tau/zero_scrolls"""
16
+
17
+ NAME = "ZeroSCROLLS QuALITY"
18
+ DATASET_PATH = "tau/zero_scrolls"
19
+ SAMPLE_SPLIT = "validation"
20
+ FEWSHOT_SPLIT = "validation"
21
+ RESPONSE_TYPE = ResponseType.LOGLIKELIHOODS
22
+ METRICS = [AccuracyLoglikelihood]
23
+ SUBJECTS = ["quality"]
24
+
25
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
26
+ LANGUAGE = Language.ENG
27
+
28
+ def __init__(self, num_fewshot: int = 0) -> None:
29
+ assert num_fewshot == 0, "ZeroSCROLLS QuALITY only supports zero fewshot examples"
30
+ super().__init__(num_fewshot)
31
+ self.keys = get_n_letters(4)
32
+
33
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
34
+ query_end_index = item["query_end_index"]
35
+ return f"{item['input'][:query_end_index]}\n\n"
36
+
37
+ def _get_cue_text(self, item: dict[str, Any]) -> str:
38
+ return "Answer:"
39
+
40
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
41
+ return f" {item['output']}"
42
+
43
+ def _get_possible_completions(self, item: dict[str, Any]) -> list[str] | None:
44
+ return [f" {key}" for key in self.keys]
45
+
46
+
47
+ class ZERO_SCROLLS_COMPLETION(BaseTask[str]):
48
+ """ZeroSCROLLS dataset: https://huggingface.co/datasets/tau/zero_scrolls"""
49
+
50
+ DATASET_PATH = "tau/zero_scrolls"
51
+ SAMPLE_SPLIT = "validation"
52
+ FEWSHOT_SPLIT = "validation"
53
+ RESPONSE_TYPE = ResponseType.COMPLETION
54
+
55
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
56
+ return item["output"]
57
+
58
+
59
+ class ZERO_SCROLLS_GOV_REPORT(ZERO_SCROLLS_COMPLETION):
60
+ NAME = "ZeroSCROLLS GovReport"
61
+ METRICS = [ROUGE_GEOMETRIC_MEAN]
62
+ SUBJECTS = ["gov_report"]
63
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Summary"]
64
+
65
+ def __init__(self, num_fewshot: int = 0) -> None:
66
+ assert num_fewshot == 0, "ZeroSCROLLS GovReport only supports zero fewshot examples"
67
+ super().__init__(num_fewshot)
68
+
69
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
70
+ query_end_index = item["query_end_index"]
71
+ return f"{item['input'][:query_end_index]}Summary:"
72
+
73
+
74
+ class ZERO_SCROLLS_QMSUM(ZERO_SCROLLS_COMPLETION):
75
+ NAME = "ZeroSCROLLS QMSum"
76
+ METRICS = [ROUGE_GEOMETRIC_MEAN]
77
+ SUBJECTS = ["qmsum"]
78
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
79
+
80
+ def __init__(self, num_fewshot: int = 0) -> None:
81
+ assert num_fewshot == 0, "ZeroSCROLLS QMSum only supports zero fewshot examples"
82
+ super().__init__(num_fewshot)
83
+
84
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
85
+ query_end_index = item["query_end_index"]
86
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
87
+
88
+
89
+ class ZERO_SCROLLS_SQUALITY(ZERO_SCROLLS_COMPLETION):
90
+ NAME = "ZeroSCROLLS SQuALITY"
91
+ METRICS = [ROUGE_GEOMETRIC_MEAN]
92
+ SUBJECTS = ["squality"]
93
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
94
+
95
+ def __init__(self, num_fewshot: int = 0) -> None:
96
+ assert num_fewshot == 0, "ZeroSCROLLS SQuALITY only supports zero fewshot examples"
97
+ super().__init__(num_fewshot)
98
+
99
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
100
+ query_end_index = item["query_end_index"]
101
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
102
+
103
+
104
+ class ZERO_SCROLLS_QASPER(ZERO_SCROLLS_COMPLETION):
105
+ NAME = "ZeroSCROLLS Qasper"
106
+ METRICS = [F1]
107
+ SUBJECTS = ["qasper"]
108
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
109
+
110
+ def __init__(self, num_fewshot: int = 0) -> None:
111
+ assert num_fewshot == 0, "ZeroSCROLLS Qasper only supports zero fewshot examples"
112
+ super().__init__(num_fewshot)
113
+
114
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
115
+ query_end_index = item["query_end_index"]
116
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
117
+
118
+
119
+ class ZERO_SCROLLS_NARRATIVEQA(ZERO_SCROLLS_COMPLETION):
120
+ NAME = "ZeroSCROLLS NarrativeQA"
121
+ METRICS = [F1]
122
+ SUBJECTS = ["narrative_qa"]
123
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
124
+
125
+ def __init__(self, num_fewshot: int = 0) -> None:
126
+ assert num_fewshot == 0, "ZeroSCROLLS NarrativeQA only supports zero fewshot examples"
127
+ super().__init__(num_fewshot)
128
+
129
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
130
+ query_end_index = item["query_end_index"]
131
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
132
+
133
+
134
+ class ZERO_SCROLLS_MUSIQUE(ZERO_SCROLLS_COMPLETION):
135
+ NAME = "ZeroSCROLLS MuSiQue"
136
+ METRICS = [F1]
137
+ SUBJECTS = ["musique"]
138
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
139
+
140
+ def __init__(self, num_fewshot: int = 0) -> None:
141
+ assert num_fewshot == 0, "ZeroSCROLLS MuSiQue only supports zero fewshot examples"
142
+ super().__init__(num_fewshot)
143
+
144
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
145
+ query_end_index = item["query_end_index"]
146
+ return f"{item['input'][:query_end_index]}\n\nAnswer:"
147
+
148
+
149
+ class ZERO_SCROLLS_SPACE_DIGEST(ZERO_SCROLLS_COMPLETION):
150
+ NAME = "ZeroSCROLLS SpaceDigest"
151
+ METRICS = [ExponentialSimilarity]
152
+ SUBJECTS = ["space_digest"]
153
+ PERTURBATION_UNMODIFIABLE_WORDS = ["Answer"]
154
+
155
+ def __init__(self, num_fewshot: int = 0) -> None:
156
+ assert num_fewshot == 0, "ZeroSCROLLS SpaceDigest only supports zero fewshot examples"
157
+ super().__init__(num_fewshot)
158
+
159
+ def post_process_generated_completion(self, completion_text: str, sample: Sample | None = None) -> str:
160
+ # First, try to find patterns like "X%" or "X percent" or "X percentage"
161
+ percentage_patterns = [
162
+ r"(\d+(?:\.\d+)?)%", # Matches: 30%, 30.5%
163
+ r"(\d+(?:\.\d+)?)\s*percent", # Matches: 30 percent, 30.5 percent
164
+ r"(\d+(?:\.\d+)?)\s*percentage", # Matches: 30 percentage, 30.5 percentage
165
+ r"percentage\s*(?:is|of|:)?\s*(\d+(?:\.\d+)?)", # Matches: percentage is 30, percentage: 30.5
166
+ r"(?:is|equals|equal to|about|approximately|around|roughly)\s*(\d+(?:\.\d+)?)\s*%",
167
+ # Matches: is 30%, equals 30.5%
168
+ r"(?:is|equals|equal to|about|approximately|around|roughly)\s*(\d+(?:\.\d+)?)\s*percent",
169
+ # Matches: is 30 percent
170
+ r"it'?s\s*(\d+(?:\.\d+)?)", # Matches: it's 60, its 60
171
+ r"that'?s\s*(\d+(?:\.\d+)?)", # Matches: that's 60, thats 60
172
+ ]
173
+
174
+ for pattern in percentage_patterns:
175
+ match = re.search(pattern, completion_text, re.IGNORECASE)
176
+ if match:
177
+ return match.group(1).strip()
178
+
179
+ # If no percentage pattern is found, check if the entire text is just a number
180
+ if re.fullmatch(r"\s*(\d+(?:\.\d+)?)\s*", completion_text):
181
+ return completion_text.strip()
182
+
183
+ # If not a standalone number, look for any number in the text
184
+ # This is a fallback and might be less accurate
185
+ number_match = re.search(r"(\d+(?:\.\d+)?)", completion_text)
186
+ if number_match:
187
+ return number_match.group(1).strip()
188
+
189
+ # If no number is found, return the original text stripped
190
+ return completion_text.strip()
191
+
192
+ def _get_instruction_text(self, item: dict[str, Any]) -> str:
193
+ query_end_index = item["query_end_index"]
194
+ return f"{item['input'][:query_end_index]}Answer:"
195
+
196
+ def _get_ground_truth(self, item: dict[str, Any]) -> str | None:
197
+ return self.post_process_generated_completion(item["output"])
@@ -0,0 +1,112 @@
1
+ import ast
2
+ import json
3
+ from pathlib import Path
4
+ from typing import Annotated, Any
5
+
6
+ from pydantic import AfterValidator, Field, field_serializer, field_validator, model_validator
7
+
8
+ from eval_framework.base_config import BaseConfig
9
+ from eval_framework.llm.base import BaseLLM
10
+ from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
11
+ from eval_framework.tasks.base import BaseTask
12
+ from eval_framework.tasks.perturbation import PerturbationConfig
13
+ from eval_framework.tasks.registry import get_task, validate_task_name
14
+ from eval_framework.utils.constants import ROOT_DIR
15
+
16
+
17
+ class EvalConfig(BaseConfig):
18
+ output_dir: Path = ROOT_DIR
19
+ wandb_project: str | None = None
20
+ wandb_entity: str | None = None
21
+ wandb_run_id: str | None = None
22
+ hf_upload_dir: str | None = None
23
+ hf_upload_repo: str | None = None
24
+ num_fewshot: Annotated[int, Field(ge=0)] = 0
25
+ num_samples: Annotated[int | None, Field(ge=1)] = 10 # Allows None or int
26
+ max_tokens: int | None = None
27
+ perturbation_config: PerturbationConfig | None = None
28
+ task_name: Annotated[str, AfterValidator(validate_task_name)]
29
+ task_subjects: list[str] | None = None
30
+ hf_revision: str | None = None
31
+ llm_class: type[BaseLLM]
32
+ llm_args: dict[str, Any] = Field(default_factory=dict)
33
+ llm_judge_class: type[BaseLLM] | None = None
34
+ judge_model_args: dict[str, Any] = Field(default_factory=dict)
35
+ batch_size: Annotated[int, Field(ge=1)] = 1
36
+ description: str | None = None
37
+ save_intermediate_results: bool = True
38
+ save_logs: bool = True
39
+
40
+ @property
41
+ def task_class(self) -> type[BaseTask]:
42
+ return get_task(self.task_name)
43
+
44
+ @field_serializer("output_dir")
45
+ def serialize_output_dir(self, value: Path) -> str:
46
+ return str(value)
47
+
48
+ @field_validator("output_dir", mode="before")
49
+ @classmethod
50
+ def validate_output_dir(cls, value: str | Path) -> Path:
51
+ if isinstance(value, str):
52
+ return Path(value)
53
+ return value
54
+
55
+ @field_validator("llm_args", mode="before")
56
+ @classmethod
57
+ def validate_llm_args(cls, value: dict[str, Any]) -> dict[str, Any]:
58
+ def convert_value(v: Any) -> Any:
59
+ if isinstance(v, dict):
60
+ # Recursively process nested dictionaries (like sampling_params)
61
+ return {k: convert_value(nested_v) for k, nested_v in v.items()}
62
+ elif isinstance(v, str):
63
+ try:
64
+ # Try to evaluate as a Python literal (int, float, bool, None, list, dict, etc.)
65
+ return ast.literal_eval(v)
66
+ except (ValueError, SyntaxError):
67
+ return v # keep as string if not a valid literal
68
+ else:
69
+ return v # already proper type
70
+
71
+ return convert_value(value)
72
+
73
+ @field_validator("judge_model_args", mode="before")
74
+ @classmethod
75
+ def validate_judge_model_args(cls, value: dict[str, Any]) -> dict[str, Any]:
76
+ typed_value = {}
77
+ for k, v in value.items():
78
+ try: # maybe this llm argument is actually a number?
79
+ if "." in str(v):
80
+ v = float(v)
81
+ else:
82
+ v = int(v)
83
+ except ValueError:
84
+ pass
85
+ typed_value[k] = v
86
+ return typed_value
87
+
88
+ @model_validator(mode="after")
89
+ def validate_llm_judge_defined(self) -> "EvalConfig":
90
+ task = get_task(self.task_name)
91
+ for metric_class in task.METRICS:
92
+ if issubclass(metric_class, BaseLLMJudgeMetric):
93
+ assert self.llm_judge_class is not None, "The LLM Judge must be defined for this evaluation task."
94
+ return self
95
+
96
+ @field_serializer("llm_class")
97
+ def serialize_llm_class(self, value: type[BaseLLM] | None) -> str | None:
98
+ """Serialize the class into its fully qualified name."""
99
+ if value:
100
+ return value.__name__
101
+ return None
102
+
103
+ @field_serializer("llm_judge_class")
104
+ def serialize_llm_judge_class(self, value: type[BaseLLM] | None) -> str | None:
105
+ """Serialize the class into its fully qualified name."""
106
+ if value:
107
+ return value.__name__
108
+ return None
109
+
110
+ def model_json_dump(self) -> str:
111
+ model_dump = self.model_dump()
112
+ return json.dumps(model_dump, sort_keys=True)
@@ -0,0 +1,83 @@
1
+ import threading
2
+ from enum import Enum
3
+ from typing import Annotated, Any, TypeVar
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+ from eval_framework.logger import logger
8
+ from eval_framework.tasks.base import RANDOM_SEED, BaseTask, Sample
9
+ from eval_framework.tasks.utils import Editor, HatPaperEditor
10
+
11
+
12
+ class PerturbationType(str, Enum):
13
+ # Editor methods
14
+ EDITOR = "editor"
15
+ # Hat paper methods
16
+ PERMUTE = "permute"
17
+ REPLACE = "replace"
18
+ DELETE = "delete"
19
+ UPPERCASE = "uppercase"
20
+
21
+
22
+ class PerturbationConfig(BaseModel):
23
+ model_config = ConfigDict(extra="forbid")
24
+ type: PerturbationType = PerturbationType.EDITOR
25
+ probability: Annotated[float, Field(ge=0.0, le=1.0)] = 0.1
26
+ seed: int = RANDOM_SEED
27
+ verbose: bool = False
28
+
29
+
30
+ _DOCKER_LAUNCH_LOCK = threading.Lock()
31
+ _AUGMENTER_PORT = 0
32
+
33
+ SomeBaseTask = TypeVar("SomeBaseTask", bound=BaseTask[Any])
34
+
35
+
36
+ def create_perturbation_class[T: BaseTask](base_class: type[T], perturbation_config: PerturbationConfig) -> type[T]:
37
+ # mypy seems to have trouble inferring the type
38
+ class EditorPerturbation(base_class): # type: ignore
39
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
40
+ super().__init__(*args, **kwargs)
41
+ self.perturbation_config = perturbation_config
42
+ self.editor = Editor(
43
+ language="de" if base_class.LANGUAGE == "German" else "en", seed=perturbation_config.seed
44
+ )
45
+
46
+ def _get_instruction_text(self, sample: Sample) -> str:
47
+ text = super()._get_instruction_text(sample)
48
+ if self.perturbation_config.verbose:
49
+ logger.info(f"Perturbating text: {text}")
50
+ result = self.editor(
51
+ text, self.perturbation_config.probability, getattr(self, "PERTURBATION_UNMODIFIABLE_WORDS", [])
52
+ )
53
+ if self.perturbation_config.verbose:
54
+ logger.info(f"Perturbed text: {result}")
55
+ return result
56
+
57
+ class HatPaperPerturbation(base_class): # type: ignore
58
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
59
+ super().__init__(*args, **kwargs)
60
+ self.perturbation_config = perturbation_config
61
+ self.editor = HatPaperEditor(seed=perturbation_config.seed)
62
+
63
+ def _get_instruction_text(self, sample: Sample) -> str:
64
+ text = super()._get_instruction_text(sample)
65
+ if self.perturbation_config.verbose:
66
+ logger.info(f"Perturbating text: {text}")
67
+ words = getattr(self, "PERTURBATION_UNMODIFIABLE_WORDS", [])
68
+ if self.perturbation_config.type == PerturbationType.PERMUTE:
69
+ result = self.editor.permute_chars_in_string(text, self.perturbation_config.probability, words)
70
+ elif self.perturbation_config.type == PerturbationType.REPLACE:
71
+ result = self.editor.replace_chars_in_string(text, self.perturbation_config.probability, words)
72
+ elif self.perturbation_config.type == PerturbationType.DELETE:
73
+ result = self.editor.delete_chars_in_string(text, self.perturbation_config.probability, words)
74
+ elif self.perturbation_config.type == PerturbationType.UPPERCASE:
75
+ result = self.editor.upper_case_string(text)
76
+ if self.perturbation_config.verbose:
77
+ logger.info(f"Perturbed text: {result}")
78
+ return result
79
+
80
+ if perturbation_config.type == PerturbationType.EDITOR:
81
+ return EditorPerturbation
82
+ else:
83
+ return HatPaperPerturbation