eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import yaml # type: ignore[import-untyped]
|
|
6
|
+
from pydantic import BaseModel, ConfigDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseConfig(BaseModel):
|
|
10
|
+
model_config = ConfigDict(extra="forbid", frozen=True, protected_namespaces=())
|
|
11
|
+
|
|
12
|
+
def as_dict(self) -> dict[str, Any]:
|
|
13
|
+
def simplify_recursive(obj: Any) -> Any:
|
|
14
|
+
if isinstance(obj, dict):
|
|
15
|
+
return {key: simplify_recursive(value) for key, value in obj.items()}
|
|
16
|
+
elif isinstance(obj, list):
|
|
17
|
+
return [simplify_recursive(item) for item in obj]
|
|
18
|
+
elif isinstance(obj, Path):
|
|
19
|
+
return str(obj)
|
|
20
|
+
elif isinstance(obj, Enum):
|
|
21
|
+
return obj.value
|
|
22
|
+
else:
|
|
23
|
+
return obj
|
|
24
|
+
|
|
25
|
+
return simplify_recursive(self.model_dump())
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_yaml(cls, yml_filename: str | Path) -> "BaseConfig":
|
|
29
|
+
with open(yml_filename) as conf_file:
|
|
30
|
+
config_dict = yaml.load(conf_file, Loader=yaml.FullLoader)
|
|
31
|
+
|
|
32
|
+
return cls(**config_dict)
|
|
33
|
+
|
|
34
|
+
def save(self, out_file: Path) -> None:
|
|
35
|
+
with open(out_file, "w", encoding="UTF-8") as f:
|
|
36
|
+
yaml.safe_dump(self.model_dump(mode="json"), f)
|
|
File without changes
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Annotated, Any
|
|
4
|
+
|
|
5
|
+
from determined._info import get_cluster_info
|
|
6
|
+
from determined.core._context import Context
|
|
7
|
+
from determined.core._context import init as determined_core_init
|
|
8
|
+
from determined.core._distributed import DummyDistributedContext
|
|
9
|
+
from pydantic import AfterValidator, BaseModel, ConfigDict
|
|
10
|
+
|
|
11
|
+
from eval_framework.context.eval import EvalContext, import_models
|
|
12
|
+
from eval_framework.llm.base import BaseLLM
|
|
13
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
14
|
+
from eval_framework.tasks.perturbation import PerturbationConfig
|
|
15
|
+
from eval_framework.tasks.registry import validate_task_name
|
|
16
|
+
from eval_framework.tasks.task_loader import load_extra_tasks
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TaskArgs(BaseModel):
|
|
22
|
+
model_config = ConfigDict(extra="forbid")
|
|
23
|
+
task_name: Annotated[str, AfterValidator(validate_task_name)]
|
|
24
|
+
num_fewshot: int
|
|
25
|
+
num_samples: int | None = None
|
|
26
|
+
max_tokens: int | None = None
|
|
27
|
+
batch_size: int | None = None
|
|
28
|
+
judge_model_name: str | None = None
|
|
29
|
+
judge_model_args: dict[str, Any] = {}
|
|
30
|
+
task_subjects: list[str] | None = None
|
|
31
|
+
hf_revision: str | None = None
|
|
32
|
+
perturbation_config: PerturbationConfig | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Hyperparameters(BaseModel):
|
|
36
|
+
model_config = ConfigDict(extra="forbid")
|
|
37
|
+
llm_name: str
|
|
38
|
+
output_dir: Path
|
|
39
|
+
hf_upload_dir: str | None = None
|
|
40
|
+
hf_upload_repo: str | None = None
|
|
41
|
+
wandb_project: str | None = None
|
|
42
|
+
wandb_entity: str | None = None
|
|
43
|
+
wandb_run_id: str | None = None
|
|
44
|
+
description: str | None = None
|
|
45
|
+
task_args: TaskArgs
|
|
46
|
+
llm_args: dict[str, Any] | None = {}
|
|
47
|
+
extra_task_modules: list[str] | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class DeterminedContext(EvalContext):
|
|
51
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
52
|
+
super().__init__(**kwargs)
|
|
53
|
+
self._core_context: Context | None = None
|
|
54
|
+
|
|
55
|
+
def __enter__(self) -> "DeterminedContext":
|
|
56
|
+
distributed_context = DummyDistributedContext()
|
|
57
|
+
self._core_context = determined_core_init(distributed=distributed_context)
|
|
58
|
+
self._core_context.start()
|
|
59
|
+
info = get_cluster_info()
|
|
60
|
+
|
|
61
|
+
if info is None:
|
|
62
|
+
raise RuntimeError("Failed to retrieve cluster info.")
|
|
63
|
+
|
|
64
|
+
# Load extra tasks if specified first
|
|
65
|
+
extra_task_modules = info.trial.hparams.get("extra_task_modules", None)
|
|
66
|
+
if extra_task_modules:
|
|
67
|
+
name = "extra_task_modules"
|
|
68
|
+
val_cli = getattr(self, name, None)
|
|
69
|
+
val_hparams = extra_task_modules
|
|
70
|
+
if val_hparams:
|
|
71
|
+
if val_cli and val_hparams and val_cli != val_hparams:
|
|
72
|
+
logger.info(
|
|
73
|
+
f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters:"
|
|
74
|
+
f"({val_hparams}). If it fails due to duplicate task names, remove the CLI argument and"
|
|
75
|
+
"consolidate as a determined hyperparameter instead."
|
|
76
|
+
)
|
|
77
|
+
load_extra_tasks(val_hparams)
|
|
78
|
+
|
|
79
|
+
self.hparams = Hyperparameters(**info.trial.hparams)
|
|
80
|
+
|
|
81
|
+
for name in [
|
|
82
|
+
"llm_name",
|
|
83
|
+
"llm_args",
|
|
84
|
+
"output_dir",
|
|
85
|
+
"hf_upload_dir",
|
|
86
|
+
"hf_upload_repo",
|
|
87
|
+
"wandb_project",
|
|
88
|
+
"wandb_entity",
|
|
89
|
+
"wandb_run_id",
|
|
90
|
+
"description",
|
|
91
|
+
]:
|
|
92
|
+
val_cli = getattr(self, name, None)
|
|
93
|
+
val_hparams = getattr(self.hparams, name, None)
|
|
94
|
+
if val_cli and val_hparams and val_cli != val_hparams:
|
|
95
|
+
logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).")
|
|
96
|
+
|
|
97
|
+
for name in [
|
|
98
|
+
"num_samples",
|
|
99
|
+
"max_tokens",
|
|
100
|
+
"num_fewshot",
|
|
101
|
+
"task_name",
|
|
102
|
+
"task_subjects",
|
|
103
|
+
"batch_size",
|
|
104
|
+
"hf_revision",
|
|
105
|
+
"judge_model_name",
|
|
106
|
+
"judge_model_args",
|
|
107
|
+
"perturbation_config",
|
|
108
|
+
]:
|
|
109
|
+
val_cli = getattr(self, name, None)
|
|
110
|
+
val_hparams = getattr(self.hparams.task_args, name, None)
|
|
111
|
+
if val_cli and val_hparams and val_cli != val_hparams:
|
|
112
|
+
logger.info(f"CLI argument {name} ({val_cli}) is being overridden by hyperparameters: ({val_hparams}).")
|
|
113
|
+
|
|
114
|
+
models = import_models(self.models_path)
|
|
115
|
+
if self.hparams.llm_name not in models:
|
|
116
|
+
raise ValueError(f"LLM '{self.hparams.llm_name}' not found.")
|
|
117
|
+
llm_class = models[self.hparams.llm_name]
|
|
118
|
+
|
|
119
|
+
llm_judge_class: type[BaseLLM] | None = None
|
|
120
|
+
judge_model_name = self.hparams.task_args.judge_model_name or self.judge_model_name
|
|
121
|
+
if self.judge_models_path is not None and judge_model_name is not None:
|
|
122
|
+
judge_models = import_models(self.judge_models_path)
|
|
123
|
+
if judge_model_name not in judge_models:
|
|
124
|
+
raise ValueError(f"LLM judge '{judge_model_name}' not found.")
|
|
125
|
+
llm_judge_class = judge_models[judge_model_name]
|
|
126
|
+
|
|
127
|
+
# for all optional hyperparameters, resort to the respective CLI argument if the hyperparameter is not set
|
|
128
|
+
self.config = EvalConfig(
|
|
129
|
+
llm_class=llm_class,
|
|
130
|
+
llm_args=self.hparams.llm_args or self.llm_args,
|
|
131
|
+
num_samples=self.hparams.task_args.num_samples or self.num_samples,
|
|
132
|
+
max_tokens=self.hparams.task_args.max_tokens or self.max_tokens,
|
|
133
|
+
num_fewshot=self.hparams.task_args.num_fewshot,
|
|
134
|
+
task_name=self.hparams.task_args.task_name,
|
|
135
|
+
task_subjects=self.hparams.task_args.task_subjects,
|
|
136
|
+
hf_revision=self.hparams.task_args.hf_revision or self.hf_revision,
|
|
137
|
+
perturbation_config=self.hparams.task_args.perturbation_config or self.perturbation_config,
|
|
138
|
+
output_dir=self.hparams.output_dir,
|
|
139
|
+
llm_judge_class=llm_judge_class,
|
|
140
|
+
judge_model_args=self.hparams.task_args.judge_model_args or self.judge_model_args,
|
|
141
|
+
hf_upload_dir=self.hparams.hf_upload_dir or self.hf_upload_dir,
|
|
142
|
+
hf_upload_repo=self.hparams.hf_upload_repo or self.hf_upload_repo,
|
|
143
|
+
wandb_project=self.hparams.wandb_project or self.wandb_project,
|
|
144
|
+
wandb_entity=self.hparams.wandb_entity or self.wandb_entity,
|
|
145
|
+
wandb_run_id=self.hparams.wandb_run_id or self.wandb_run_id,
|
|
146
|
+
batch_size=self.hparams.task_args.batch_size or self.batch_size,
|
|
147
|
+
description=self.hparams.description or self.description,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return self
|
|
151
|
+
|
|
152
|
+
def __exit__(
|
|
153
|
+
self,
|
|
154
|
+
exc_type: type[BaseException] | None,
|
|
155
|
+
exc_value: BaseException | None,
|
|
156
|
+
traceback: Any | None,
|
|
157
|
+
) -> None:
|
|
158
|
+
if self._core_context is not None:
|
|
159
|
+
self._core_context.close()
|
|
160
|
+
self._core_context = None
|
|
161
|
+
|
|
162
|
+
def should_preempt(self) -> bool:
|
|
163
|
+
if self._core_context is None:
|
|
164
|
+
return False
|
|
165
|
+
return self._core_context.preempt.should_preempt()
|
|
166
|
+
|
|
167
|
+
def get_trial_id(self) -> int | None:
|
|
168
|
+
if self._core_context is None:
|
|
169
|
+
return None
|
|
170
|
+
return self._core_context.train._trial_id
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import inspect
|
|
3
|
+
import sys
|
|
4
|
+
from contextlib import AbstractContextManager
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import eval_framework
|
|
9
|
+
from eval_framework.llm.base import BaseLLM
|
|
10
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
11
|
+
from eval_framework.tasks.perturbation import PerturbationConfig
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def import_models(models_file: Path | str) -> dict[str, type[BaseLLM]]:
|
|
15
|
+
models_file = Path(models_file).resolve()
|
|
16
|
+
library_path = Path(eval_framework.__path__[0]).resolve()
|
|
17
|
+
|
|
18
|
+
# Imports from the eval_framework module need special care to avoid
|
|
19
|
+
# import issues
|
|
20
|
+
if models_file.is_relative_to(library_path):
|
|
21
|
+
relative_path = models_file.relative_to(library_path.parent)
|
|
22
|
+
module_name = ".".join(relative_path.with_suffix("").parts)
|
|
23
|
+
module = importlib.import_module(module_name)
|
|
24
|
+
else:
|
|
25
|
+
module_name = models_file.stem
|
|
26
|
+
|
|
27
|
+
spec = importlib.util.spec_from_file_location(module_name, str(models_file))
|
|
28
|
+
|
|
29
|
+
if spec is None:
|
|
30
|
+
raise ImportError(f"Could not load module '{models_file}'.")
|
|
31
|
+
|
|
32
|
+
module = importlib.util.module_from_spec(spec)
|
|
33
|
+
sys.modules[module_name] = module
|
|
34
|
+
|
|
35
|
+
if spec.loader is None:
|
|
36
|
+
raise ImportError(f"Could not load module '{models_file}'.")
|
|
37
|
+
|
|
38
|
+
spec.loader.exec_module(module)
|
|
39
|
+
|
|
40
|
+
subclasses = {}
|
|
41
|
+
for name, clazz in inspect.getmembers(module, inspect.isclass):
|
|
42
|
+
if issubclass(clazz, BaseLLM) and clazz is not BaseLLM:
|
|
43
|
+
subclasses[name] = clazz
|
|
44
|
+
|
|
45
|
+
return subclasses
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class EvalContext(AbstractContextManager):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
llm_name: str,
|
|
52
|
+
models_path: Path,
|
|
53
|
+
num_samples: int | None = None,
|
|
54
|
+
max_tokens: int | None = None,
|
|
55
|
+
num_fewshot: int | None = None,
|
|
56
|
+
task_name: str | None = None,
|
|
57
|
+
task_subjects: list[str] | None = None,
|
|
58
|
+
hf_revision: str | None = None,
|
|
59
|
+
output_dir: Path | None = None,
|
|
60
|
+
wandb_project: str | None = None,
|
|
61
|
+
wandb_entity: str | None = None,
|
|
62
|
+
wandb_run_id: str | None = None,
|
|
63
|
+
hf_upload_dir: str | None = None,
|
|
64
|
+
hf_upload_repo: str | None = None,
|
|
65
|
+
llm_args: dict[str, Any] | None = None,
|
|
66
|
+
judge_models_path: Path | None = None,
|
|
67
|
+
judge_model_name: str | None = None,
|
|
68
|
+
judge_model_args: dict[str, Any] | None = None,
|
|
69
|
+
batch_size: int | None = None,
|
|
70
|
+
description: str | None = None,
|
|
71
|
+
perturbation_type: str | None = None,
|
|
72
|
+
perturbation_probability: float | None = None,
|
|
73
|
+
perturbation_seed: int | None = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
self.llm_name = llm_name
|
|
76
|
+
self.models_path = models_path
|
|
77
|
+
self.num_samples = num_samples
|
|
78
|
+
self.max_tokens = max_tokens
|
|
79
|
+
self.num_fewshot = num_fewshot
|
|
80
|
+
self.task_name = task_name
|
|
81
|
+
self.task_subjects = task_subjects
|
|
82
|
+
self.hf_revision = hf_revision
|
|
83
|
+
self.output_dir = output_dir
|
|
84
|
+
self.wandb_project = wandb_project
|
|
85
|
+
self.wandb_entity = wandb_entity
|
|
86
|
+
self.wandb_run_id = wandb_run_id
|
|
87
|
+
self.hf_upload_dir = hf_upload_dir
|
|
88
|
+
self.hf_upload_repo = hf_upload_repo
|
|
89
|
+
self.llm_args = llm_args
|
|
90
|
+
self.judge_models_path = judge_models_path
|
|
91
|
+
self.judge_model_name = judge_model_name
|
|
92
|
+
self.judge_model_args = judge_model_args
|
|
93
|
+
self.batch_size = batch_size
|
|
94
|
+
self.description = description
|
|
95
|
+
|
|
96
|
+
if perturbation_type or perturbation_probability is not None:
|
|
97
|
+
perturbation = {
|
|
98
|
+
"type": perturbation_type,
|
|
99
|
+
"probability": perturbation_probability,
|
|
100
|
+
"seed": perturbation_seed,
|
|
101
|
+
}
|
|
102
|
+
self.perturbation_config: PerturbationConfig | None = PerturbationConfig(
|
|
103
|
+
**{k: v for k, v in perturbation.items() if v is not None}
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
self.perturbation_config = None
|
|
107
|
+
|
|
108
|
+
self.config: EvalConfig | None = None
|
|
109
|
+
|
|
110
|
+
def should_preempt(self) -> bool:
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
def get_trial_id(self) -> int | None:
|
|
114
|
+
return None
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from eval_framework.context.eval import EvalContext, import_models
|
|
4
|
+
from eval_framework.llm.base import BaseLLM
|
|
5
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LocalContext(EvalContext):
|
|
9
|
+
def __enter__(self) -> "LocalContext":
|
|
10
|
+
models = import_models(self.models_path)
|
|
11
|
+
if self.llm_name not in models:
|
|
12
|
+
raise ValueError(f"LLM '{self.llm_name}' not found.")
|
|
13
|
+
llm_class = models[self.llm_name]
|
|
14
|
+
|
|
15
|
+
self.llm_judge_class: type[BaseLLM] | None = None
|
|
16
|
+
if self.judge_models_path is not None and self.judge_model_name is not None:
|
|
17
|
+
judge_models = import_models(self.judge_models_path)
|
|
18
|
+
if self.judge_model_name not in judge_models:
|
|
19
|
+
raise ValueError(f"LLM judge '{self.judge_model_name}' not found.")
|
|
20
|
+
self.llm_judge_class = judge_models[self.judge_model_name]
|
|
21
|
+
|
|
22
|
+
self.config = EvalConfig(
|
|
23
|
+
llm_class=llm_class,
|
|
24
|
+
llm_args=self.llm_args,
|
|
25
|
+
num_samples=self.num_samples,
|
|
26
|
+
max_tokens=self.max_tokens,
|
|
27
|
+
num_fewshot=self.num_fewshot,
|
|
28
|
+
perturbation_config=self.perturbation_config,
|
|
29
|
+
task_name=self.task_name,
|
|
30
|
+
task_subjects=self.task_subjects,
|
|
31
|
+
hf_revision=self.hf_revision,
|
|
32
|
+
output_dir=self.output_dir,
|
|
33
|
+
hf_upload_dir=self.hf_upload_dir,
|
|
34
|
+
hf_upload_repo=self.hf_upload_repo,
|
|
35
|
+
wandb_entity=self.wandb_entity,
|
|
36
|
+
wandb_project=self.wandb_project,
|
|
37
|
+
wandb_run_id=self.wandb_run_id,
|
|
38
|
+
llm_judge_class=self.llm_judge_class,
|
|
39
|
+
judge_model_args=self.judge_model_args,
|
|
40
|
+
batch_size=self.batch_size,
|
|
41
|
+
description=self.description,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
def __exit__(
|
|
47
|
+
self,
|
|
48
|
+
exc_type: type[BaseException] | None,
|
|
49
|
+
exc_value: BaseException | None,
|
|
50
|
+
traceback: Any | None,
|
|
51
|
+
) -> None:
|
|
52
|
+
pass
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import wandb
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from eval_framework.metrics.base import BaseMetric
|
|
10
|
+
from eval_framework.metrics.efficiency.bytes_per_sequence_position import (
|
|
11
|
+
BytesCompletion,
|
|
12
|
+
BytesLoglikelihood,
|
|
13
|
+
SequencePositionsCompletion,
|
|
14
|
+
SequencePositionsLoglikelihood,
|
|
15
|
+
)
|
|
16
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
17
|
+
from eval_framework.result_processors.base import Result, ResultProcessor
|
|
18
|
+
from eval_framework.shared.types import Completion, Loglikelihood
|
|
19
|
+
from eval_framework.tasks.base import ResponseType
|
|
20
|
+
from eval_framework.tasks.eval_config import EvalConfig
|
|
21
|
+
from eval_framework.tasks.registry import get_task
|
|
22
|
+
from eval_framework.utils.constants import RED, RESET
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EvaluationGenerator:
|
|
28
|
+
def __init__(self, config: EvalConfig, result_processor: ResultProcessor) -> None:
|
|
29
|
+
logger.info("EvaluationGenerator initialized")
|
|
30
|
+
|
|
31
|
+
self.few_shot = config.num_fewshot
|
|
32
|
+
self.config = config
|
|
33
|
+
self.num_samples = config.num_samples
|
|
34
|
+
self.max_tokens = config.max_tokens
|
|
35
|
+
self.result_processor = result_processor
|
|
36
|
+
self.save_intermediate_results = config.save_intermediate_results
|
|
37
|
+
|
|
38
|
+
task_class = get_task(config.task_name)
|
|
39
|
+
if task_class.RESPONSE_TYPE == ResponseType.COMPLETION:
|
|
40
|
+
self.metrics = task_class.METRICS + [BytesCompletion, SequencePositionsCompletion]
|
|
41
|
+
elif task_class.RESPONSE_TYPE == ResponseType.LOGLIKELIHOODS:
|
|
42
|
+
self.metrics = task_class.METRICS + [BytesLoglikelihood, SequencePositionsLoglikelihood]
|
|
43
|
+
else:
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
self.task_name = task_class.NAME
|
|
47
|
+
|
|
48
|
+
def _run_metric_calculators(self, responses: list[Completion | Loglikelihood]) -> list[Result]:
|
|
49
|
+
results: list[Result] = self.result_processor.load_metrics_results()
|
|
50
|
+
llm_name = self.result_processor.load_metadata()["llm_name"]
|
|
51
|
+
|
|
52
|
+
subject_result_id_existing = set()
|
|
53
|
+
for result in results:
|
|
54
|
+
subject_result_id_existing.add(f"{result.subject}_{result.id}_{result.metric_class_name}")
|
|
55
|
+
|
|
56
|
+
"""
|
|
57
|
+
we have three dimensions: subject, metric, sample_id
|
|
58
|
+
we wanna average over sample_id
|
|
59
|
+
and also over all subjects by averaging over the averages
|
|
60
|
+
dict[metric, dict[subject, dict[sample_id, list[result]]]]
|
|
61
|
+
"""
|
|
62
|
+
llm_judge = None
|
|
63
|
+
for metric_class in self.metrics:
|
|
64
|
+
metric: BaseMetric
|
|
65
|
+
if issubclass(metric_class, BaseLLMJudgeMetric):
|
|
66
|
+
if llm_judge is None:
|
|
67
|
+
assert self.config.llm_judge_class is not None, "The llm_judge_class must be defined in the config."
|
|
68
|
+
llm_judge = self.config.llm_judge_class(**self.config.judge_model_args)
|
|
69
|
+
metric = metric_class(llm_judge=llm_judge)
|
|
70
|
+
else:
|
|
71
|
+
metric = metric_class()
|
|
72
|
+
|
|
73
|
+
logger.info(f"Starting calculation of {metric.NAME}")
|
|
74
|
+
tqdm.write(f"INFO: Calculating {metric.NAME}")
|
|
75
|
+
for response in tqdm(responses, desc=f"Calculating {metric.NAME}"):
|
|
76
|
+
if f"{response.subject}_{response.id}_{metric.__class__.__name__}" in subject_result_id_existing:
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
subject = response.subject
|
|
80
|
+
metric_results = metric.calculate(response)
|
|
81
|
+
for metric_result in metric_results:
|
|
82
|
+
if "/" in metric_result.metric_name:
|
|
83
|
+
metric_name, key = metric_result.metric_name.split("/")
|
|
84
|
+
else:
|
|
85
|
+
metric_name = metric_result.metric_name
|
|
86
|
+
key = None
|
|
87
|
+
completion = response.completion if isinstance(response, Completion) else str(response.ground_truth)
|
|
88
|
+
|
|
89
|
+
result = Result(
|
|
90
|
+
id=response.id,
|
|
91
|
+
metric_class_name=metric.__class__.__name__,
|
|
92
|
+
metric_name=metric_name,
|
|
93
|
+
num_fewshot=self.few_shot,
|
|
94
|
+
key=key,
|
|
95
|
+
subject=subject,
|
|
96
|
+
llm_name=llm_name,
|
|
97
|
+
task_name=self.task_name,
|
|
98
|
+
value=metric_result.value,
|
|
99
|
+
higher_is_better=metric_result.higher_is_better,
|
|
100
|
+
prompt=response.prompt,
|
|
101
|
+
response=completion,
|
|
102
|
+
llm_judge_prompt=metric_result.llm_judge_prompt,
|
|
103
|
+
llm_judge_response=metric_result.llm_judge_response,
|
|
104
|
+
code_execution_trace=metric_result.code_execution_trace,
|
|
105
|
+
error=metric_result.error,
|
|
106
|
+
)
|
|
107
|
+
results.append(result)
|
|
108
|
+
if self.save_intermediate_results:
|
|
109
|
+
self.result_processor.save_metrics_result(result)
|
|
110
|
+
|
|
111
|
+
logger.info(f"Completed calculation of {metric.NAME}")
|
|
112
|
+
tqdm.write(f"INFO: Completed {metric.NAME}")
|
|
113
|
+
|
|
114
|
+
if not self.save_intermediate_results:
|
|
115
|
+
self.result_processor.save_metrics_results(results)
|
|
116
|
+
return results
|
|
117
|
+
|
|
118
|
+
def _aggregate_results(self, results: list[Result]) -> dict[str, float | None]:
|
|
119
|
+
data = pd.DataFrame([r.model_dump() for r in results])
|
|
120
|
+
if len(data) == 0:
|
|
121
|
+
return {}
|
|
122
|
+
data.fillna({"key": ""}, inplace=True)
|
|
123
|
+
metrics = sorted(data["metric_name"].unique())
|
|
124
|
+
aggregated_results: dict[str, float | None] = {}
|
|
125
|
+
|
|
126
|
+
for metric in metrics:
|
|
127
|
+
# filter for metric
|
|
128
|
+
data_subset = data[data["metric_name"] == metric][["subject", "key", "value", "error"]]
|
|
129
|
+
|
|
130
|
+
# filter and count errors
|
|
131
|
+
total_count = len(data_subset)
|
|
132
|
+
|
|
133
|
+
mask = data["error"].isnull()
|
|
134
|
+
data_subset_error_free = data_subset.loc[mask, ["subject", "key", "value"]]
|
|
135
|
+
# data_subset_error_free = data_subset[data_subset["error"].isnull()][["subject", "key", "value"]]
|
|
136
|
+
|
|
137
|
+
aggregated_results[f"ErrorFreeRatio {metric}"] = float(len(data_subset_error_free) / total_count)
|
|
138
|
+
|
|
139
|
+
# aggregate by key and subject first to have equal weights for all key / subject combinations
|
|
140
|
+
key_subject_mean = data_subset_error_free.groupby(["key", "subject"]).mean()
|
|
141
|
+
aggregated_results[f"Average {metric}"] = float(key_subject_mean[["value"]].mean()["value"])
|
|
142
|
+
|
|
143
|
+
std_err_mean_sum_of_squares = 0.0
|
|
144
|
+
std_err_mean_total_num_samples = 0.0
|
|
145
|
+
std_err_mean_num_subjects = 0
|
|
146
|
+
|
|
147
|
+
for column in ["key", "subject"]:
|
|
148
|
+
if len(data_subset[column].unique()) > 1:
|
|
149
|
+
for name, _group in key_subject_mean.groupby([column]):
|
|
150
|
+
mask = data_subset[column] == name[0]
|
|
151
|
+
group = data_subset.loc[mask, ["subject", "key", "value", "error"]]
|
|
152
|
+
# group = data_subset[data[column] == name][["subject", "key", "value", "error"]]
|
|
153
|
+
group_total_count = len(group)
|
|
154
|
+
group_error_free = group[group["error"].isnull()][["subject", "key", "value"]]
|
|
155
|
+
aggregated_results[f"ErrorFreeRatio {metric} - {name[0]}"] = float(
|
|
156
|
+
len(group_error_free) / group_total_count
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
group_key_subject_mean = group_error_free.groupby(["key", "subject"]).mean()
|
|
160
|
+
value = float(group_key_subject_mean[["value"]].mean()["value"])
|
|
161
|
+
aggregated_results[f"Average {metric} - {name[0]}"] = value if not math.isnan(value) else None
|
|
162
|
+
|
|
163
|
+
if not ("SequencePositions" in metric or "Bytes" in metric):
|
|
164
|
+
# calculate standard error for selected metrics
|
|
165
|
+
group_key_subject_std = group_error_free.groupby(["key", "subject"]).std()
|
|
166
|
+
std = float(group_key_subject_std[["value"]].mean()["value"])
|
|
167
|
+
num_samples = len(group_error_free)
|
|
168
|
+
|
|
169
|
+
if math.isnan(std) or num_samples == 0:
|
|
170
|
+
aggregated_results[f"StdErr {metric} - {name[0]}"] = None
|
|
171
|
+
else:
|
|
172
|
+
aggregated_results[f"StdErr {metric} - {name[0]}"] = std / np.sqrt(num_samples)
|
|
173
|
+
aggregated_results[f"NumSamples {metric} - {name[0]}"] = num_samples
|
|
174
|
+
|
|
175
|
+
std_err_mean_sum_of_squares += std**2 / num_samples
|
|
176
|
+
std_err_mean_total_num_samples += num_samples
|
|
177
|
+
std_err_mean_num_subjects += 1
|
|
178
|
+
|
|
179
|
+
if not ("SequencePositions" in metric or "Bytes" in metric):
|
|
180
|
+
# calculate standard error for selected metrics
|
|
181
|
+
if std_err_mean_total_num_samples > 0:
|
|
182
|
+
# calculate the standard error of the mean (SEM) for the aggregated results (eg. add in quadrature)
|
|
183
|
+
# SEM = sqrt(sum(variance_i * n_i) / i)
|
|
184
|
+
# where variance_i is the variance of each group and i is the number of groups
|
|
185
|
+
# (the combined mean is also not weighted by the number of samples)
|
|
186
|
+
if math.isnan(std) or std_err_mean_total_num_samples == 0:
|
|
187
|
+
aggregated_results[f"StdErr {metric}"] = None
|
|
188
|
+
else:
|
|
189
|
+
aggregated_results[f"StdErr {metric}"] = np.sqrt(
|
|
190
|
+
std_err_mean_sum_of_squares / std_err_mean_num_subjects
|
|
191
|
+
)
|
|
192
|
+
aggregated_results[f"NumSamples {metric}"] = std_err_mean_total_num_samples
|
|
193
|
+
else:
|
|
194
|
+
# if there are no sub-groups to combine, calculate the SEM here directly
|
|
195
|
+
key_subject_std = data_subset_error_free.groupby(["key", "subject"]).std()
|
|
196
|
+
std = float(key_subject_std[["value"]].mean()["value"])
|
|
197
|
+
num_samples = len(data_subset_error_free)
|
|
198
|
+
if math.isnan(std) or num_samples == 0:
|
|
199
|
+
aggregated_results[f"StdErr {metric}"] = None
|
|
200
|
+
else:
|
|
201
|
+
aggregated_results[f"StdErr {metric}"] = std / np.sqrt(num_samples)
|
|
202
|
+
aggregated_results[f"NumSamples {metric}"] = num_samples
|
|
203
|
+
|
|
204
|
+
if (
|
|
205
|
+
"Average Bytes" in aggregated_results
|
|
206
|
+
and "Average SequencePositions" in aggregated_results
|
|
207
|
+
and aggregated_results["Average Bytes"]
|
|
208
|
+
and aggregated_results["Average SequencePositions"]
|
|
209
|
+
):
|
|
210
|
+
aggregated_results["Average Bytes per Sequence Position"] = (
|
|
211
|
+
aggregated_results["Average Bytes"] / aggregated_results["Average SequencePositions"]
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return aggregated_results
|
|
215
|
+
|
|
216
|
+
def run_eval(self) -> list[Result]:
|
|
217
|
+
"""Runs evaluation using saved completions."""
|
|
218
|
+
logger.info("Running evaluation...")
|
|
219
|
+
responses = self.result_processor.load_responses()
|
|
220
|
+
if not responses:
|
|
221
|
+
raise ValueError("No saved completions found. Run 'run_completions' first.")
|
|
222
|
+
|
|
223
|
+
metrics_results = self._run_metric_calculators(responses)
|
|
224
|
+
aggregated_results = self._aggregate_results(metrics_results)
|
|
225
|
+
|
|
226
|
+
wandb.log(aggregated_results)
|
|
227
|
+
|
|
228
|
+
self.result_processor.save_aggregated_results(aggregated_results)
|
|
229
|
+
logger.info(aggregated_results)
|
|
230
|
+
logger.info(f"{RED}[ Evaluation completed and results saved! ]{RESET}")
|
|
231
|
+
return metrics_results
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
IFEval implementation here is taken 1:1 (besides some minuscule details, like adding Swedish to `LANGUAGE_CODES`)
|
|
2
|
+
from [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/lm_eval/tasks/ifeval) as of Dec 19 2024.
|
|
3
|
+
The original repository is https://github.com/google-research/google-research/tree/master/instruction_following_eval .
|
|
4
|
+
|
|
5
|
+
This code doesn't have to be unit-tested.
|