eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,7 @@
1
+ from importlib.metadata import version
2
+
3
+ __version__ = version("eval-framework")
4
+
5
+ del version
6
+
7
+ __all__ = ["__version__"]
@@ -0,0 +1,36 @@
1
+ from enum import Enum
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import yaml # type: ignore[import-untyped]
6
+ from pydantic import BaseModel, ConfigDict
7
+
8
+
9
+ class BaseConfig(BaseModel):
10
+ model_config = ConfigDict(extra="forbid", frozen=True, protected_namespaces=())
11
+
12
+ def as_dict(self) -> dict[str, Any]:
13
+ def simplify_recursive(obj: Any) -> Any:
14
+ if isinstance(obj, dict):
15
+ return {key: simplify_recursive(value) for key, value in obj.items()}
16
+ elif isinstance(obj, list):
17
+ return [simplify_recursive(item) for item in obj]
18
+ elif isinstance(obj, Path):
19
+ return str(obj)
20
+ elif isinstance(obj, Enum):
21
+ return obj.value
22
+ else:
23
+ return obj
24
+
25
+ return simplify_recursive(self.model_dump())
26
+
27
+ @classmethod
28
+ def from_yaml(cls, yml_filename: str | Path) -> "BaseConfig":
29
+ with open(yml_filename) as conf_file:
30
+ config_dict = yaml.load(conf_file, Loader=yaml.FullLoader)
31
+
32
+ return cls(**config_dict)
33
+
34
+ def save(self, out_file: Path) -> None:
35
+ with open(out_file, "w", encoding="UTF-8") as f:
36
+ yaml.safe_dump(self.model_dump(mode="json"), f)
File without changes
@@ -0,0 +1,170 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Annotated, Any
4
+
5
+ from determined._info import get_cluster_info
6
+ from determined.core._context import Context
7
+ from determined.core._context import init as determined_core_init
8
+ from determined.core._distributed import DummyDistributedContext
9
+ from pydantic import AfterValidator, BaseModel, ConfigDict
10
+
11
+ from eval_framework.context.eval import EvalContext, import_models
12
+ from eval_framework.llm.base import BaseLLM
13
+ from eval_framework.tasks.eval_config import EvalConfig
14
+ from eval_framework.tasks.perturbation import PerturbationConfig
15
+ from eval_framework.tasks.registry import validate_task_name
16
+ from eval_framework.tasks.task_loader import load_extra_tasks
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class TaskArgs(BaseModel):
22
+ model_config = ConfigDict(extra="forbid")
23
+ task_name: Annotated[str, AfterValidator(validate_task_name)]
24
+ num_fewshot: int
25
+ num_samples: int | None = None
26
+ max_tokens: int | None = None
27
+ batch_size: int | None = None
28
+ judge_model_name: str | None = None
29
+ judge_model_args: dict[str, Any] = {}
30
+ task_subjects: list[str] | None = None
31
+ hf_revision: str | None = None
32
+ perturbation_config: PerturbationConfig | None = None
33
+
34
+
35
+ class Hyperparameters(BaseModel):
36
+ model_config = ConfigDict(extra="forbid")
37
+ llm_name: str
38
+ output_dir: Path
39
+ hf_upload_dir: str | None = None
40
+ hf_upload_repo: str | None = None
41
+ wandb_project: str | None = None
42
+ wandb_entity: str | None = None
43
+ wandb_run_id: str | None = None
44
+ description: str | None = None
45
+ task_args: TaskArgs
46
+ llm_args: dict[str, Any] | None = {}
47
+ extra_task_modules: list[str] | None = None
48
+
49
+
50
+ class DeterminedContext(EvalContext):
51
+ def __init__(self, **kwargs: Any) -> None:
52
+ super().__init__(**kwargs)
53
+ self._core_context: Context | None = None
54
+
55
+ def __enter__(self) -> "DeterminedContext":
56
+ distributed_context = DummyDistributedContext()
57
+ self._core_context = determined_core_init(distributed=distributed_context)
58
+ self._core_context.start()
59
+ info = get_cluster_info()
60
+
61
+ if info is None:
62
+ raise RuntimeError("Failed to retrieve cluster info.")
63
+
64
+ # Load extra tasks if specified first
65
+ extra_task_modules = info.trial.hparams.get("extra_task_modules", None)
66
+ if extra_task_modules:
67
+ name = "extra_task_modules"
68
+ val_cli = getattr(self, name, None)
69
+ val_hparams = extra_task_modules
70
+ if val_hparams:
71
+ if val_cli and val_hparams and val_cli != val_hparams:
72
+ logger.info(
73
+ f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters:"
74
+ f"({val_hparams}). If it fails due to duplicate task names, remove the CLI argument and"
75
+ "consolidate as a determined hyperparameter instead."
76
+ )
77
+ load_extra_tasks(val_hparams)
78
+
79
+ self.hparams = Hyperparameters(**info.trial.hparams)
80
+
81
+ for name in [
82
+ "llm_name",
83
+ "llm_args",
84
+ "output_dir",
85
+ "hf_upload_dir",
86
+ "hf_upload_repo",
87
+ "wandb_project",
88
+ "wandb_entity",
89
+ "wandb_run_id",
90
+ "description",
91
+ ]:
92
+ val_cli = getattr(self, name, None)
93
+ val_hparams = getattr(self.hparams, name, None)
94
+ if val_cli and val_hparams and val_cli != val_hparams:
95
+ logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).")
96
+
97
+ for name in [
98
+ "num_samples",
99
+ "max_tokens",
100
+ "num_fewshot",
101
+ "task_name",
102
+ "task_subjects",
103
+ "batch_size",
104
+ "hf_revision",
105
+ "judge_model_name",
106
+ "judge_model_args",
107
+ "perturbation_config",
108
+ ]:
109
+ val_cli = getattr(self, name, None)
110
+ val_hparams = getattr(self.hparams.task_args, name, None)
111
+ if val_cli and val_hparams and val_cli != val_hparams:
112
+ logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).")
113
+
114
+ models = import_models(self.models_path)
115
+ if self.hparams.llm_name not in models:
116
+ raise ValueError(f"LLM '{self.hparams.llm_name}' not found.")
117
+ llm_class = models[self.hparams.llm_name]
118
+
119
+ llm_judge_class: type[BaseLLM] | None = None
120
+ judge_model_name = self.hparams.task_args.judge_model_name or self.judge_model_name
121
+ if self.judge_models_path is not None and judge_model_name is not None:
122
+ judge_models = import_models(self.judge_models_path)
123
+ if judge_model_name not in judge_models:
124
+ raise ValueError(f"LLM judge '{judge_model_name}' not found.")
125
+ llm_judge_class = judge_models[judge_model_name]
126
+
127
+ # for all optional hyperparameters, resort to the respective CLI argument if the hyperparameter is not set
128
+ self.config = EvalConfig(
129
+ llm_class=llm_class,
130
+ llm_args=self.hparams.llm_args or self.llm_args,
131
+ num_samples=self.hparams.task_args.num_samples or self.num_samples,
132
+ max_tokens=self.hparams.task_args.max_tokens or self.max_tokens,
133
+ num_fewshot=self.hparams.task_args.num_fewshot,
134
+ task_name=self.hparams.task_args.task_name,
135
+ task_subjects=self.hparams.task_args.task_subjects,
136
+ hf_revision=self.hparams.task_args.hf_revision or self.hf_revision,
137
+ perturbation_config=self.hparams.task_args.perturbation_config or self.perturbation_config,
138
+ output_dir=self.hparams.output_dir,
139
+ llm_judge_class=llm_judge_class,
140
+ judge_model_args=self.hparams.task_args.judge_model_args or self.judge_model_args,
141
+ hf_upload_dir=self.hparams.hf_upload_dir or self.hf_upload_dir,
142
+ hf_upload_repo=self.hparams.hf_upload_repo or self.hf_upload_repo,
143
+ wandb_project=self.hparams.wandb_project or self.wandb_project,
144
+ wandb_entity=self.hparams.wandb_entity or self.wandb_entity,
145
+ wandb_run_id=self.hparams.wandb_run_id or self.wandb_run_id,
146
+ batch_size=self.hparams.task_args.batch_size or self.batch_size,
147
+ description=self.hparams.description or self.description,
148
+ )
149
+
150
+ return self
151
+
152
+ def __exit__(
153
+ self,
154
+ exc_type: type[BaseException] | None,
155
+ exc_value: BaseException | None,
156
+ traceback: Any | None,
157
+ ) -> None:
158
+ if self._core_context is not None:
159
+ self._core_context.close()
160
+ self._core_context = None
161
+
162
+ def should_preempt(self) -> bool:
163
+ if self._core_context is None:
164
+ return False
165
+ return self._core_context.preempt.should_preempt()
166
+
167
+ def get_trial_id(self) -> int | None:
168
+ if self._core_context is None:
169
+ return None
170
+ return self._core_context.train._trial_id
@@ -0,0 +1,114 @@
1
+ import importlib.util
2
+ import inspect
3
+ import sys
4
+ from contextlib import AbstractContextManager
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import eval_framework
9
+ from eval_framework.llm.base import BaseLLM
10
+ from eval_framework.tasks.eval_config import EvalConfig
11
+ from eval_framework.tasks.perturbation import PerturbationConfig
12
+
13
+
14
+ def import_models(models_file: Path | str) -> dict[str, type[BaseLLM]]:
15
+ models_file = Path(models_file).resolve()
16
+ library_path = Path(eval_framework.__path__[0]).resolve()
17
+
18
+ # Imports from the eval_framework module need special care to avoid
19
+ # import issues
20
+ if models_file.is_relative_to(library_path):
21
+ relative_path = models_file.relative_to(library_path.parent)
22
+ module_name = ".".join(relative_path.with_suffix("").parts)
23
+ module = importlib.import_module(module_name)
24
+ else:
25
+ module_name = models_file.stem
26
+
27
+ spec = importlib.util.spec_from_file_location(module_name, str(models_file))
28
+
29
+ if spec is None:
30
+ raise ImportError(f"Could not load module '{models_file}'.")
31
+
32
+ module = importlib.util.module_from_spec(spec)
33
+ sys.modules[module_name] = module
34
+
35
+ if spec.loader is None:
36
+ raise ImportError(f"Could not load module '{models_file}'.")
37
+
38
+ spec.loader.exec_module(module)
39
+
40
+ subclasses = {}
41
+ for name, clazz in inspect.getmembers(module, inspect.isclass):
42
+ if issubclass(clazz, BaseLLM) and clazz is not BaseLLM:
43
+ subclasses[name] = clazz
44
+
45
+ return subclasses
46
+
47
+
48
+ class EvalContext(AbstractContextManager):
49
+ def __init__(
50
+ self,
51
+ llm_name: str,
52
+ models_path: Path,
53
+ num_samples: int | None = None,
54
+ max_tokens: int | None = None,
55
+ num_fewshot: int | None = None,
56
+ task_name: str | None = None,
57
+ task_subjects: list[str] | None = None,
58
+ hf_revision: str | None = None,
59
+ output_dir: Path | None = None,
60
+ wandb_project: str | None = None,
61
+ wandb_entity: str | None = None,
62
+ wandb_run_id: str | None = None,
63
+ hf_upload_dir: str | None = None,
64
+ hf_upload_repo: str | None = None,
65
+ llm_args: dict[str, Any] | None = None,
66
+ judge_models_path: Path | None = None,
67
+ judge_model_name: str | None = None,
68
+ judge_model_args: dict[str, Any] | None = None,
69
+ batch_size: int | None = None,
70
+ description: str | None = None,
71
+ perturbation_type: str | None = None,
72
+ perturbation_probability: float | None = None,
73
+ perturbation_seed: int | None = None,
74
+ ) -> None:
75
+ self.llm_name = llm_name
76
+ self.models_path = models_path
77
+ self.num_samples = num_samples
78
+ self.max_tokens = max_tokens
79
+ self.num_fewshot = num_fewshot
80
+ self.task_name = task_name
81
+ self.task_subjects = task_subjects
82
+ self.hf_revision = hf_revision
83
+ self.output_dir = output_dir
84
+ self.wandb_project = wandb_project
85
+ self.wandb_entity = wandb_entity
86
+ self.wandb_run_id = wandb_run_id
87
+ self.hf_upload_dir = hf_upload_dir
88
+ self.hf_upload_repo = hf_upload_repo
89
+ self.llm_args = llm_args
90
+ self.judge_models_path = judge_models_path
91
+ self.judge_model_name = judge_model_name
92
+ self.judge_model_args = judge_model_args
93
+ self.batch_size = batch_size
94
+ self.description = description
95
+
96
+ if perturbation_type or perturbation_probability is not None:
97
+ perturbation = {
98
+ "type": perturbation_type,
99
+ "probability": perturbation_probability,
100
+ "seed": perturbation_seed,
101
+ }
102
+ self.perturbation_config: PerturbationConfig | None = PerturbationConfig(
103
+ **{k: v for k, v in perturbation.items() if v is not None}
104
+ )
105
+ else:
106
+ self.perturbation_config = None
107
+
108
+ self.config: EvalConfig | None = None
109
+
110
+ def should_preempt(self) -> bool:
111
+ return False
112
+
113
+ def get_trial_id(self) -> int | None:
114
+ return None
@@ -0,0 +1,52 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.context.eval import EvalContext, import_models
4
+ from eval_framework.llm.base import BaseLLM
5
+ from eval_framework.tasks.eval_config import EvalConfig
6
+
7
+
8
+ class LocalContext(EvalContext):
9
+ def __enter__(self) -> "LocalContext":
10
+ models = import_models(self.models_path)
11
+ if self.llm_name not in models:
12
+ raise ValueError(f"LLM '{self.llm_name}' not found.")
13
+ llm_class = models[self.llm_name]
14
+
15
+ self.llm_judge_class: type[BaseLLM] | None = None
16
+ if self.judge_models_path is not None and self.judge_model_name is not None:
17
+ judge_models = import_models(self.judge_models_path)
18
+ if self.judge_model_name not in judge_models:
19
+ raise ValueError(f"LLM judge '{self.judge_model_name}' not found.")
20
+ self.llm_judge_class = judge_models[self.judge_model_name]
21
+
22
+ self.config = EvalConfig(
23
+ llm_class=llm_class,
24
+ llm_args=self.llm_args,
25
+ num_samples=self.num_samples,
26
+ max_tokens=self.max_tokens,
27
+ num_fewshot=self.num_fewshot,
28
+ perturbation_config=self.perturbation_config,
29
+ task_name=self.task_name,
30
+ task_subjects=self.task_subjects,
31
+ hf_revision=self.hf_revision,
32
+ output_dir=self.output_dir,
33
+ hf_upload_dir=self.hf_upload_dir,
34
+ hf_upload_repo=self.hf_upload_repo,
35
+ wandb_entity=self.wandb_entity,
36
+ wandb_project=self.wandb_project,
37
+ wandb_run_id=self.wandb_run_id,
38
+ llm_judge_class=self.llm_judge_class,
39
+ judge_model_args=self.judge_model_args,
40
+ batch_size=self.batch_size,
41
+ description=self.description,
42
+ )
43
+
44
+ return self
45
+
46
+ def __exit__(
47
+ self,
48
+ exc_type: type[BaseException] | None,
49
+ exc_value: BaseException | None,
50
+ traceback: Any | None,
51
+ ) -> None:
52
+ pass
@@ -0,0 +1,231 @@
1
+ import logging
2
+ import math
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import wandb
7
+ from tqdm import tqdm
8
+
9
+ from eval_framework.metrics.base import BaseMetric
10
+ from eval_framework.metrics.efficiency.bytes_per_sequence_position import (
11
+ BytesCompletion,
12
+ BytesLoglikelihood,
13
+ SequencePositionsCompletion,
14
+ SequencePositionsLoglikelihood,
15
+ )
16
+ from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
17
+ from eval_framework.result_processors.base import Result, ResultProcessor
18
+ from eval_framework.shared.types import Completion, Loglikelihood
19
+ from eval_framework.tasks.base import ResponseType
20
+ from eval_framework.tasks.eval_config import EvalConfig
21
+ from eval_framework.tasks.registry import get_task
22
+ from eval_framework.utils.constants import RED, RESET
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class EvaluationGenerator:
28
+ def __init__(self, config: EvalConfig, result_processor: ResultProcessor) -> None:
29
+ logger.info("EvaluationGenerator initialized")
30
+
31
+ self.few_shot = config.num_fewshot
32
+ self.config = config
33
+ self.num_samples = config.num_samples
34
+ self.max_tokens = config.max_tokens
35
+ self.result_processor = result_processor
36
+ self.save_intermediate_results = config.save_intermediate_results
37
+
38
+ task_class = get_task(config.task_name)
39
+ if task_class.RESPONSE_TYPE == ResponseType.COMPLETION:
40
+ self.metrics = task_class.METRICS + [BytesCompletion, SequencePositionsCompletion]
41
+ elif task_class.RESPONSE_TYPE == ResponseType.LOGLIKELIHOODS:
42
+ self.metrics = task_class.METRICS + [BytesLoglikelihood, SequencePositionsLoglikelihood]
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ self.task_name = task_class.NAME
47
+
48
+ def _run_metric_calculators(self, responses: list[Completion | Loglikelihood]) -> list[Result]:
49
+ results: list[Result] = self.result_processor.load_metrics_results()
50
+ llm_name = self.result_processor.load_metadata()["llm_name"]
51
+
52
+ subject_result_id_existing = set()
53
+ for result in results:
54
+ subject_result_id_existing.add(f"{result.subject}_{result.id}_{result.metric_class_name}")
55
+
56
+ """
57
+ we have three dimensions: subject, metric, sample_id
58
+ we wanna average over sample_id
59
+ and also over all subjects by averaging over the averages
60
+ dict[metric, dict[subject, dict[sample_id, list[result]]]]
61
+ """
62
+ llm_judge = None
63
+ for metric_class in self.metrics:
64
+ metric: BaseMetric
65
+ if issubclass(metric_class, BaseLLMJudgeMetric):
66
+ if llm_judge is None:
67
+ assert self.config.llm_judge_class is not None, "The llm_judge_class must be defined in the config."
68
+ llm_judge = self.config.llm_judge_class(**self.config.judge_model_args)
69
+ metric = metric_class(llm_judge=llm_judge)
70
+ else:
71
+ metric = metric_class()
72
+
73
+ logger.info(f"Starting calculation of {metric.NAME}")
74
+ tqdm.write(f"INFO: Calculating {metric.NAME}")
75
+ for response in tqdm(responses, desc=f"Calculating {metric.NAME}"):
76
+ if f"{response.subject}_{response.id}_{metric.__class__.__name__}" in subject_result_id_existing:
77
+ continue
78
+
79
+ subject = response.subject
80
+ metric_results = metric.calculate(response)
81
+ for metric_result in metric_results:
82
+ if "/" in metric_result.metric_name:
83
+ metric_name, key = metric_result.metric_name.split("/")
84
+ else:
85
+ metric_name = metric_result.metric_name
86
+ key = None
87
+ completion = response.completion if isinstance(response, Completion) else str(response.ground_truth)
88
+
89
+ result = Result(
90
+ id=response.id,
91
+ metric_class_name=metric.__class__.__name__,
92
+ metric_name=metric_name,
93
+ num_fewshot=self.few_shot,
94
+ key=key,
95
+ subject=subject,
96
+ llm_name=llm_name,
97
+ task_name=self.task_name,
98
+ value=metric_result.value,
99
+ higher_is_better=metric_result.higher_is_better,
100
+ prompt=response.prompt,
101
+ response=completion,
102
+ llm_judge_prompt=metric_result.llm_judge_prompt,
103
+ llm_judge_response=metric_result.llm_judge_response,
104
+ code_execution_trace=metric_result.code_execution_trace,
105
+ error=metric_result.error,
106
+ )
107
+ results.append(result)
108
+ if self.save_intermediate_results:
109
+ self.result_processor.save_metrics_result(result)
110
+
111
+ logger.info(f"Completed calculation of {metric.NAME}")
112
+ tqdm.write(f"INFO: Completed {metric.NAME}")
113
+
114
+ if not self.save_intermediate_results:
115
+ self.result_processor.save_metrics_results(results)
116
+ return results
117
+
118
+ def _aggregate_results(self, results: list[Result]) -> dict[str, float | None]:
119
+ data = pd.DataFrame([r.model_dump() for r in results])
120
+ if len(data) == 0:
121
+ return {}
122
+ data.fillna({"key": ""}, inplace=True)
123
+ metrics = sorted(data["metric_name"].unique())
124
+ aggregated_results: dict[str, float | None] = {}
125
+
126
+ for metric in metrics:
127
+ # filter for metric
128
+ data_subset = data[data["metric_name"] == metric][["subject", "key", "value", "error"]]
129
+
130
+ # filter and count errors
131
+ total_count = len(data_subset)
132
+
133
+ mask = data["error"].isnull()
134
+ data_subset_error_free = data_subset.loc[mask, ["subject", "key", "value"]]
135
+ # data_subset_error_free = data_subset[data_subset["error"].isnull()][["subject", "key", "value"]]
136
+
137
+ aggregated_results[f"ErrorFreeRatio {metric}"] = float(len(data_subset_error_free) / total_count)
138
+
139
+ # aggregate by key and subject first to have equal weights for all key / subject combinations
140
+ key_subject_mean = data_subset_error_free.groupby(["key", "subject"]).mean()
141
+ aggregated_results[f"Average {metric}"] = float(key_subject_mean[["value"]].mean()["value"])
142
+
143
+ std_err_mean_sum_of_squares = 0.0
144
+ std_err_mean_total_num_samples = 0.0
145
+ std_err_mean_num_subjects = 0
146
+
147
+ for column in ["key", "subject"]:
148
+ if len(data_subset[column].unique()) > 1:
149
+ for name, _group in key_subject_mean.groupby([column]):
150
+ mask = data_subset[column] == name[0]
151
+ group = data_subset.loc[mask, ["subject", "key", "value", "error"]]
152
+ # group = data_subset[data[column] == name][["subject", "key", "value", "error"]]
153
+ group_total_count = len(group)
154
+ group_error_free = group[group["error"].isnull()][["subject", "key", "value"]]
155
+ aggregated_results[f"ErrorFreeRatio {metric} - {name[0]}"] = float(
156
+ len(group_error_free) / group_total_count
157
+ )
158
+
159
+ group_key_subject_mean = group_error_free.groupby(["key", "subject"]).mean()
160
+ value = float(group_key_subject_mean[["value"]].mean()["value"])
161
+ aggregated_results[f"Average {metric} - {name[0]}"] = value if not math.isnan(value) else None
162
+
163
+ if not ("SequencePositions" in metric or "Bytes" in metric):
164
+ # calculate standard error for selected metrics
165
+ group_key_subject_std = group_error_free.groupby(["key", "subject"]).std()
166
+ std = float(group_key_subject_std[["value"]].mean()["value"])
167
+ num_samples = len(group_error_free)
168
+
169
+ if math.isnan(std) or num_samples == 0:
170
+ aggregated_results[f"StdErr {metric} - {name[0]}"] = None
171
+ else:
172
+ aggregated_results[f"StdErr {metric} - {name[0]}"] = std / np.sqrt(num_samples)
173
+ aggregated_results[f"NumSamples {metric} - {name[0]}"] = num_samples
174
+
175
+ std_err_mean_sum_of_squares += std**2 / num_samples
176
+ std_err_mean_total_num_samples += num_samples
177
+ std_err_mean_num_subjects += 1
178
+
179
+ if not ("SequencePositions" in metric or "Bytes" in metric):
180
+ # calculate standard error for selected metrics
181
+ if std_err_mean_total_num_samples > 0:
182
+ # calculate the standard error of the mean (SEM) for the aggregated results (eg. add in quadrature)
183
+ # SEM = sqrt(sum(variance_i * n_i) / i)
184
+ # where variance_i is the variance of each group and i is the number of groups
185
+ # (the combined mean is also not weighted by the number of samples)
186
+ if math.isnan(std) or std_err_mean_total_num_samples == 0:
187
+ aggregated_results[f"StdErr {metric}"] = None
188
+ else:
189
+ aggregated_results[f"StdErr {metric}"] = np.sqrt(
190
+ std_err_mean_sum_of_squares / std_err_mean_num_subjects
191
+ )
192
+ aggregated_results[f"NumSamples {metric}"] = std_err_mean_total_num_samples
193
+ else:
194
+ # if there are no sub-groups to combine, calculate the SEM here directly
195
+ key_subject_std = data_subset_error_free.groupby(["key", "subject"]).std()
196
+ std = float(key_subject_std[["value"]].mean()["value"])
197
+ num_samples = len(data_subset_error_free)
198
+ if math.isnan(std) or num_samples == 0:
199
+ aggregated_results[f"StdErr {metric}"] = None
200
+ else:
201
+ aggregated_results[f"StdErr {metric}"] = std / np.sqrt(num_samples)
202
+ aggregated_results[f"NumSamples {metric}"] = num_samples
203
+
204
+ if (
205
+ "Average Bytes" in aggregated_results
206
+ and "Average SequencePositions" in aggregated_results
207
+ and aggregated_results["Average Bytes"]
208
+ and aggregated_results["Average SequencePositions"]
209
+ ):
210
+ aggregated_results["Average Bytes per Sequence Position"] = (
211
+ aggregated_results["Average Bytes"] / aggregated_results["Average SequencePositions"]
212
+ )
213
+
214
+ return aggregated_results
215
+
216
+ def run_eval(self) -> list[Result]:
217
+ """Runs evaluation using saved completions."""
218
+ logger.info("Running evaluation...")
219
+ responses = self.result_processor.load_responses()
220
+ if not responses:
221
+ raise ValueError("No saved completions found. Run 'run_completions' first.")
222
+
223
+ metrics_results = self._run_metric_calculators(responses)
224
+ aggregated_results = self._aggregate_results(metrics_results)
225
+
226
+ wandb.log(aggregated_results)
227
+
228
+ self.result_processor.save_aggregated_results(aggregated_results)
229
+ logger.info(aggregated_results)
230
+ logger.info(f"{RED}[ Evaluation completed and results saved! ]{RESET}")
231
+ return metrics_results
@@ -0,0 +1,2 @@
1
+ class LogicError(Exception):
2
+ pass
@@ -0,0 +1,5 @@
1
+ IFEval implementation here is taken 1:1 (besides some minuscule details, like adding Swedish to `LANGUAGE_CODES`)
2
+ from [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks/ifeval) as of Dec 19 2024.
3
+ The original repository is https://github.com/google-research/google-research/tree/master/instruction_following_eval .
4
+
5
+ This code doesn't have to be unit-tested.