eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,129 @@
1
+ import hashlib
2
+ import importlib.metadata
3
+ import json
4
+ import logging
5
+ import os
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+ import jsonlines
10
+ from pydantic import RootModel
11
+
12
+ from eval_framework.result_processors.base import Result, ResultProcessor
13
+ from eval_framework.shared.types import Completion, Loglikelihood
14
+ from eval_framework.tasks.eval_config import EvalConfig
15
+
16
+ MAIN = "eval_framework_results"
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class ResultsFileProcessor(ResultProcessor):
21
+ def __init__(self, output_dir: Path) -> None:
22
+ self.output_dir = output_dir
23
+ os.makedirs(self.output_dir, exist_ok=True)
24
+
25
+ def save_metadata(self, metadata: dict) -> None:
26
+ with open(self.output_dir / "metadata.json", "w") as f:
27
+ json.dump(metadata, f, indent=2, sort_keys=True)
28
+
29
+ def load_metadata(self) -> dict:
30
+ metadata_file = self.output_dir / "metadata.json"
31
+ if os.path.exists(metadata_file):
32
+ with open(metadata_file) as f:
33
+ return json.load(f)
34
+ else:
35
+ logger.info("No metadata found.")
36
+ return {}
37
+
38
+ def save_responses(self, responses: list[Completion | Loglikelihood]) -> None:
39
+ responses_data = [response.model_dump(mode="json", serialize_as_any=True) for response in responses]
40
+ with jsonlines.open(self.output_dir / "output.jsonl", "w") as f:
41
+ f.write_all(responses_data)
42
+
43
+ def save_response(self, response: Completion | Loglikelihood) -> None:
44
+ with jsonlines.open(self.output_dir / "output.jsonl", "a") as f:
45
+ f.write(response.model_dump(mode="json", serialize_as_any=True))
46
+
47
+ def load_responses(self) -> list[Completion | Loglikelihood]:
48
+ output_file = self.output_dir / "output.jsonl"
49
+ broken_file = output_file.with_suffix(f".jsonl.broken.{datetime.now().strftime('%Y%m%d_%H%M%S')}")
50
+
51
+ try:
52
+ Item = RootModel[Loglikelihood | Completion]
53
+ with jsonlines.open(output_file, "r") as f:
54
+ responses = [Item.model_validate(x).root for x in f]
55
+ except FileNotFoundError:
56
+ logger.info("No saved completions found.")
57
+ responses = []
58
+ except (json.decoder.JSONDecodeError, jsonlines.jsonlines.InvalidLineError):
59
+ logger.info(f"Error decoding JSON, the file is corrupted. It will be renamed to {broken_file} and ignored")
60
+ output_file.rename(broken_file)
61
+ responses = []
62
+
63
+ ids_list = [(resp.id, resp.subject) for resp in responses]
64
+ if len(ids_list) != len(set(ids_list)) and "mtbench" not in str(output_file):
65
+ logger.info(
66
+ f"Error: {len(ids_list) - len(set(ids_list))} duplicate response IDs found, the file is corrupted. "
67
+ f"It will be renamed to {broken_file} and ignored"
68
+ )
69
+ output_file.rename(broken_file)
70
+ responses = []
71
+
72
+ return responses
73
+
74
+ def save_metrics_results(self, results: list[Result]) -> None:
75
+ result_data = [x.model_dump(mode="json") for x in results]
76
+ with jsonlines.open(self.output_dir / "results.jsonl", "w") as f:
77
+ f.write_all(result_data)
78
+
79
+ def save_metrics_result(self, result: Result) -> None:
80
+ with jsonlines.open(self.output_dir / "results.jsonl", "a") as f:
81
+ f.write(result.model_dump(mode="json"))
82
+
83
+ def save_aggregated_results(self, results: dict[str, float | None]) -> None:
84
+ with open(self.output_dir / "aggregated_results.json", "w") as f:
85
+ json.dump(results, f, indent=4, sort_keys=True)
86
+
87
+ def load_metrics_results(self) -> list[Result]:
88
+ results_file = self.output_dir / "results.jsonl"
89
+ try:
90
+ with jsonlines.open(results_file, "r") as f:
91
+ result_data = [Result.model_validate(x) for x in f]
92
+ return result_data
93
+ except FileNotFoundError:
94
+ logger.info("No saved metrics found.")
95
+ return []
96
+
97
+
98
+ def generate_output_dir(llm_name: str, config: EvalConfig) -> Path:
99
+ # get the package version
100
+ version_str = f"v{importlib.metadata.version('eval_framework')}"
101
+
102
+ # Handle None values for num_fewshot and num_samples
103
+ fewshot_str = f"fewshot_{config.num_fewshot}" if config.num_fewshot is not None else "fewshot_None"
104
+ samples_str = f"samples_{config.num_samples}" if config.num_samples is not None else "samples_None"
105
+ tokens_str = f"tokens_{config.max_tokens}" if config.max_tokens is not None else ""
106
+
107
+ # Serialize key parameters for inclusion in the name
108
+ params_str = f"{fewshot_str}__{samples_str}"
109
+ if tokens_str:
110
+ params_str += f"__{tokens_str}"
111
+
112
+ # Serialize the full config for hashing
113
+ # Convert the config to a dict and sort keys to ensure consistent hashing
114
+ config_json = config.model_json_robust_subset_dump()
115
+ config_hash = hashlib.sha256(config_json.encode("utf-8")).hexdigest()[:5] # Short hash of 5 characters
116
+
117
+ # Include the hash in the directory name
118
+ dir_name = f"{params_str}_{config_hash}"
119
+
120
+ # add timestamp to dir in debug mode
121
+ if os.environ.get("DEBUG", "FALSE").lower() == "true":
122
+ # Generate the timestamp
123
+ timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
124
+ dir_name += f"_{timestamp}"
125
+
126
+ # Combine all components to form the full output directory path
127
+ output_dir = config.output_dir / llm_name / f"{version_str}_{config.task_name}" / dir_name
128
+
129
+ return output_dir
@@ -0,0 +1,137 @@
1
+ """
2
+ Module for writing result folder to a W&B artifact
3
+ """
4
+
5
+ import hashlib
6
+ import logging
7
+ import subprocess
8
+ from collections.abc import Callable
9
+ from pathlib import Path
10
+
11
+ import wandb
12
+ from wandb.sdk.artifacts._validators import NAME_MAXLEN
13
+
14
+ from eval_framework.result_processors.base import ResultsUploader
15
+ from eval_framework.tasks.eval_config import EvalConfig
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ ArtifactUploadFunction = Callable[[str, str, list[Path]], str | None] # returns reference path for W&B or None
20
+ _ARTIFACT_UPLOAD_FUNCTION: ArtifactUploadFunction | None = None
21
+
22
+
23
+ def register_artifact_upload_function(func: ArtifactUploadFunction | None) -> None:
24
+ global _ARTIFACT_UPLOAD_FUNCTION
25
+ _ARTIFACT_UPLOAD_FUNCTION = func
26
+
27
+
28
+ def artifact_upload_function(artifact_name: str, subpath: str, file_paths: list[Path]) -> str | None:
29
+ if _ARTIFACT_UPLOAD_FUNCTION is None:
30
+ return None
31
+ logger.info(f"Uploading '{artifact_name}'.")
32
+ reference_path = _ARTIFACT_UPLOAD_FUNCTION(artifact_name, subpath, file_paths)
33
+ if reference_path is None:
34
+ logger.warning("Failed uploading, the custom upload function returned empty destination path!")
35
+ else:
36
+ logger.info(f"Successfully uploaded to {reference_path}.")
37
+ return reference_path
38
+
39
+
40
+ class WandbUploader(ResultsUploader):
41
+ def __init__(
42
+ self,
43
+ config: EvalConfig,
44
+ include_all: bool = True,
45
+ compress_non_json: bool = True,
46
+ wandb_registry: str | None = None,
47
+ ) -> None:
48
+ if not config.wandb_upload_results:
49
+ logger.warning("Results will not be persisted in WandB (`wandb_upload_results` not set).")
50
+ return
51
+ if config.output_dir is None:
52
+ raise ValueError("Output directory is not set in the configuration.")
53
+ if wandb.run is None or wandb.run.settings._noop:
54
+ logger.warning("Results will not be persisted in WandB (no WandB run active / `wandb_project` not set).")
55
+ return
56
+
57
+ self._include_all = include_all
58
+ self._compress_non_json = compress_non_json
59
+ self._wandb_registry = wandb_registry
60
+
61
+ def upload(self, llm_name: str, config: EvalConfig, output_dir: Path) -> bool:
62
+ if hasattr(self, "_wandb_registry") is False:
63
+ return False # not initialized
64
+
65
+ try:
66
+ if self._include_all and self._compress_non_json:
67
+ # note: individual gz files can be easily read by `less` or `grepz`, unlike a tar of multiple files
68
+ subprocess.run(
69
+ ["find", output_dir, "-type", "f", "!", "-name", "*.json", "-exec", "gzip", "-k", "{}", ";"],
70
+ check=True,
71
+ )
72
+ file_paths = list(output_dir.glob("*.json")) + list(output_dir.glob("*.gz"))
73
+ elif self._include_all:
74
+ file_paths = list(output_dir.glob("*"))
75
+ else:
76
+ file_paths = list(output_dir.glob("*.json"))
77
+
78
+ artifact_name = self._get_artifact_name(llm_name, config)
79
+ alias_name = self._get_alias(output_dir)
80
+
81
+ try:
82
+ rel_upload_dir = str(output_dir.relative_to(config.output_dir))
83
+ reference_path = artifact_upload_function(artifact_name, rel_upload_dir, file_paths)
84
+ except Exception as e:
85
+ logger.error(f"Problem during artifact upload function, aborting registration: {e}.")
86
+ return False
87
+
88
+ artifact = wandb.Artifact(name=artifact_name, type="eval") # note: metadata is added from run automatically
89
+ if reference_path:
90
+ artifact.add_reference(reference_path, checksum=True)
91
+ else:
92
+ logger.info("Uploading results directly to WandB.")
93
+ for fp in file_paths:
94
+ artifact.add_file(str(fp))
95
+
96
+ # Because metadata and e.g. logs are added to the artifact, we get a new version everytime!
97
+ # To mitigate this, we also add an alias based on the content hash of the "pure" result files.
98
+ wandb.log_artifact(artifact, aliases=[alias_name])
99
+ if self._wandb_registry:
100
+ artifact.link(f"wandb-registry-{self._wandb_registry}/{artifact_name}")
101
+ logger.info(f"Successfully registered '{artifact_name}' in WandB")
102
+
103
+ finally:
104
+ for fp in file_paths:
105
+ if fp.suffix == ".gz" and self._compress_non_json:
106
+ fp.unlink(missing_ok=True)
107
+
108
+ return True
109
+
110
+ def _get_artifact_name(self, llm_name: str, config: EvalConfig) -> str:
111
+ llm_name = llm_name.replace("/", "_") # assuming this has class name and checkpoint name in it
112
+
113
+ # As in generate_output_dir() for consistency, though shorter.
114
+ # But we don't include the eval_framework version and timestamp here (-> handled via W&B versioning)
115
+ fewshot_str = f"fs{config.num_fewshot}" if config.num_fewshot is not None else ""
116
+ samples_str = f"s{config.num_samples}" if config.num_samples is not None else ""
117
+ tokens_str = f"t{config.max_tokens}" if config.max_tokens is not None else ""
118
+ params_str = f"{fewshot_str}{samples_str}{tokens_str}"
119
+
120
+ config_json = config.model_json_robust_subset_dump()
121
+ config_hash = hashlib.sha256(config_json.encode("utf-8")).hexdigest()[:5]
122
+
123
+ # Respect W&B artifact name length limit
124
+ eval_name = f"__{config.task_name}__{params_str}_{config_hash}"
125
+ max_llm_name_len = NAME_MAXLEN - len(eval_name)
126
+ return llm_name[:max_llm_name_len] + eval_name
127
+
128
+ def _get_alias(self, output_dir: Path) -> str:
129
+ digests = []
130
+ # These files don't contain result-irrelevant things such as timestamps or paths and are fields are ordered.
131
+ # This makes them good for generating a hash that identifies the actual results and not "random" metadata.
132
+ for filename in ["aggregated_results.json", "output.jsonl", "results.jsonl"]:
133
+ if (output_dir / filename).exists():
134
+ with open(output_dir / filename, "rb") as f:
135
+ digests.append(hashlib.file_digest(f, "sha256").hexdigest())
136
+ hash = hashlib.sha256("".join(digests).encode("utf-8")).hexdigest()
137
+ return f"H-{hash[:10]}"
eval_framework/run.py ADDED
@@ -0,0 +1,369 @@
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ try:
8
+ from eval_framework.context.determined import DeterminedContext
9
+ except ImportError:
10
+ DeterminedContext = None # type: ignore
11
+
12
+
13
+ from eval_framework.context.local import LocalContext
14
+ from eval_framework.main import main
15
+ from eval_framework.tasks.task_loader import load_extra_tasks
16
+ from eval_framework.utils.logging import setup_logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ CONTEXT = {
21
+ "local": LocalContext,
22
+ "determined": DeterminedContext,
23
+ }
24
+
25
+
26
+ def parse_args() -> argparse.Namespace:
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument(
29
+ "--context",
30
+ type=str,
31
+ required=False,
32
+ default="local",
33
+ choices=["local", "determined"],
34
+ help="The context in which the evaluation is run.",
35
+ )
36
+ parser.add_argument(
37
+ "--models",
38
+ type=Path,
39
+ required=False,
40
+ default=Path(__file__).parent / "llm" / "models.py",
41
+ help="The path to the Python module file containing model classes.",
42
+ )
43
+ parser.add_argument(
44
+ "--extra-task-modules",
45
+ nargs="*",
46
+ default=[],
47
+ required=False,
48
+ help="List of files and folders containing additional task definitions.",
49
+ )
50
+ parser.add_argument(
51
+ "--llm-name",
52
+ type=str,
53
+ required=False,
54
+ help=(
55
+ "Either a full import path for a model (e.g., `eval_framework.llm.huggingface.HFLLM`) or the "
56
+ "name of a class derived from `eval_framework.llm.base.BaseLLM` that can be found in the "
57
+ "models file. The resulting model is instantiated with the arguments provided via `--llm-args`."
58
+ ),
59
+ )
60
+ parser.add_argument(
61
+ "--llm-args",
62
+ type=str,
63
+ nargs="+",
64
+ default=(),
65
+ required=False,
66
+ help="The arguments to pass to the LLM as key=value pairs.",
67
+ )
68
+ parser.add_argument(
69
+ "--num-samples", type=int, required=False, help="The number of samples per subject to evaluate."
70
+ )
71
+ parser.add_argument(
72
+ "--max-tokens",
73
+ type=int,
74
+ required=False,
75
+ help="The maximum number of tokens to generate for each sample. Overwrites any task default value.",
76
+ )
77
+ parser.add_argument(
78
+ "--num-fewshot", type=int, required=False, default=0, help="The number of fewshot examples to use."
79
+ )
80
+ parser.add_argument("--task-name", type=str, required=False, help="The name of the task to evaluate.")
81
+ parser.add_argument(
82
+ "--randomize-judge-order",
83
+ action="store_true",
84
+ help="Randomize the order of answers presented to the LLM judge to mitigate position bias.",
85
+ )
86
+
87
+ # Perturbation arguments
88
+ parser.add_argument(
89
+ "--perturbation-type",
90
+ type=str,
91
+ required=False,
92
+ choices=[
93
+ "editor",
94
+ "permute",
95
+ "replace",
96
+ "delete",
97
+ "uppercase",
98
+ ],
99
+ help=(
100
+ "The type of perturbation to apply. Note that this may not make sense for some prompts, for example those "
101
+ "containing math and code."
102
+ ),
103
+ )
104
+ parser.add_argument(
105
+ "--perturbation-probability",
106
+ type=float,
107
+ required=False,
108
+ default=None,
109
+ help="The probability of applying a perturbation to each word or character (between 0.0 and 1.0).",
110
+ )
111
+ parser.add_argument(
112
+ "--perturbation-seed",
113
+ type=int,
114
+ required=False,
115
+ default=42,
116
+ help="Random seed controlling perturbations.",
117
+ )
118
+ parser.add_argument(
119
+ "--task-subjects",
120
+ type=str,
121
+ nargs="+",
122
+ required=False,
123
+ help=(
124
+ "The subjects of the task to evaluate. If empty, all subjects are evaluated. Subjects in the form of "
125
+ "tuples can be specified in a comma-delimited way, possibly using wildcard * in some dimensions of a "
126
+ "tuple, e.g., 'DE_DE, *' or 'FR_FR, astronomy'."
127
+ ),
128
+ )
129
+ parser.add_argument(
130
+ "--hf-revision",
131
+ type=str,
132
+ required=False,
133
+ default=None,
134
+ help="A tag name, a branch name, or commit hash for the task HF dataset.",
135
+ )
136
+ parser.add_argument(
137
+ "--judge-models",
138
+ type=Path,
139
+ required=False,
140
+ help="The path to the Python module file containing LLM judge model classes.",
141
+ )
142
+ parser.add_argument(
143
+ "--output-dir",
144
+ type=Path,
145
+ default="outputs",
146
+ required=False,
147
+ help="The path for the evaluation outputs.",
148
+ )
149
+ parser.add_argument(
150
+ "--hf-upload-repo",
151
+ type=str,
152
+ default="Aleph-Alpha/evaluation-results",
153
+ required=False,
154
+ help="Customizable path for the HuggingFace git repository where runs will be stored.",
155
+ )
156
+ parser.add_argument(
157
+ "--hf-upload-dir",
158
+ type=str,
159
+ default="",
160
+ required=False,
161
+ help="Folder name for the HuggingFace git repository where runs will be stored.",
162
+ )
163
+ parser.add_argument(
164
+ "--wandb-project",
165
+ type=str,
166
+ default=None,
167
+ required=False,
168
+ help=(
169
+ "The name of the Weights & Biases project to log runs to. "
170
+ "The environment variable WANDB_API_KEY must be set."
171
+ ),
172
+ )
173
+ parser.add_argument(
174
+ "--wandb-entity",
175
+ type=str,
176
+ default=None,
177
+ required=False,
178
+ help="The name of the Weights & Biases entity to log runs to. Defaults to the user's default entity.",
179
+ )
180
+ parser.add_argument(
181
+ "--wandb-run-id",
182
+ type=str,
183
+ default=None,
184
+ required=False,
185
+ help=(
186
+ "The ID of an existing Weights & Biases run to resume. "
187
+ "If not given, creates a new run. If given and exists, "
188
+ "will continue the run but will overwrite the Python command logged in wandb."
189
+ ),
190
+ )
191
+ parser.add_argument(
192
+ "--wandb-upload-results",
193
+ action=argparse.BooleanOptionalAction,
194
+ required=False,
195
+ default=True,
196
+ help=("Whether to upload results as an artifact to Weights & Biases (default: True). Needs `--wandb-project`."),
197
+ )
198
+ parser.add_argument(
199
+ "--description",
200
+ type=str,
201
+ required=False,
202
+ help="Description of the run. This will be added to the metadata of the run to help with bookkeeping.",
203
+ )
204
+ parser.add_argument(
205
+ "--batch-size",
206
+ type=int,
207
+ default=1,
208
+ required=False,
209
+ help=(
210
+ "Size of batch of samples to send to the LLM for evaluation in parallel. "
211
+ "Use 1 for sequential running (default)."
212
+ ),
213
+ )
214
+ parser.add_argument(
215
+ "--save-logs",
216
+ action="store_true",
217
+ default=True,
218
+ required=False,
219
+ help="Whether to save logs to a file in the output directory (default: True).",
220
+ )
221
+
222
+ parser.add_argument(
223
+ "--judge-model-name",
224
+ type=str,
225
+ required=False,
226
+ help=(
227
+ "Either a full import path for a judge (e.g., `eval_framework.llm.huggingface.HFLLM`) or the "
228
+ "name of a class derived from `eval_framework.llm.base.BaseLLM` that can be found in the "
229
+ "models file. The resulting judge model is instantiated with the arguments provided via "
230
+ "`--judge-model-args`."
231
+ ),
232
+ )
233
+ parser.add_argument(
234
+ "--judge-model-args",
235
+ type=str,
236
+ required=False,
237
+ nargs="+",
238
+ default=(),
239
+ help="The args of the judge model used.",
240
+ )
241
+ parser.add_argument(
242
+ "--resource-cleanup",
243
+ action="store_true",
244
+ required=False,
245
+ default=False,
246
+ help=("Add this flag to free up GPU resources between response generation and evaluation"),
247
+ )
248
+ parser.add_argument(
249
+ "--delete-output-dir-after-upload",
250
+ action="store_true",
251
+ required=False,
252
+ default=False,
253
+ help=("Add this flag to remove the output directory after a successful upload to HF or WandB."),
254
+ )
255
+ parser.add_argument(
256
+ "-v",
257
+ "--verbosity",
258
+ type=int,
259
+ nargs="?",
260
+ default=1,
261
+ choices=[0, 1, 2],
262
+ help="Set the logging verbosity level: 0=critical, 1=info, 2=debug",
263
+ )
264
+
265
+ llm_args: dict[str, Any] = {}
266
+ args = parser.parse_args()
267
+
268
+ for arg in args.llm_args:
269
+ if "=" in arg:
270
+ key, value = arg.split("=", 1)
271
+
272
+ # Handle nested keys like "sampling_params.temperature=0.7"
273
+ if "." in key:
274
+ nested_key, sub_key = key.split(".", 1)
275
+ if nested_key not in llm_args:
276
+ llm_args[nested_key] = {}
277
+ llm_args[nested_key][sub_key] = value
278
+ else:
279
+ llm_args[key] = value
280
+
281
+ args.llm_args = llm_args
282
+
283
+ judge_model_args = {}
284
+ for arg in args.judge_model_args:
285
+ if "=" in arg:
286
+ key, value = arg.split("=", 1)
287
+ judge_model_args[key] = value
288
+
289
+ args.judge_model_args = judge_model_args
290
+
291
+ # if args.extra_task_modules:
292
+ # # Convert the comma-separated string into a list
293
+ # args.extra_task_modules = [file_or_dir.strip() for file_or_dir in args.extra_task_modules.split(",")]
294
+ # else:
295
+ # args.extra_task_modules = None
296
+
297
+ return args
298
+
299
+
300
+ def run_with_kwargs(kwargs: dict) -> None:
301
+ # Setup logging for the output directory
302
+ output_dir = kwargs.get("output_dir", "results")
303
+ log_level = kwargs.get("verbosity", 1)
304
+ setup_logging(output_dir, log_level=log_level)
305
+
306
+ logger.info(kwargs)
307
+
308
+ now = datetime.datetime.now()
309
+ logger.info(f"starting time: {now}")
310
+
311
+ if kwargs.get("extra_task_modules"):
312
+ load_extra_tasks(kwargs["extra_task_modules"])
313
+
314
+ context_name = kwargs.pop("context")
315
+
316
+ context = CONTEXT[context_name](
317
+ llm_name=kwargs["llm_name"],
318
+ models_path=kwargs["models"],
319
+ num_samples=kwargs["num_samples"],
320
+ max_tokens=kwargs["max_tokens"],
321
+ num_fewshot=kwargs["num_fewshot"],
322
+ task_name=kwargs["task_name"],
323
+ task_subjects=kwargs["task_subjects"],
324
+ hf_revision=kwargs["hf_revision"],
325
+ output_dir=kwargs["output_dir"],
326
+ wandb_project=kwargs["wandb_project"],
327
+ wandb_entity=kwargs["wandb_entity"],
328
+ wandb_run_id=kwargs["wandb_run_id"],
329
+ wandb_upload_results=kwargs["wandb_upload_results"],
330
+ hf_upload_dir=kwargs["hf_upload_dir"],
331
+ hf_upload_repo=kwargs["hf_upload_repo"],
332
+ llm_args=kwargs["llm_args"],
333
+ judge_models_path=kwargs["judge_models"],
334
+ judge_model_name=kwargs["judge_model_name"],
335
+ judge_model_args=kwargs["judge_model_args"],
336
+ batch_size=kwargs["batch_size"],
337
+ description=kwargs["description"],
338
+ perturbation_type=kwargs["perturbation_type"],
339
+ perturbation_probability=kwargs["perturbation_probability"],
340
+ perturbation_seed=kwargs["perturbation_seed"],
341
+ randomize_judge_order=kwargs["randomize_judge_order"],
342
+ delete_output_dir_after_upload=kwargs["delete_output_dir_after_upload"],
343
+ # save_logs=kwargs["save_logs"],
344
+ )
345
+
346
+ with context as ctx:
347
+ if ctx.config is None:
348
+ raise ValueError(f"Context configuration is not set for '{type(ctx)}'.")
349
+
350
+ main(
351
+ llm=ctx.config.llm_class(**ctx.config.llm_args),
352
+ config=ctx.config,
353
+ should_preempt_callable=ctx.should_preempt,
354
+ trial_id=ctx.get_trial_id(),
355
+ resource_cleanup=kwargs.pop("resource_cleanup", False),
356
+ verbosity=log_level,
357
+ )
358
+
359
+ logger.info(f"time since start: {datetime.datetime.now() - now}")
360
+
361
+
362
+ def run() -> None:
363
+ run_with_kwargs(vars(parse_args()))
364
+
365
+
366
+ # Enable execution via `python -m eval_framework.run`. Useful for
367
+ # debugging via `debugpy -m eval_framework.run`
368
+ if __name__ == "__main__":
369
+ run()
@@ -0,0 +1,42 @@
1
+ import datetime
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ from eval_framework.run import run_with_kwargs
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ if __name__ == "__main__":
10
+ logger.info(Path("models.py"))
11
+ now = datetime.datetime.now()
12
+ logger.info("starting time:", now)
13
+ # insert API token here
14
+
15
+ # main block for local debugging
16
+ kwargs = {
17
+ "context": "local",
18
+ "models": Path("src/eval_framework/llm/models.py"),
19
+ "judge_models": None,
20
+ "judge_model_name": None,
21
+ "judge_model_args": {},
22
+ # ---
23
+ "llm_name": "Llama31_8B_Instruct_API",
24
+ "llm_args": {},
25
+ "num_samples": 1,
26
+ "max_tokens": None,
27
+ "num_fewshot": 4,
28
+ "task_name": "Math", # complete task
29
+ "task_subjects": None,
30
+ "hf_revision": None,
31
+ "output_dir": Path("outputs"),
32
+ "hf_upload_dir": "",
33
+ "description": "",
34
+ "batch_size": 1,
35
+ "perturbation_type": None,
36
+ "perturbation_probability": None,
37
+ "perturbation_seed": None,
38
+ "save_logs": True,
39
+ }
40
+ run_with_kwargs(kwargs)
41
+
42
+ logger.info("time since start:", datetime.datetime.now() - now)