eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,7 @@
1
+ from importlib.metadata import version
2
+
3
+ __version__ = version("eval-framework")
4
+
5
+ del version
6
+
7
+ __all__ = ["__version__"]
@@ -0,0 +1,36 @@
1
+ from enum import Enum
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import yaml # type: ignore[import-untyped]
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+
9
+ class BaseConfig(BaseModel):
10
+ model_config = ConfigDict(extra="forbid", frozen=True, protected_namespaces=())
11
+
12
+ def as_dict(self) -> dict[str, Any]:
13
+ def simplify_recursive(obj: Any) -> Any:
14
+ if isinstance(obj, dict):
15
+ return {key: simplify_recursive(value) for key, value in obj.items()}
16
+ elif isinstance(obj, list):
17
+ return [simplify_recursive(item) for item in obj]
18
+ elif isinstance(obj, Path):
19
+ return str(obj)
20
+ elif isinstance(obj, Enum):
21
+ return obj.value
22
+ else:
23
+ return obj
24
+
25
+ return simplify_recursive(self.model_dump())
26
+
27
+ @classmethod
28
+ def from_yaml(cls, yml_filename: str | Path) -> "BaseConfig":
29
+ with open(yml_filename) as conf_file:
30
+ config_dict = yaml.load(conf_file, Loader=yaml.FullLoader)
31
+
32
+ return cls(**config_dict)
33
+
34
+ def save(self, out_file: Path) -> None:
35
+ with open(out_file, "w", encoding="UTF-8") as f:
36
+ yaml.safe_dump(self.model_dump(mode="json"), f)
File without changes
@@ -0,0 +1,177 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Annotated, Any
4
+
5
+ from determined._info import get_cluster_info
6
+ from determined.core._context import Context
7
+ from determined.core._context import init as determined_core_init
8
+ from determined.core._distributed import DummyDistributedContext
9
+ from pydantic import AfterValidator, BaseModel, ConfigDict
10
+
11
+ from eval_framework.context.eval import EvalContext
12
+ from eval_framework.context.local import _load_model
13
+ from eval_framework.llm.base import BaseLLM
14
+ from eval_framework.tasks.eval_config import EvalConfig
15
+ from eval_framework.tasks.perturbation import PerturbationConfig
16
+ from eval_framework.tasks.registry import validate_task_name
17
+ from eval_framework.tasks.task_loader import load_extra_tasks
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class TaskArgs(BaseModel):
23
+ model_config = ConfigDict(extra="forbid")
24
+ task_name: Annotated[str, AfterValidator(validate_task_name)]
25
+ num_fewshot: int
26
+ num_samples: int | None = None
27
+ max_tokens: int | None = None
28
+ batch_size: int | None = None
29
+ judge_model_name: str | None = None
30
+ judge_model_args: dict[str, Any] = {}
31
+ task_subjects: list[str] | None = None
32
+ hf_revision: str | None = None
33
+ perturbation_config: PerturbationConfig | None = None
34
+
35
+
36
+ class Hyperparameters(BaseModel):
37
+ model_config = ConfigDict(extra="forbid")
38
+ llm_name: str
39
+ output_dir: Path
40
+ hf_upload_dir: str | None = None
41
+ hf_upload_repo: str | None = None
42
+ wandb_project: str | None = None
43
+ wandb_entity: str | None = None
44
+ wandb_run_id: str | None = None
45
+ wandb_upload_results: bool | None = None
46
+ description: str | None = None
47
+ task_args: TaskArgs
48
+ llm_args: dict[str, Any] | None = {}
49
+ extra_task_modules: list[str] | None = None
50
+ delete_output_dir_after_upload: bool | None = None
51
+
52
+
53
+ class DeterminedContext(EvalContext):
54
+ def __init__(self, **kwargs: Any) -> None:
55
+ super().__init__(**kwargs)
56
+ self._core_context: Context | None = None
57
+
58
+ def __enter__(self) -> "DeterminedContext":
59
+ distributed_context = DummyDistributedContext()
60
+ self._core_context = determined_core_init(distributed=distributed_context)
61
+ self._core_context.start()
62
+ info = get_cluster_info()
63
+
64
+ if info is None:
65
+ raise RuntimeError("Failed to retrieve cluster info.")
66
+
67
+ # Load extra tasks if specified first
68
+ extra_task_modules = info.trial.hparams.get("extra_task_modules", None)
69
+ if extra_task_modules:
70
+ name = "extra_task_modules"
71
+ val_cli = getattr(self, name, None)
72
+ val_hparams = extra_task_modules
73
+ if val_hparams:
74
+ if val_cli and val_hparams and val_cli != val_hparams:
75
+ logger.info(
76
+ f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters:"
77
+ f"({val_hparams}). If it fails due to duplicate task names, remove the CLI argument and"
78
+ "consolidate as a determined hyperparameter instead."
79
+ )
80
+ load_extra_tasks(val_hparams)
81
+
82
+ self.hparams = Hyperparameters(**info.trial.hparams)
83
+
84
+ for name in [
85
+ "llm_name",
86
+ "llm_args",
87
+ "output_dir",
88
+ "hf_upload_dir",
89
+ "hf_upload_repo",
90
+ "wandb_project",
91
+ "wandb_entity",
92
+ "wandb_run_id",
93
+ "wandb_upload_results",
94
+ "description",
95
+ "delete_output_dir_after_upload",
96
+ ]:
97
+ val_cli = getattr(self, name, None)
98
+ val_hparams = getattr(self.hparams, name, None)
99
+ if val_cli and val_hparams and val_cli != val_hparams:
100
+ logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).")
101
+
102
+ for name in [
103
+ "num_samples",
104
+ "max_tokens",
105
+ "num_fewshot",
106
+ "task_name",
107
+ "task_subjects",
108
+ "batch_size",
109
+ "hf_revision",
110
+ "judge_model_name",
111
+ "judge_model_args",
112
+ "perturbation_config",
113
+ ]:
114
+ val_cli = getattr(self, name, None)
115
+ val_hparams = getattr(self.hparams.task_args, name, None)
116
+ if val_cli and val_hparams and val_cli != val_hparams:
117
+ logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).")
118
+
119
+ # Hyperparameters take precedence over core context
120
+ llm_name = self.hparams.llm_name or self.llm_name
121
+ judge_model_name = self.hparams.task_args.judge_model_name or self.judge_model_name
122
+
123
+ llm_class = _load_model(llm_name, models_path=self.models_path)
124
+ llm_judge_class: type[BaseLLM] | None = (
125
+ _load_model(judge_model_name, models_path=self.judge_models_path, info="judge")
126
+ if judge_model_name
127
+ else None
128
+ )
129
+
130
+ # for all optional hyperparameters, resort to the respective CLI argument if the hyperparameter is not set
131
+ self.config = EvalConfig(
132
+ llm_class=llm_class,
133
+ llm_args=self.hparams.llm_args or self.llm_args,
134
+ num_samples=self.hparams.task_args.num_samples or self.num_samples,
135
+ max_tokens=self.hparams.task_args.max_tokens or self.max_tokens,
136
+ num_fewshot=self.hparams.task_args.num_fewshot,
137
+ task_name=self.hparams.task_args.task_name,
138
+ task_subjects=self.hparams.task_args.task_subjects,
139
+ hf_revision=self.hparams.task_args.hf_revision or self.hf_revision,
140
+ perturbation_config=self.hparams.task_args.perturbation_config or self.perturbation_config,
141
+ output_dir=self.hparams.output_dir,
142
+ llm_judge_class=llm_judge_class,
143
+ judge_model_args=self.hparams.task_args.judge_model_args or self.judge_model_args,
144
+ hf_upload_dir=self.hparams.hf_upload_dir or self.hf_upload_dir,
145
+ hf_upload_repo=self.hparams.hf_upload_repo or self.hf_upload_repo,
146
+ wandb_project=self.hparams.wandb_project or self.wandb_project,
147
+ wandb_entity=self.hparams.wandb_entity or self.wandb_entity,
148
+ wandb_run_id=self.hparams.wandb_run_id or self.wandb_run_id,
149
+ wandb_upload_results=self.hparams.wandb_upload_results or self.wandb_upload_results,
150
+ batch_size=self.hparams.task_args.batch_size or self.batch_size,
151
+ description=self.hparams.description or self.description,
152
+ randomize_judge_order=self.randomize_judge_order,
153
+ delete_output_dir_after_upload=self.hparams.delete_output_dir_after_upload
154
+ or self.delete_output_dir_after_upload,
155
+ )
156
+
157
+ return self
158
+
159
+ def __exit__(
160
+ self,
161
+ exc_type: type[BaseException] | None,
162
+ exc_value: BaseException | None,
163
+ traceback: Any | None,
164
+ ) -> None:
165
+ if self._core_context is not None:
166
+ self._core_context.close()
167
+ self._core_context = None
168
+
169
+ def should_preempt(self) -> bool:
170
+ if self._core_context is None:
171
+ return False
172
+ return self._core_context.preempt.should_preempt()
173
+
174
+ def get_trial_id(self) -> int | None:
175
+ if self._core_context is None:
176
+ return None
177
+ return self._core_context.train._trial_id
@@ -0,0 +1,121 @@
1
+ import importlib.util
2
+ import inspect
3
+ import sys
4
+ from contextlib import AbstractContextManager
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import eval_framework
10
+ from eval_framework.llm.base import BaseLLM
11
+ from eval_framework.tasks.eval_config import EvalConfig
12
+ from eval_framework.tasks.perturbation import PerturbationConfig
13
+
14
+
15
+ def import_models(models_file: PathLike | str) -> dict[str, type[BaseLLM]]:
16
+ models_file = Path(models_file).resolve()
17
+ library_path = Path(eval_framework.__path__[0]).resolve()
18
+
19
+ # Imports from the eval_framework module need special care to avoid
20
+ # import issues
21
+ if models_file.is_relative_to(library_path):
22
+ relative_path = models_file.relative_to(library_path.parent)
23
+ module_name = ".".join(relative_path.with_suffix("").parts)
24
+ module = importlib.import_module(module_name)
25
+ else:
26
+ module_name = models_file.stem
27
+
28
+ spec = importlib.util.spec_from_file_location(module_name, str(models_file))
29
+
30
+ if spec is None:
31
+ raise ImportError(f"Could not load module '{models_file}'.")
32
+
33
+ module = importlib.util.module_from_spec(spec)
34
+ sys.modules[module_name] = module
35
+
36
+ if spec.loader is None:
37
+ raise ImportError(f"Could not load module '{models_file}'.")
38
+
39
+ spec.loader.exec_module(module)
40
+
41
+ subclasses = {}
42
+ for name, clazz in inspect.getmembers(module, inspect.isclass):
43
+ if issubclass(clazz, BaseLLM) and clazz is not BaseLLM:
44
+ subclasses[name] = clazz
45
+
46
+ return subclasses
47
+
48
+
49
+ class EvalContext(AbstractContextManager):
50
+ def __init__(
51
+ self,
52
+ llm_name: str,
53
+ models_path: Path,
54
+ num_samples: int | None = None,
55
+ max_tokens: int | None = None,
56
+ num_fewshot: int | None = None,
57
+ task_name: str | None = None,
58
+ task_subjects: list[str] | None = None,
59
+ hf_revision: str | None = None,
60
+ output_dir: Path | None = None,
61
+ wandb_project: str | None = None,
62
+ wandb_entity: str | None = None,
63
+ wandb_run_id: str | None = None,
64
+ wandb_upload_results: bool | None = None,
65
+ hf_upload_dir: str | None = None,
66
+ hf_upload_repo: str | None = None,
67
+ llm_args: dict[str, Any] | None = None,
68
+ judge_models_path: Path | None = None,
69
+ judge_model_name: str | None = None,
70
+ judge_model_args: dict[str, Any] | None = None,
71
+ batch_size: int | None = None,
72
+ description: str | None = None,
73
+ perturbation_type: str | None = None,
74
+ perturbation_probability: float | None = None,
75
+ perturbation_seed: int | None = None,
76
+ randomize_judge_order: bool = False,
77
+ delete_output_dir_after_upload: bool | None = None,
78
+ ) -> None:
79
+ self.llm_name = llm_name
80
+ self.models_path = models_path
81
+ self.num_samples = num_samples
82
+ self.max_tokens = max_tokens
83
+ self.num_fewshot = num_fewshot
84
+ self.task_name = task_name
85
+ self.task_subjects = task_subjects
86
+ self.hf_revision = hf_revision
87
+ self.output_dir = output_dir
88
+ self.wandb_project = wandb_project
89
+ self.wandb_entity = wandb_entity
90
+ self.wandb_run_id = wandb_run_id
91
+ self.wandb_upload_results = wandb_upload_results
92
+ self.hf_upload_dir = hf_upload_dir
93
+ self.hf_upload_repo = hf_upload_repo
94
+ self.llm_args = llm_args if llm_args is not None else {}
95
+ self.judge_models_path = judge_models_path
96
+ self.judge_model_name = judge_model_name
97
+ self.judge_model_args = judge_model_args if judge_model_args is not None else {}
98
+ self.batch_size = batch_size
99
+ self.description = description
100
+ self.randomize_judge_order = randomize_judge_order
101
+ self.delete_output_dir_after_upload = delete_output_dir_after_upload
102
+
103
+ if perturbation_type or perturbation_probability is not None:
104
+ perturbation = {
105
+ "type": perturbation_type,
106
+ "probability": perturbation_probability,
107
+ "seed": perturbation_seed,
108
+ }
109
+ self.perturbation_config: PerturbationConfig | None = PerturbationConfig(
110
+ **{k: v for k, v in perturbation.items() if v is not None}
111
+ )
112
+ else:
113
+ self.perturbation_config = None
114
+
115
+ self.config: EvalConfig | None = None
116
+
117
+ def should_preempt(self) -> bool:
118
+ return False
119
+
120
+ def get_trial_id(self) -> int | None:
121
+ return None
@@ -0,0 +1,78 @@
1
+ import importlib
2
+ from os import PathLike
3
+ from typing import Any
4
+
5
+ from eval_framework.context.eval import EvalContext, import_models
6
+ from eval_framework.llm.base import BaseLLM
7
+ from eval_framework.tasks.eval_config import EvalConfig
8
+
9
+
10
+ def _load_model(llm_name: str, models_path: str | PathLike | None, *, info: str = "") -> type[BaseLLM]:
11
+ """Load a model class either from a models file or as a fully qualified module path.
12
+
13
+ Args:
14
+ llm_name: The name of the model class to load, or a fully qualified module path.
15
+ models_path: The path to a Python file containing model class definitions
16
+ info: Additional info to include in error messages.
17
+ Returns:
18
+ The model class.
19
+ """
20
+ if models_path is None or "." in llm_name:
21
+ # The llm_name must a a fully qualified module path
22
+ if "." not in llm_name:
23
+ raise ValueError(f"LLM {info} '{llm_name}' is not a fully qualified module path.")
24
+ module_path, llm_class_name = llm_name.rsplit(".", 1)
25
+ module = importlib.import_module(module_path)
26
+ if not hasattr(module, llm_class_name):
27
+ raise ValueError(f"LLM '{llm_class_name}' not found in module '{module_path}'.")
28
+ return getattr(module, llm_class_name)
29
+ else:
30
+ models_dict = import_models(models_path)
31
+ if llm_name not in models_dict:
32
+ if info:
33
+ info = f"{info.strip()} "
34
+ raise ValueError(f"LLM {info} '{llm_name}' not found in {models_path}.")
35
+ return models_dict[llm_name]
36
+
37
+
38
+ class LocalContext(EvalContext):
39
+ def __enter__(self) -> "LocalContext":
40
+ llm_class = _load_model(self.llm_name, models_path=self.models_path)
41
+ self.llm_judge_class: type[BaseLLM] | None = None
42
+ if self.judge_model_name is not None:
43
+ self.llm_judge_class = _load_model(self.judge_model_name, models_path=self.judge_models_path, info="judge")
44
+
45
+ self.config = EvalConfig(
46
+ llm_class=llm_class,
47
+ llm_args=self.llm_args,
48
+ num_samples=self.num_samples,
49
+ max_tokens=self.max_tokens,
50
+ num_fewshot=self.num_fewshot,
51
+ perturbation_config=self.perturbation_config,
52
+ task_name=self.task_name,
53
+ task_subjects=self.task_subjects,
54
+ hf_revision=self.hf_revision,
55
+ output_dir=self.output_dir,
56
+ hf_upload_dir=self.hf_upload_dir,
57
+ hf_upload_repo=self.hf_upload_repo,
58
+ wandb_entity=self.wandb_entity,
59
+ wandb_project=self.wandb_project,
60
+ wandb_run_id=self.wandb_run_id,
61
+ wandb_upload_results=self.wandb_upload_results,
62
+ llm_judge_class=self.llm_judge_class,
63
+ judge_model_args=self.judge_model_args,
64
+ batch_size=self.batch_size,
65
+ description=self.description,
66
+ randomize_judge_order=self.randomize_judge_order,
67
+ delete_output_dir_after_upload=self.delete_output_dir_after_upload,
68
+ )
69
+
70
+ return self
71
+
72
+ def __exit__(
73
+ self,
74
+ exc_type: type[BaseException] | None,
75
+ exc_value: BaseException | None,
76
+ traceback: Any | None,
77
+ ) -> None:
78
+ pass
@@ -0,0 +1,234 @@
1
+ import logging
2
+ import math
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import wandb
7
+ from tqdm import tqdm
8
+
9
+ from eval_framework.metrics.base import BaseMetric
10
+ from eval_framework.metrics.efficiency.bytes_per_sequence_position import (
11
+ BytesCompletion,
12
+ BytesLoglikelihood,
13
+ SequencePositionsCompletion,
14
+ SequencePositionsLoglikelihood,
15
+ )
16
+ from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
17
+ from eval_framework.result_processors.base import Result, ResultProcessor
18
+ from eval_framework.shared.types import Completion, Loglikelihood
19
+ from eval_framework.tasks.base import ResponseType
20
+ from eval_framework.tasks.eval_config import EvalConfig
21
+ from eval_framework.tasks.registry import get_task
22
+ from eval_framework.utils.constants import RED, RESET
23
+ from eval_framework.utils.tqdm_handler import get_disable_bar_flag, safe_tqdm_write
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class EvaluationGenerator:
29
+ def __init__(self, config: EvalConfig, result_processor: ResultProcessor) -> None:
30
+ logger.info("EvaluationGenerator initialized")
31
+
32
+ self.few_shot = config.num_fewshot
33
+ self.config = config
34
+ self.num_samples = config.num_samples
35
+ self.max_tokens = config.max_tokens
36
+ self.result_processor = result_processor
37
+ self.save_intermediate_results = config.save_intermediate_results
38
+
39
+ task_class = get_task(config.task_name)
40
+ if task_class.RESPONSE_TYPE == ResponseType.COMPLETION:
41
+ self.metrics = task_class.METRICS + [BytesCompletion, SequencePositionsCompletion]
42
+ elif task_class.RESPONSE_TYPE == ResponseType.LOGLIKELIHOODS:
43
+ self.metrics = task_class.METRICS + [BytesLoglikelihood, SequencePositionsLoglikelihood]
44
+ else:
45
+ raise NotImplementedError
46
+
47
+ self.task_name = task_class.NAME
48
+
49
+ def _run_metric_calculators(self, responses: list[Completion | Loglikelihood]) -> list[Result]:
50
+ results: list[Result] = self.result_processor.load_metrics_results()
51
+ llm_name = self.result_processor.load_metadata()["llm_name"]
52
+
53
+ subject_result_id_existing = set()
54
+ for result in results:
55
+ subject_result_id_existing.add(f"{result.subject}_{result.id}_{result.metric_class_name}")
56
+
57
+ """
58
+ we have three dimensions: subject, metric, sample_id
59
+ we wanna average over sample_id
60
+ and also over all subjects by averaging over the averages
61
+ dict[metric, dict[subject, dict[sample_id, list[result]]]]
62
+ """
63
+ llm_judge = None
64
+ for metric_class in self.metrics:
65
+ metric: BaseMetric
66
+ if issubclass(metric_class, BaseLLMJudgeMetric):
67
+ if llm_judge is None:
68
+ assert self.config.llm_judge_class is not None, "The llm_judge_class must be defined in the config."
69
+ llm_judge = self.config.llm_judge_class(**self.config.judge_model_args)
70
+ metric = metric_class(
71
+ llm_judge=llm_judge,
72
+ randomize_order=self.config.randomize_judge_order,
73
+ )
74
+ else:
75
+ metric = metric_class()
76
+
77
+ logger.info(f"Starting calculation of {metric.NAME}")
78
+ safe_tqdm_write(f"INFO: Calculating {metric.NAME}")
79
+ for response in tqdm(responses, desc=f"Calculating {metric.NAME}", disable=get_disable_bar_flag()):
80
+ if f"{response.subject}_{response.id}_{metric.__class__.__name__}" in subject_result_id_existing:
81
+ continue
82
+
83
+ subject = response.subject
84
+ metric_results = metric.calculate(response)
85
+ for metric_result in metric_results:
86
+ if "/" in metric_result.metric_name:
87
+ metric_name, key = metric_result.metric_name.split("/")
88
+ else:
89
+ metric_name = metric_result.metric_name
90
+ key = None
91
+ completion = response.completion if isinstance(response, Completion) else str(response.ground_truth)
92
+
93
+ result = Result(
94
+ id=response.id,
95
+ metric_class_name=metric.__class__.__name__,
96
+ metric_name=metric_name,
97
+ num_fewshot=self.few_shot,
98
+ key=key,
99
+ subject=subject,
100
+ llm_name=llm_name,
101
+ task_name=self.task_name,
102
+ value=metric_result.value,
103
+ higher_is_better=metric_result.higher_is_better,
104
+ prompt=response.prompt,
105
+ response=completion,
106
+ llm_judge_prompt=metric_result.llm_judge_prompt,
107
+ llm_judge_response=metric_result.llm_judge_response,
108
+ code_execution_trace=metric_result.code_execution_trace,
109
+ error=metric_result.error,
110
+ )
111
+ results.append(result)
112
+ if self.save_intermediate_results:
113
+ self.result_processor.save_metrics_result(result)
114
+
115
+ logger.info(f"Completed calculation of {metric.NAME}")
116
+ safe_tqdm_write(f"INFO: Completed {metric.NAME}")
117
+
118
+ if not self.save_intermediate_results:
119
+ self.result_processor.save_metrics_results(results)
120
+ return results
121
+
122
+ def _aggregate_results(self, results: list[Result]) -> dict[str, float | None]:
123
+ data = pd.DataFrame([r.model_dump() for r in results])
124
+ if len(data) == 0:
125
+ return {}
126
+ data.fillna({"key": ""}, inplace=True)
127
+ metrics = sorted(data["metric_name"].unique())
128
+ aggregated_results: dict[str, float | None] = {}
129
+
130
+ for metric in metrics:
131
+ # filter for metric
132
+ data_subset = data[data["metric_name"] == metric][["subject", "key", "value", "error"]]
133
+
134
+ # filter and count errors
135
+ total_count = len(data_subset)
136
+
137
+ mask = data["error"].isnull()
138
+ data_subset_error_free = data_subset.loc[mask, ["subject", "key", "value"]]
139
+ # data_subset_error_free = data_subset[data_subset["error"].isnull()][["subject", "key", "value"]]
140
+
141
+ aggregated_results[f"ErrorFreeRatio {metric}"] = float(len(data_subset_error_free) / total_count)
142
+
143
+ # aggregate by key and subject first to have equal weights for all key / subject combinations
144
+ key_subject_mean = data_subset_error_free.groupby(["key", "subject"]).mean()
145
+ aggregated_results[f"Average {metric}"] = float(key_subject_mean[["value"]].mean()["value"])
146
+
147
+ std_err_mean_sum_of_squares = 0.0
148
+ std_err_mean_total_num_samples = 0.0
149
+ std_err_mean_num_subjects = 0
150
+
151
+ for column in ["key", "subject"]:
152
+ if len(data_subset[column].unique()) > 1:
153
+ for name, _group in key_subject_mean.groupby([column]):
154
+ mask = data_subset[column] == name[0]
155
+ group = data_subset.loc[mask, ["subject", "key", "value", "error"]]
156
+ # group = data_subset[data[column] == name][["subject", "key", "value", "error"]]
157
+ group_total_count = len(group)
158
+ group_error_free = group[group["error"].isnull()][["subject", "key", "value"]]
159
+ aggregated_results[f"ErrorFreeRatio {metric} - {name[0]}"] = float(
160
+ len(group_error_free) / group_total_count
161
+ )
162
+
163
+ group_key_subject_mean = group_error_free.groupby(["key", "subject"]).mean()
164
+ value = float(group_key_subject_mean[["value"]].mean()["value"])
165
+ aggregated_results[f"Average {metric} - {name[0]}"] = value if not math.isnan(value) else None
166
+
167
+ if not ("SequencePositions" in metric or "Bytes" in metric):
168
+ # calculate standard error for selected metrics
169
+ group_key_subject_std = group_error_free.groupby(["key", "subject"]).std()
170
+ std = float(group_key_subject_std[["value"]].mean()["value"])
171
+ num_samples = len(group_error_free)
172
+
173
+ if math.isnan(std) or num_samples == 0:
174
+ aggregated_results[f"StdErr {metric} - {name[0]}"] = None
175
+ else:
176
+ aggregated_results[f"StdErr {metric} - {name[0]}"] = std / np.sqrt(num_samples)
177
+ aggregated_results[f"NumSamples {metric} - {name[0]}"] = num_samples
178
+
179
+ std_err_mean_sum_of_squares += std**2 / num_samples
180
+ std_err_mean_total_num_samples += num_samples
181
+ std_err_mean_num_subjects += 1
182
+
183
+ if not ("SequencePositions" in metric or "Bytes" in metric):
184
+ # calculate standard error for selected metrics
185
+ if std_err_mean_total_num_samples > 0:
186
+ # calculate the standard error of the mean (SEM) for the aggregated results (eg. add in quadrature)
187
+ # SEM = sqrt(sum(variance_i * n_i) / i)
188
+ # where variance_i is the variance of each group and i is the number of groups
189
+ # (the combined mean is also not weighted by the number of samples)
190
+ if math.isnan(std) or std_err_mean_total_num_samples == 0:
191
+ aggregated_results[f"StdErr {metric}"] = None
192
+ else:
193
+ aggregated_results[f"StdErr {metric}"] = np.sqrt(
194
+ std_err_mean_sum_of_squares / std_err_mean_num_subjects
195
+ )
196
+ aggregated_results[f"NumSamples {metric}"] = std_err_mean_total_num_samples
197
+ else:
198
+ # if there are no sub-groups to combine, calculate the SEM here directly
199
+ key_subject_std = data_subset_error_free.groupby(["key", "subject"]).std()
200
+ std = float(key_subject_std[["value"]].mean()["value"])
201
+ num_samples = len(data_subset_error_free)
202
+ if math.isnan(std) or num_samples == 0:
203
+ aggregated_results[f"StdErr {metric}"] = None
204
+ else:
205
+ aggregated_results[f"StdErr {metric}"] = std / np.sqrt(num_samples)
206
+ aggregated_results[f"NumSamples {metric}"] = num_samples
207
+
208
+ if (
209
+ "Average Bytes" in aggregated_results
210
+ and "Average SequencePositions" in aggregated_results
211
+ and aggregated_results["Average Bytes"]
212
+ and aggregated_results["Average SequencePositions"]
213
+ ):
214
+ aggregated_results["Average Bytes per Sequence Position"] = (
215
+ aggregated_results["Average Bytes"] / aggregated_results["Average SequencePositions"]
216
+ )
217
+
218
+ return aggregated_results
219
+
220
+ def run_eval(self) -> list[Result]:
221
+ """Runs evaluation using saved completions."""
222
+ logger.info("Running evaluation...")
223
+ responses = self.result_processor.load_responses()
224
+ if not responses:
225
+ raise ValueError("No saved completions found. Run 'run_completions' first.")
226
+
227
+ metrics_results = self._run_metric_calculators(responses)
228
+ aggregated_results = self._aggregate_results(metrics_results)
229
+
230
+ wandb.log(aggregated_results)
231
+ self.result_processor.save_aggregated_results(aggregated_results)
232
+ logger.info(aggregated_results)
233
+ logger.info(f"{RED}[ Evaluation completed and results saved! ]{RESET}")
234
+ return metrics_results
@@ -0,0 +1,2 @@
1
+ class LogicError(Exception):
2
+ pass
@@ -0,0 +1,5 @@
1
+ IFEval implementation here is taken 1:1 (besides some minuscule details, like adding Swedish to `LANGUAGE_CODES`)
2
+ from [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks/ifeval) as of Dec 19 2024.
3
+ The original repository is https://github.com/google-research/google-research/tree/master/instruction_following_eval .
4
+
5
+ This code doesn't have to be unit-tested.