eval-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +170 -0
  5. eval_framework/context/eval.py +114 -0
  6. eval_framework/context/local.py +52 -0
  7. eval_framework/evaluation_generator.py +231 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +323 -0
  16. eval_framework/llm/base.py +58 -0
  17. eval_framework/llm/huggingface.py +332 -0
  18. eval_framework/llm/mistral.py +73 -0
  19. eval_framework/llm/models.py +16 -0
  20. eval_framework/llm/openai.py +205 -0
  21. eval_framework/llm/vllm.py +438 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +187 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/bleu.py +76 -0
  29. eval_framework/metrics/completion/chrf.py +62 -0
  30. eval_framework/metrics/completion/code_assertion.py +44 -0
  31. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  32. eval_framework/metrics/completion/comet.py +56 -0
  33. eval_framework/metrics/completion/concordance_index.py +38 -0
  34. eval_framework/metrics/completion/csv_format.py +102 -0
  35. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  36. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  37. eval_framework/metrics/completion/f1.py +42 -0
  38. eval_framework/metrics/completion/format_checker.py +56 -0
  39. eval_framework/metrics/completion/grid_difference.py +77 -0
  40. eval_framework/metrics/completion/ifeval.py +73 -0
  41. eval_framework/metrics/completion/json_format.py +171 -0
  42. eval_framework/metrics/completion/language_checker.py +74 -0
  43. eval_framework/metrics/completion/length_control.py +83 -0
  44. eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
  45. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  46. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  47. eval_framework/metrics/completion/repetition.py +88 -0
  48. eval_framework/metrics/completion/rouge_1.py +35 -0
  49. eval_framework/metrics/completion/rouge_2.py +45 -0
  50. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  51. eval_framework/metrics/completion/rouge_l.py +52 -0
  52. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  53. eval_framework/metrics/completion/ter.py +67 -0
  54. eval_framework/metrics/completion/text_counter.py +182 -0
  55. eval_framework/metrics/efficiency/__init__.py +0 -0
  56. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  57. eval_framework/metrics/llm/__init__.py +0 -0
  58. eval_framework/metrics/llm/base.py +8 -0
  59. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  60. eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
  61. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  62. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  63. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  64. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  65. eval_framework/metrics/llm/graders/language.py +56 -0
  66. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  67. eval_framework/metrics/llm/graders/models.py +74 -0
  68. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  69. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  70. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  71. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  72. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  73. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  74. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  75. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  76. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  77. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
  78. eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
  79. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  80. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  81. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  82. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  83. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  84. eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
  85. eval_framework/py.typed +0 -0
  86. eval_framework/response_generator.py +416 -0
  87. eval_framework/result_processors/__init__.py +0 -0
  88. eval_framework/result_processors/base.py +74 -0
  89. eval_framework/result_processors/hf_processor.py +87 -0
  90. eval_framework/result_processors/result_processor.py +129 -0
  91. eval_framework/run.py +314 -0
  92. eval_framework/run_direct.py +42 -0
  93. eval_framework/shared/types.py +227 -0
  94. eval_framework/tasks/__init__.py +6 -0
  95. eval_framework/tasks/base.py +314 -0
  96. eval_framework/tasks/benchmarks/__init__.py +0 -0
  97. eval_framework/tasks/benchmarks/arc.py +46 -0
  98. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  99. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  100. eval_framework/tasks/benchmarks/belebele.py +60 -0
  101. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  102. eval_framework/tasks/benchmarks/casehold.py +47 -0
  103. eval_framework/tasks/benchmarks/chembench.py +85 -0
  104. eval_framework/tasks/benchmarks/copa.py +39 -0
  105. eval_framework/tasks/benchmarks/duc.py +91 -0
  106. eval_framework/tasks/benchmarks/flores200.py +62 -0
  107. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  108. eval_framework/tasks/benchmarks/gpqa.py +177 -0
  109. eval_framework/tasks/benchmarks/gsm8k.py +148 -0
  110. eval_framework/tasks/benchmarks/hellaswag.py +44 -0
  111. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  112. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  113. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  114. eval_framework/tasks/benchmarks/include.py +119 -0
  115. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  116. eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
  117. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  118. eval_framework/tasks/benchmarks/mmlu.py +190 -0
  119. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  120. eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
  121. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  122. eval_framework/tasks/benchmarks/openbookqa.py +37 -0
  123. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  124. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  125. eval_framework/tasks/benchmarks/piqa.py +39 -0
  126. eval_framework/tasks/benchmarks/quality.py +56 -0
  127. eval_framework/tasks/benchmarks/sciq.py +44 -0
  128. eval_framework/tasks/benchmarks/sphyr.py +75 -0
  129. eval_framework/tasks/benchmarks/squad.py +89 -0
  130. eval_framework/tasks/benchmarks/struct_eval.py +110 -0
  131. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  132. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  133. eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
  134. eval_framework/tasks/benchmarks/winogender.py +39 -0
  135. eval_framework/tasks/benchmarks/winogrande.py +44 -0
  136. eval_framework/tasks/benchmarks/winox.py +57 -0
  137. eval_framework/tasks/benchmarks/wmt.py +160 -0
  138. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  139. eval_framework/tasks/eval_config.py +112 -0
  140. eval_framework/tasks/perturbation.py +83 -0
  141. eval_framework/tasks/registry.py +186 -0
  142. eval_framework/tasks/task_loader.py +80 -0
  143. eval_framework/tasks/task_names.py +138 -0
  144. eval_framework/tasks/utils.py +578 -0
  145. eval_framework/utils/constants.py +9 -0
  146. eval_framework/utils/generate_task_docs.py +229 -0
  147. eval_framework/utils/helpers.py +3 -0
  148. eval_framework/utils/logging.py +50 -0
  149. eval_framework/utils/packaging.py +52 -0
  150. eval_framework-0.2.0.dist-info/METADATA +514 -0
  151. eval_framework-0.2.0.dist-info/RECORD +161 -0
  152. eval_framework-0.2.0.dist-info/WHEEL +4 -0
  153. eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
  154. template_formatting/README.md +83 -0
  155. template_formatting/__init__.py +0 -0
  156. template_formatting/formatter.py +536 -0
  157. template_formatting/mistral_formatter.py +159 -0
  158. template_formatting/py.typed +0 -0
  159. template_formatting/tests/test_formatter_eval.py +408 -0
  160. template_formatting/tests/test_formatter_scaling.py +253 -0
  161. template_formatting/tests/test_mistral_formatter.py +136 -0
@@ -0,0 +1,76 @@
1
+ import sacrebleu
2
+
3
+ from eval_framework.exceptions import LogicError
4
+ from eval_framework.metrics.base import BaseMetric, MetricResult
5
+ from eval_framework.shared.types import Completion
6
+
7
+
8
+ class BLEU(BaseMetric[Completion]):
9
+ """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
10
+ for evaluating a generated sentence to a reference sentence. It counts matching
11
+ n-grams in the candidate translation to n-grams in the reference text, where
12
+ 1-gram or unigram would be each token and a bigram comparison would be each
13
+ word pair. The comparison is made regardless of word order
14
+ Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
15
+ Paper: https://www.aclweb.org/anthology/P02-1040/
16
+ """
17
+
18
+ NAME = "BLEU"
19
+
20
+ def calculate(self, response: Completion) -> list[MetricResult]:
21
+ if response.error is not None:
22
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
23
+
24
+ scores = []
25
+ for ground_truth in response.ground_truth_list:
26
+ if ground_truth == "" or ground_truth is None:
27
+ raise LogicError("When calculating BLEU we need a ground truth.")
28
+
29
+ sacre_formatted_completion = [response.completion]
30
+ sacre_formatted_ground_truth = [[ground_truth]]
31
+ scores.append(sacrebleu.corpus_bleu(sacre_formatted_completion, sacre_formatted_ground_truth).score)
32
+
33
+ return [
34
+ MetricResult(metric_name=self.NAME, value=float(max(scores)), higher_is_better=True, error=response.error)
35
+ ]
36
+
37
+
38
+ class LINEWISE_BLEU(BaseMetric[Completion]):
39
+ """Maximum Line-level BLEU score."""
40
+
41
+ NAME = "Linewise BLEU"
42
+
43
+ def calculate(self, response: Completion) -> list[MetricResult]:
44
+ if response.error is not None:
45
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
46
+
47
+ scores = []
48
+ for ground_truth in response.ground_truth_list:
49
+ for sentence in response.completion.split("\n"):
50
+ if sentence == "":
51
+ continue
52
+
53
+ if ground_truth == "" or ground_truth is None:
54
+ raise LogicError("When calculating BLEU we need a ground truth.")
55
+
56
+ sacre_formatted_completion = [sentence]
57
+ sacre_formatted_ground_truth = [[ground_truth]]
58
+ scores.append(sacrebleu.corpus_bleu(sacre_formatted_completion, sacre_formatted_ground_truth).score)
59
+
60
+ return [
61
+ MetricResult(
62
+ metric_name=self.NAME, value=float(max(scores, default=0)), higher_is_better=True, error=response.error
63
+ )
64
+ ]
65
+
66
+
67
+ class ResponseToOriginalBLEU(BaseMetric[Completion]):
68
+ NAME = "Response to Original BLEU"
69
+
70
+ def calculate(self, response: Completion) -> list[MetricResult]:
71
+ if response.error is not None:
72
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
73
+
74
+ score = sacrebleu.corpus_bleu([response.completion], [[response.last_user_instruction]]).score
75
+ # scaled to [0, 1] to make aggregation easier
76
+ return [MetricResult(metric_name=self.NAME, value=score / 100, higher_is_better=True, error=response.error)]
@@ -0,0 +1,62 @@
1
+ import sacrebleu
2
+
3
+ from eval_framework.exceptions import LogicError
4
+ from eval_framework.metrics.base import BaseMetric, MetricResult
5
+ from eval_framework.shared.types import Completion
6
+
7
+
8
+ class CHRF(BaseMetric[Completion]):
9
+ """chrF++ is a tool for automatic evaluation of machine translation output
10
+ based on character n-gram precision and recall enhanced with word n-grams.
11
+ Source: https://github.com/m-popovic/chrF
12
+ Paper: https://www.aclweb.org/anthology/W15-3049.pdf
13
+ """
14
+
15
+ NAME = "chrF"
16
+
17
+ def calculate(self, response: Completion) -> list[MetricResult]:
18
+ if response.error is not None:
19
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
20
+
21
+ scores = []
22
+ for ground_truth in response.ground_truth_list:
23
+ if ground_truth == "" or ground_truth is None:
24
+ raise LogicError("When calculating chrF we need a ground truth.")
25
+
26
+ sacre_formatted_completion = [response.completion]
27
+ sacre_formatted_ground_truth = [[ground_truth]]
28
+ scores.append(sacrebleu.corpus_chrf(sacre_formatted_completion, sacre_formatted_ground_truth).score)
29
+
30
+ return [
31
+ MetricResult(metric_name=self.NAME, value=float(max(scores)), higher_is_better=True, error=response.error)
32
+ ]
33
+
34
+
35
+ class LINEWISE_CHRF(BaseMetric[Completion]):
36
+ """Maximum Line-level chrF++ (Character n-gram F-score) score.
37
+ Paper: https://aclanthology.org/W15-3049/"""
38
+
39
+ NAME = "Linewise chrF"
40
+
41
+ def calculate(self, response: Completion) -> list[MetricResult]:
42
+ if response.error is not None:
43
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
44
+
45
+ scores = []
46
+ for ground_truth in response.ground_truth_list:
47
+ for sentence in response.completion.split("\n"):
48
+ if sentence == "":
49
+ continue
50
+
51
+ if ground_truth == "" or ground_truth is None:
52
+ raise LogicError("When calculating chrF we need a ground truth.")
53
+
54
+ sacre_formatted_completion = [sentence]
55
+ sacre_formatted_ground_truth = [[ground_truth]]
56
+ scores.append(sacrebleu.corpus_chrf(sacre_formatted_completion, sacre_formatted_ground_truth).score)
57
+
58
+ return [
59
+ MetricResult(
60
+ metric_name=self.NAME, value=float(max(scores, default=0)), higher_is_better=True, error=response.error
61
+ )
62
+ ]
@@ -0,0 +1,44 @@
1
+ from eval_framework.metrics.base import BaseMetric, MetricResult
2
+ from eval_framework.shared.types import Completion, Error
3
+ from eval_framework.tasks.utils import run_python_code
4
+
5
+
6
+ class CodeCompletionAssertion(BaseMetric[Completion]):
7
+ NAME = "Code Completion Accuracy"
8
+
9
+ def calculate(self, response: Completion) -> list[MetricResult]:
10
+ if response.error is not None:
11
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
12
+
13
+ # this will always be a list, if return is "" this will be an empty list
14
+ code = response.completion
15
+ output = run_python_code(code, image="python:3.12-slim")
16
+
17
+ # Split and filter out empty strings
18
+ output_parts = [part for part in output.split() if part.strip()]
19
+
20
+ if not output_parts:
21
+ last_output = ""
22
+ else:
23
+ last_output = output_parts[-1]
24
+
25
+ success = last_output == "True"
26
+ error = (
27
+ None
28
+ if success
29
+ else Error(
30
+ error_class="CodeCompletionAssertionError",
31
+ message=f"Expected 'True' but got '{last_output}'",
32
+ traceback=output,
33
+ )
34
+ )
35
+
36
+ return [
37
+ MetricResult(
38
+ metric_name=self.NAME,
39
+ value=1.0 if success else 0.0,
40
+ higher_is_better=True,
41
+ error=error,
42
+ code_execution_trace=output,
43
+ )
44
+ ]
@@ -0,0 +1,126 @@
1
+ import traceback
2
+ from collections.abc import Callable
3
+ from typing import Self
4
+
5
+ from pydantic import Field
6
+
7
+ from eval_framework.metrics.base import BaseMetric, MetricResult
8
+ from eval_framework.shared.types import BaseMetricContext, Completion, Error, extract_context_metric
9
+ from eval_framework.tasks.utils import CallableSerializer, ExecutionResult, execute_python_code_with_tests
10
+
11
+
12
+ class CodeExecutionBaseContext(BaseMetricContext):
13
+ run_env: str = Field(description="Name of docker image to run unit-tests inside")
14
+ code_prompt: str = Field(description="Prompt to LLM for code generation")
15
+ test_code: str = Field(description="Python code that contains logic for unit test execution")
16
+ benchmark_timeout: int = Field(default=60, description="Time in seconds for the full test execution run")
17
+ package_downloads: dict[str, str | None] = Field(
18
+ description="a dictionary listing the packages and their respective names in PyPiinto the LLM sandbox"
19
+ )
20
+
21
+
22
+ class CodeExecutionPassAtOneContext(CodeExecutionBaseContext):
23
+ snippet_merge_fn: str = Field(
24
+ description="logic for merging LLM generated code with test execution code;"
25
+ "this code will be passed into the sandbox to run the testing process"
26
+ "This code is serialized"
27
+ )
28
+ output_parse_fn: str = Field(
29
+ description="logic for parsing the output of test code execution run within the LLM sandbox"
30
+ "This code is serialized"
31
+ )
32
+
33
+
34
+ class RealtimeCodeExectionContext(CodeExecutionBaseContext):
35
+ snippet_merge_fn: Callable[[str, str], str] = Field(
36
+ description="logic for merging LLM generated code with test execution code;"
37
+ "this code will be passed into the sandbox to run the testing process"
38
+ "This code is deserialized"
39
+ )
40
+ output_parse_fn: Callable[[str], ExecutionResult] = Field(
41
+ description="logic for parsing the output of test code execution run within the LLM sandbox"
42
+ "This code is deserialized"
43
+ )
44
+
45
+ @classmethod
46
+ def from_context(cls, context: CodeExecutionPassAtOneContext) -> Self:
47
+ return cls(
48
+ run_env=context.run_env,
49
+ code_prompt=context.code_prompt,
50
+ test_code=context.test_code,
51
+ benchmark_timeout=context.benchmark_timeout,
52
+ snippet_merge_fn=CallableSerializer.decode(context.snippet_merge_fn),
53
+ output_parse_fn=CallableSerializer.decode(context.output_parse_fn),
54
+ package_downloads=context.package_downloads,
55
+ )
56
+
57
+
58
+ class CodeExecutionPassAtOne(BaseMetric[Completion]):
59
+ NAME = "code-execution-pass@1"
60
+
61
+ def __init__(self) -> None:
62
+ self.k = 1
63
+ # NOTE : this serializer should be the same class as initialized in the benchmark
64
+ self.serializer = CallableSerializer()
65
+
66
+ def calculate(self, response: Completion) -> list[MetricResult]:
67
+ if response.error is not None:
68
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
69
+ try:
70
+ context = extract_context_metric(response, CodeExecutionPassAtOneContext)
71
+ parsed_context = RealtimeCodeExectionContext.from_context(context)
72
+ except Exception as e:
73
+ raise Exception(f"Failed to rebuild parsing functions => {e}")
74
+
75
+ n = 1 # we only support N=1 at the moment
76
+ try:
77
+ c, output = self._count_correct_samples(response.completion, parsed_context)
78
+ except Exception as e:
79
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback=traceback.format_exc())
80
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
81
+
82
+ pass_at_k_value = estimate_pass_at_k(n, c, self.k)
83
+ return [
84
+ MetricResult(
85
+ metric_name=self.NAME,
86
+ value=pass_at_k_value,
87
+ higher_is_better=True,
88
+ error=response.error,
89
+ code_execution_trace=output,
90
+ )
91
+ ]
92
+
93
+ def _count_correct_samples(self, completion: str, context: RealtimeCodeExectionContext) -> tuple[int, str]:
94
+ result = execute_python_code_with_tests(
95
+ code=completion,
96
+ test_code=context.test_code,
97
+ package_mapping=context.package_downloads,
98
+ merge_code_fn=context.snippet_merge_fn,
99
+ image=context.run_env,
100
+ timeout=context.benchmark_timeout,
101
+ parse_output_fn=context.output_parse_fn,
102
+ )
103
+ return (1 if result.success else 0), result.output
104
+
105
+
106
+ def estimate_pass_at_k(n: int, c: int, k: int) -> float:
107
+ """
108
+ Estimates pass@k for a single problem.
109
+
110
+ Parameters:
111
+ n (int): Total number of generated samples.
112
+ c (int): Number of correct samples.
113
+ k (int): Number of attempts or samples considered.
114
+
115
+ Returns:
116
+ float: The pass@k value.
117
+ """
118
+ if n - c < k:
119
+ return 1.0
120
+
121
+ # Calculate the probability that at least one of the k samples is correct
122
+ probability = 1.0
123
+ for i in range(k):
124
+ probability *= (n - c - i) / (n - i)
125
+
126
+ return 1.0 - probability
@@ -0,0 +1,56 @@
1
+ import torch
2
+ from comet import download_model, load_from_checkpoint
3
+
4
+ from eval_framework.exceptions import LogicError
5
+ from eval_framework.metrics.base import BaseMetric, MetricResult
6
+ from eval_framework.shared.types import Completion, UntemplatedPrompt
7
+ from eval_framework.utils.constants import ROOT_DIR
8
+
9
+ SAVING_DIR = ROOT_DIR / "comet_model"
10
+
11
+
12
+ class COMET(BaseMetric[Completion]):
13
+ """COMET is a neural, multilingual framework for evaluating machine translation quality by leveraging cross-lingual
14
+ pretrained language models to achieve state-of-the-art correlation with human judgments
15
+ Note: this requires a Hugging Face token with access to the model: https://huggingface.co/Unbabel/XCOMET-XL
16
+ Source: https://github.com/Unbabel/COMET
17
+ Paper: https://arxiv.org/abs/2009.09025
18
+ """
19
+
20
+ NAME = "COMET"
21
+
22
+ def __init__(self) -> None:
23
+ checkpoint_path = download_model("Unbabel/XCOMET-XL", saving_directory=SAVING_DIR)
24
+ self.model = load_from_checkpoint(checkpoint_path)
25
+ assert torch.cuda.is_available(), "COMET requires a GPU"
26
+
27
+ def calculate(self, response: Completion) -> list[MetricResult]:
28
+ if response.error is not None:
29
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
30
+
31
+ if (
32
+ response.context is None
33
+ or not isinstance(response.context, UntemplatedPrompt)
34
+ or response.context.untemplated_prompt == ""
35
+ ):
36
+ raise LogicError("When calculating COMET we need an untemplated prompt.")
37
+
38
+ scores = []
39
+ for ground_truth in response.ground_truth_list:
40
+ if ground_truth == "" or ground_truth is None:
41
+ raise LogicError("When calculating COMET we need a ground truth.")
42
+
43
+ data = [
44
+ {
45
+ "src": response.context.untemplated_prompt.strip(),
46
+ "mt": response.completion.strip(),
47
+ "ref": ground_truth.strip(),
48
+ },
49
+ ]
50
+ with torch.no_grad():
51
+ model_output = self.model.predict(data, gpus=1)
52
+ scores.append(model_output.system_score)
53
+
54
+ return [
55
+ MetricResult(metric_name=self.NAME, value=float(max(scores)), higher_is_better=True, error=response.error)
56
+ ]
@@ -0,0 +1,38 @@
1
+ import ast
2
+
3
+ from eval_framework.metrics.base import BaseMetric, MetricResult
4
+ from eval_framework.shared.types import Completion
5
+
6
+
7
+ class ConcordanceIndex(BaseMetric[Completion]):
8
+ NAME = "ConcordanceIndex"
9
+
10
+ def calculate(self, response: Completion) -> list[MetricResult]:
11
+ if response.error is not None:
12
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
13
+
14
+ ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
15
+ if not ground_truths:
16
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
17
+
18
+ concordance_count = max([calculate_concordance_index(gt, response.completion) for gt in ground_truths])
19
+ return [
20
+ MetricResult(metric_name=self.NAME, value=concordance_count, higher_is_better=True, error=response.error)
21
+ ]
22
+
23
+
24
+ def calculate_concordance_index(
25
+ ground_truth: str,
26
+ completion: str,
27
+ ) -> float:
28
+ ground_truth_arr = ast.literal_eval(ground_truth)
29
+ completion_arr = ast.literal_eval(completion)
30
+
31
+ if len(ground_truth_arr) != len(completion_arr):
32
+ return 0
33
+
34
+ concordance_count = 0
35
+ for gt, c in zip(ground_truth_arr, completion_arr):
36
+ concordance_count += 1 if gt == c else 0
37
+
38
+ return concordance_count / len(ground_truth_arr)
@@ -0,0 +1,102 @@
1
+ import json
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from eval_framework.metrics.base import BaseMetric, MetricResult
6
+ from eval_framework.shared.types import Completion
7
+
8
+ SEPARATOR_MAP = {"comma": ",", "semicolon": ";", "space": " ", "tab": "\t"}
9
+
10
+
11
+ class CSVFormatEvaluation(BaseModel):
12
+ implicit: bool = False
13
+ has_csv: bool = False
14
+ is_separator_respected: bool = False
15
+ is_column_count_respected: bool = False
16
+
17
+
18
+ class CSVFormat(BaseMetric[Completion]):
19
+ NAME = "CSV Format"
20
+ KEYS = ["has_csv", "is_separator_respected", "is_column_count_respected"]
21
+
22
+ def calculate(self, response: Completion) -> list[MetricResult]:
23
+ if response.error is not None:
24
+ return [
25
+ MetricResult(metric_name=f"{self.NAME}/{k}", value=None, higher_is_better=True, error=response.error)
26
+ for k in self.KEYS
27
+ ]
28
+
29
+ if response.completion == "":
30
+ return [
31
+ MetricResult(metric_name=f"{self.NAME}/{k}", value=0.0, higher_is_better=True, error=response.error)
32
+ for k in self.KEYS
33
+ ]
34
+
35
+ grading = evaluate_csv_format(response)
36
+
37
+ results = []
38
+ for key in self.KEYS:
39
+ result = MetricResult(
40
+ metric_name=f"{self.NAME}/{key}",
41
+ value=float(getattr(grading, key)),
42
+ higher_is_better=True,
43
+ error=response.error,
44
+ )
45
+ results.append(result)
46
+ return results
47
+
48
+
49
+ def extract_csv_from_text(text: str, min_rows: int = 2, min_columns: int = 2) -> tuple[list[str] | None, str | None]:
50
+ lines = text.split("\n")
51
+ delimiters = set(SEPARATOR_MAP.values())
52
+ best_delimiter = None
53
+ csv_lines: list[str] = []
54
+
55
+ # Iterate over lines to find potential delimiters and consistent substring counts
56
+ for i, line in enumerate(lines):
57
+ for delimiter in delimiters:
58
+ substrings = line.split(delimiter)
59
+ if len(substrings) < min_columns:
60
+ continue
61
+
62
+ current_csv_lines = [line]
63
+ for j in range(i + 1, len(lines)):
64
+ next_line = lines[j]
65
+ next_substrings = next_line.split(delimiter)
66
+ if len(next_substrings) != len(substrings):
67
+ break
68
+ current_csv_lines.append(next_line)
69
+ if len(current_csv_lines) >= min_rows and len(current_csv_lines) > len(csv_lines):
70
+ best_delimiter = delimiter
71
+ csv_lines = current_csv_lines
72
+
73
+ if not csv_lines:
74
+ return None, None
75
+
76
+ return csv_lines, best_delimiter
77
+
78
+
79
+ def evaluate_csv_format(response: Completion) -> CSVFormatEvaluation:
80
+ expected_output = json.loads(str(response.ground_truth))
81
+
82
+ expected_separator_code = expected_output["separator"]
83
+ csv_lines, separator = extract_csv_from_text(response.completion)
84
+
85
+ if not csv_lines:
86
+ return CSVFormatEvaluation(has_csv=False, implicit=not expected_separator_code)
87
+
88
+ csv_format_evaluation = CSVFormatEvaluation(has_csv=True, implicit=not expected_separator_code)
89
+
90
+ if not expected_separator_code:
91
+ csv_format_evaluation.is_separator_respected = separator in SEPARATOR_MAP.values()
92
+ else:
93
+ csv_format_evaluation.is_separator_respected = separator == SEPARATOR_MAP.get(expected_separator_code)
94
+
95
+ expected_column_count = len(expected_output["columns"])
96
+ column_counts = [len(csv_lines.split(separator)) for csv_lines in csv_lines]
97
+
98
+ csv_format_evaluation.is_column_count_respected = all(
99
+ column_count == expected_column_count for column_count in column_counts
100
+ )
101
+
102
+ return csv_format_evaluation
@@ -0,0 +1,49 @@
1
+ import re
2
+
3
+ from eval_framework.metrics.base import BaseMetric, MetricResult
4
+ from eval_framework.shared.types import Completion, Error
5
+
6
+
7
+ class CWEAccuracy(BaseMetric[Completion]):
8
+ """Metric for Common Word Extraction tasks"""
9
+
10
+ NAME = "CWEAccuracy"
11
+
12
+ def calculate(self, response: Completion) -> list[MetricResult]:
13
+ if response.error is not None:
14
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
15
+
16
+ ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
17
+ if not ground_truths:
18
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
19
+
20
+ try:
21
+ # Get model's answer
22
+ model_answer = response.completion
23
+
24
+ # Check if all words in the correct answer are present in the model's answer
25
+ is_correct = self._is_answer_correct(ground_truths, model_answer)
26
+
27
+ return [
28
+ MetricResult(
29
+ metric_name=self.NAME, value=1.0 if is_correct else 0.0, higher_is_better=True, error=response.error
30
+ )
31
+ ]
32
+ except Exception as e:
33
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback="")
34
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
35
+
36
+ def _is_answer_correct(self, correct_answer: list[str], model_answer: str) -> bool:
37
+ """Check if all words in correct_answer are present in model_answer as whole words"""
38
+ model_answer = model_answer.strip().lower()
39
+ correct_answer = [correct.strip().lower() for correct in correct_answer]
40
+
41
+ # For each word in the correct answer, check if it exists as a whole word in the model answer
42
+ for word in correct_answer:
43
+ # Create a regex pattern that matches the word as a whole word
44
+ # \b represents a word boundary
45
+ pattern = r"\b" + re.escape(word) + r"\b"
46
+ if not re.search(pattern, model_answer):
47
+ return False
48
+
49
+ return True
@@ -0,0 +1,65 @@
1
+ from eval_framework.metrics.base import BaseMetric, MetricResult
2
+ from eval_framework.shared.types import Completion, Error
3
+
4
+
5
+ class ExponentialSimilarity(BaseMetric[Completion]):
6
+ NAME = "ExponentialSimilarity"
7
+
8
+ def calculate(self, response: Completion) -> list[MetricResult]:
9
+ if response.error is not None:
10
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
11
+
12
+ ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
13
+ if not ground_truths:
14
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
15
+
16
+ try:
17
+ # Try to calculate exponential similarity for each ground truth
18
+ similarities = []
19
+ for gt in ground_truths:
20
+ try:
21
+ gt_float = float(gt)
22
+ completion_float = float(response.completion)
23
+ similarities.append(calculate_exponential_similarity(gt_float, completion_float))
24
+ except (ValueError, TypeError):
25
+ # Skip this ground truth if conversion fails
26
+ continue
27
+
28
+ # If we have any valid similarities, return the max
29
+ if similarities:
30
+ return [
31
+ MetricResult(
32
+ metric_name=self.NAME, value=max(similarities), higher_is_better=True, error=response.error
33
+ )
34
+ ]
35
+ else:
36
+ # If all conversions failed, return an error
37
+ error = Error(
38
+ error_class="ValueError",
39
+ message="Could not convert ground truth or completion to float",
40
+ traceback="",
41
+ )
42
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
43
+ except Exception as e:
44
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback="")
45
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
46
+
47
+
48
+ def calculate_exponential_similarity(p_true: float, p_pred: float) -> float:
49
+ """
50
+ Compute the exponential similarity (SpaceDigest version) between
51
+ the gold percentage and predicted value.
52
+
53
+ Parameters:
54
+ - p_true (float): The gold/reference percentage.
55
+ - p_pred (float): The predicted scalar.
56
+ - d (float): Base of the exponent. Default is 2.
57
+ - c (float): Coefficient in exponent. Default is 10.
58
+
59
+ Returns:
60
+ - float: Similarity score between 0 and 1.
61
+ """
62
+ d = 2
63
+ c = 10
64
+
65
+ return d ** (-c * abs(p_true / 100 - p_pred / 100))
@@ -0,0 +1,42 @@
1
+ from collections import Counter
2
+ from typing import Any
3
+
4
+ from eval_framework.metrics.base import BaseMetric, MetricResult
5
+ from eval_framework.shared.types import Completion
6
+
7
+
8
+ class F1(BaseMetric[Completion]):
9
+ NAME = "F1"
10
+
11
+ def calculate(self, response: Completion) -> list[MetricResult]:
12
+ if response.error is not None:
13
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
14
+
15
+ ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
16
+ if not ground_truths:
17
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
18
+
19
+ hyp_tokens = response.completion.lower().split()
20
+ f1_scores = [calculate_f1(gt.lower().split(), hyp_tokens) for gt in ground_truths]
21
+ max_f1 = max(f1_scores)
22
+
23
+ return [MetricResult(metric_name=self.NAME, value=max_f1, higher_is_better=True, error=response.error)]
24
+
25
+
26
+ def calculate_f1(ref_tokens: list[Any], hyp_tokens: list[Any]) -> float:
27
+ """Calculate F1 score between two texts based on token overlap."""
28
+ if not ref_tokens and not hyp_tokens:
29
+ return 1.0
30
+ if not ref_tokens or not hyp_tokens:
31
+ return 0.0
32
+
33
+ common = Counter(ref_tokens) & Counter(hyp_tokens)
34
+ num_same = sum(common.values())
35
+
36
+ if num_same == 0:
37
+ return 0.0
38
+
39
+ precision = num_same / len(hyp_tokens)
40
+ recall = num_same / len(ref_tokens)
41
+
42
+ return 2 * precision * recall / (precision + recall)