eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,88 @@
1
+ import re
2
+ from collections import Counter
3
+ from collections.abc import Sequence
4
+ from typing import Final
5
+
6
+ from eval_framework.metrics.base import BaseMetric, MetricResult
7
+ from eval_framework.shared.types import Completion
8
+
9
+
10
+ class WordRepetition(BaseMetric[Completion]):
11
+ """Word Repetition Metric
12
+
13
+ This metric checks for repetitions of words in the completion text for a
14
+ given window size and repetition threshold. The window size defines the
15
+ consecutive word count to consider a repetition, and min_repetitions
16
+ specifies the minimum repetition count that triggers the metric. This metric
17
+ returns 0.0 if no repetitions are found, and 1.0 if a sufficient number of
18
+ repetitions are found. For example, if the completion contains a two-word
19
+ sequence that repeats once (such as "hello world hello world"), this metric
20
+ would trigger with a window size of 2 and min_repetitions set to 1.
21
+ """
22
+
23
+ NAME = "WordRepetition"
24
+ HIGHER_IS_BETTER: Final[bool] = False
25
+
26
+ def __init__(self, window_size: int = 128, min_repetitions: int = 1) -> None:
27
+ """
28
+ Initialize the WordRepetition metric.
29
+
30
+ Args:
31
+ window_size (int): The number of consecutive words to consider as a
32
+ sequence.
33
+ min_repetitions (int): The minimum number of times a sequence must
34
+ repeat to be considered a repetition. Set to 1 to catch any
35
+ repetition.
36
+ """
37
+ super().__init__()
38
+ self.window_size = window_size
39
+ self.min_repetitions = min_repetitions
40
+
41
+ if self.min_repetitions < 1:
42
+ raise ValueError("min_repetitions must be at least 1")
43
+
44
+ if self.window_size < 1:
45
+ raise ValueError("window_size must be at least 1")
46
+
47
+ def calculate(self, response: Completion) -> list[MetricResult]:
48
+ if response.error is not None:
49
+ return [
50
+ MetricResult(
51
+ metric_name=self.NAME,
52
+ value=None,
53
+ higher_is_better=self.HIGHER_IS_BETTER,
54
+ error=response.error,
55
+ )
56
+ ]
57
+
58
+ has_repetition = _has_repetition(
59
+ text=response.completion,
60
+ window_size=self.window_size,
61
+ min_repetitions=self.min_repetitions,
62
+ )
63
+
64
+ return [
65
+ MetricResult(
66
+ metric_name=self.NAME,
67
+ value=float(has_repetition),
68
+ higher_is_better=self.HIGHER_IS_BETTER,
69
+ error=response.error,
70
+ )
71
+ ]
72
+
73
+
74
+ def _has_repetition(text: str, window_size: int, min_repetitions: int) -> bool:
75
+ """Check if the text contains any word sequences of a given size that repeat"""
76
+ sequences = _word_sequences(_to_words(text), window_size)
77
+ counts = Counter(sequences)
78
+ return any([count > min_repetitions for count in counts.values()])
79
+
80
+
81
+ def _to_words(text: str) -> Sequence[str]:
82
+ """A somewhat crude function to tokenize a string into words."""
83
+ return re.findall(r"\b\w+\b", text, re.UNICODE)
84
+
85
+
86
+ def _word_sequences(text_words: Sequence[str], window_size: int) -> Sequence[Sequence[str]]:
87
+ """Get all contiguous sub-sequences of a given size from a word sequence."""
88
+ return [tuple(text_words[i : i + window_size]) for i in range(len(text_words) - window_size + 1)]
@@ -0,0 +1,35 @@
1
+ from eval_framework.exceptions import LogicError
2
+ from eval_framework.metrics.base import BaseMetric, MetricResult
3
+ from eval_framework.metrics.completion.f1 import calculate_f1
4
+ from eval_framework.shared.types import Completion
5
+
6
+
7
+ class ROUGE_1(BaseMetric[Completion]):
8
+ """ROUGE-1"""
9
+
10
+ NAME = "ROUGE-1"
11
+
12
+ def calculate(self, response: Completion) -> list[MetricResult]:
13
+ if response.error is not None:
14
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
15
+
16
+ if response.completion == "":
17
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
18
+ if None in response.ground_truth_list:
19
+ raise LogicError("When calculating ROUGE-1 ground_truth cannot be None.")
20
+
21
+ # ROUGE-1 captures word sequence similarity by focusing on unigrams
22
+ rouge = max([_calculate_rouge_1(response.completion, gt) for gt in response.ground_truth_list]) # type: ignore[arg-type]
23
+ return [MetricResult(metric_name=self.NAME, value=float(rouge), higher_is_better=True, error=response.error)]
24
+
25
+
26
+ def _calculate_rouge_1(candidate: str, reference: str) -> float:
27
+ """
28
+ Calculate ROUGE-1 precision, recall, and F1 score between candidate and reference texts.
29
+ """
30
+
31
+ # Tokenize the candidate and reference summaries
32
+ candidate_tokens = candidate.split()
33
+ reference_tokens = reference.split()
34
+
35
+ return calculate_f1(reference_tokens, candidate_tokens)
@@ -0,0 +1,45 @@
1
+ from eval_framework.exceptions import LogicError
2
+ from eval_framework.metrics.base import BaseMetric, MetricResult
3
+ from eval_framework.metrics.completion.f1 import calculate_f1
4
+ from eval_framework.shared.types import Completion
5
+
6
+
7
+ class ROUGE_2(BaseMetric[Completion]):
8
+ """ROUGE-2"""
9
+
10
+ NAME = "ROUGE-2"
11
+
12
+ def calculate(self, response: Completion) -> list[MetricResult]:
13
+ if response.error is not None:
14
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
15
+
16
+ if response.completion == "":
17
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
18
+ if None in response.ground_truth_list:
19
+ raise LogicError("When calculating ROUGE-2 ground_truth cannot be None.")
20
+
21
+ # ROUGE-2 captures word sequence similarity by focusing on bigrams,
22
+ # which makes it sensitive to the order and co-occurrence of words to some extent.
23
+ rouge = max([_calculate_rouge_2(response.completion, gt) for gt in response.ground_truth_list]) # type: ignore[arg-type]
24
+ return [MetricResult(metric_name=self.NAME, value=float(rouge), higher_is_better=True, error=response.error)]
25
+
26
+
27
+ def _generate_bigrams(tokens: list[str]) -> list[tuple[str, str]]:
28
+ """Generate bigrams from a list of tokens."""
29
+ return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
30
+
31
+
32
+ def _calculate_rouge_2(completion: str, ground_truth: str) -> float:
33
+ """
34
+ Calculate ROUGE-2 precision, recall, and F1 score between candidate and reference texts.
35
+ """
36
+
37
+ # Tokenize the candidate and reference summaries
38
+ candidate_tokens = completion.split()
39
+ reference_tokens = ground_truth.split()
40
+
41
+ # Generate bigrams for candidate and reference
42
+ candidate_bigrams = _generate_bigrams(candidate_tokens)
43
+ reference_bigrams = _generate_bigrams(reference_tokens)
44
+
45
+ return calculate_f1(reference_bigrams, candidate_bigrams)
@@ -0,0 +1,36 @@
1
+ from eval_framework.exceptions import LogicError
2
+ from eval_framework.metrics.base import BaseMetric, MetricResult
3
+ from eval_framework.metrics.completion.rouge_1 import ROUGE_1
4
+ from eval_framework.metrics.completion.rouge_2 import ROUGE_2
5
+ from eval_framework.metrics.completion.rouge_l import ROUGE_L
6
+ from eval_framework.shared.types import Completion
7
+
8
+
9
+ class ROUGE_GEOMETRIC_MEAN(BaseMetric[Completion]):
10
+ """ROUGE Geometric Mean"""
11
+
12
+ NAME = "ROUGE-Geometric-Mean"
13
+
14
+ def calculate(self, response: Completion) -> list[MetricResult]:
15
+ if response.error is not None:
16
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
17
+ if response.completion == "":
18
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
19
+ if any(gt is None for gt in response.ground_truth_list):
20
+ raise LogicError("When calculating ROUGE Geometric Mean ground_truth cannot be None.")
21
+
22
+ # Calculate ROUGE-1, ROUGE-2, and ROUGE-L
23
+ rouge_1 = ROUGE_1().calculate(response)[0].value
24
+ rouge_2 = ROUGE_2().calculate(response)[0].value
25
+ rouge_l = ROUGE_L().calculate(response)[0].value
26
+
27
+ # Calculate the geometric mean of ROUGE-1, ROUGE-2, and ROUGE-L
28
+ if rouge_1 is None or rouge_2 is None or rouge_l is None:
29
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
30
+
31
+ geometric_mean = (rouge_1 * rouge_2 * rouge_l) ** (1 / 3)
32
+ return [
33
+ MetricResult(
34
+ metric_name=self.NAME, value=float(geometric_mean), higher_is_better=True, error=response.error
35
+ )
36
+ ]
@@ -0,0 +1,52 @@
1
+ from eval_framework.exceptions import LogicError
2
+ from eval_framework.metrics.base import BaseMetric, MetricResult
3
+ from eval_framework.shared.types import Completion
4
+
5
+
6
+ class ROUGE_L(BaseMetric[Completion]):
7
+ """ROUGE-L"""
8
+
9
+ NAME = "ROUGE-L"
10
+
11
+ def calculate(self, response: Completion) -> list[MetricResult]:
12
+ if response.error is not None:
13
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
14
+
15
+ if response.completion == "":
16
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
17
+ if None in response.ground_truth_list:
18
+ raise LogicError("When calculating ROUGE-L ground_truth cannot be None.")
19
+
20
+ # ROUGE-L is essentially an F1 score, but it’s a specific F1 score based on
21
+ # the Longest Common Subsequence (LCS) between a candidate summary and a reference summary.
22
+ rouge = max([_calculate_rouge_l(response.completion, gt) for gt in response.ground_truth_list]) # type: ignore[arg-type]
23
+ return [MetricResult(metric_name=self.NAME, value=float(rouge), higher_is_better=True, error=response.error)]
24
+
25
+
26
+ def _longest_common_subsequence_length(candidate_tokens: list[str], reference_tokens: list[str]) -> int:
27
+ candidate_len, reference_len = len(candidate_tokens), len(reference_tokens)
28
+ lcs_matrix = [[0] * (reference_len + 1) for _ in range(candidate_len + 1)]
29
+
30
+ for i in range(candidate_len + 1):
31
+ for j in range(reference_len + 1):
32
+ if i == 0 or j == 0:
33
+ lcs_matrix[i][j] = 0
34
+ elif candidate_tokens[i - 1] == reference_tokens[j - 1]:
35
+ lcs_matrix[i][j] = lcs_matrix[i - 1][j - 1] + 1
36
+ else:
37
+ lcs_matrix[i][j] = max(lcs_matrix[i - 1][j], lcs_matrix[i][j - 1])
38
+
39
+ return lcs_matrix[candidate_len][reference_len]
40
+
41
+
42
+ def _calculate_rouge_l(completion: str, ground_truth: str) -> float:
43
+ lcs_length = _longest_common_subsequence_length(completion.split(), ground_truth.split())
44
+ if lcs_length == 0:
45
+ return 0.0
46
+ precision = lcs_length / len(completion.split())
47
+ recall = lcs_length / len(ground_truth.split())
48
+ if precision + recall == 0:
49
+ f1_score = 0.0
50
+ else:
51
+ f1_score = (2 * precision * recall) / (precision + recall)
52
+ return f1_score
@@ -0,0 +1,248 @@
1
+ import csv
2
+ import io
3
+ import json
4
+ import tomllib
5
+ from typing import Any
6
+
7
+ import xmltodict
8
+ import yaml
9
+ from lxml import etree
10
+
11
+ from eval_framework.metrics.base import BaseMetric, MetricResult
12
+ from eval_framework.shared.types import BaseMetricContext, Completion, extract_context_metric
13
+
14
+
15
+ class StructMetricContext(BaseMetricContext):
16
+ output_type: str
17
+ paths: list[str]
18
+
19
+
20
+ class StructMetric(BaseMetric[Completion]):
21
+ NAME = "StructMetric"
22
+
23
+ def calculate(self, response: Completion) -> list[MetricResult]:
24
+ if response.error is not None:
25
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
26
+
27
+ context = extract_context_metric(response, StructMetricContext)
28
+
29
+ output_type = context.output_type
30
+
31
+ try:
32
+ match output_type.lower():
33
+ case "json":
34
+ result = json.loads(response.completion)
35
+ case "yaml":
36
+ result = list(yaml.safe_load_all(response.completion))
37
+ if isinstance(result, list) and len(result) == 1:
38
+ result = result[0]
39
+ else:
40
+ raise yaml.YAMLError("Multiple documents found in YAML")
41
+ case "toml":
42
+ result = tomllib.loads(response.completion)
43
+ case "xml":
44
+ result = xmltodict.parse(response.completion)
45
+ case "csv":
46
+ csv_output = csv.DictReader(io.StringIO(response.completion))
47
+ # Check for unclosed quotes
48
+ if response.completion.count('"') % 2 != 0:
49
+ raise csv.Error("Unclosed quote in CSV")
50
+ if not csv_output.fieldnames:
51
+ raise csv.Error("CSV has no headers")
52
+ result = {"csv_headers": csv_output.fieldnames, "csv_rows": list(csv_output)}
53
+ case _:
54
+ raise ValueError(f"Unsupported format: {output_type}")
55
+ valid_format = 1.0
56
+ except (json.JSONDecodeError, yaml.YAMLError, tomllib.TOMLDecodeError, csv.Error, Exception):
57
+ valid_format = 0.0
58
+
59
+ has_required_fields = 0.0
60
+ if valid_format == 1:
61
+ # assert "paths" in response.eval_kwargs, "Paths must be provided in eval_kwargs"
62
+ assert context.paths is not None, "Paths must be provided in context"
63
+ paths = context.paths
64
+ assert isinstance(paths, list), "Paths must be a list of strings"
65
+ valid_paths = 0
66
+ for path in paths:
67
+ if path_exists(result, path):
68
+ valid_paths += 1
69
+ has_required_fields = valid_paths / len(paths) if paths else 1.0
70
+
71
+ return [
72
+ MetricResult(
73
+ metric_name=f"{self.NAME}/valid_format",
74
+ value=valid_format,
75
+ higher_is_better=True,
76
+ ),
77
+ MetricResult(
78
+ metric_name=f"{self.NAME}/has_keywords",
79
+ value=has_required_fields,
80
+ higher_is_better=True,
81
+ ),
82
+ ]
83
+
84
+
85
+ def is_valid_html(html: str) -> bool:
86
+ parser = etree.HTMLParser(recover=False)
87
+ try:
88
+ etree.fromstring(html.encode("utf-8"), parser)
89
+ except etree.XMLSyntaxError:
90
+ return False
91
+ return len(parser.error_log) == 0
92
+
93
+
94
+ class RenderableStructMetricContext(BaseMetricContext):
95
+ output_type: str
96
+ keywords: list[str]
97
+
98
+
99
+ class RenderableStructMetric(StructMetric):
100
+ NAME = "RenderableStructMetric"
101
+
102
+ def calculate(self, response: Completion) -> list[MetricResult]:
103
+ if response.error is not None:
104
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
105
+
106
+ context = extract_context_metric(response, RenderableStructMetricContext)
107
+
108
+ output_type = context.output_type
109
+
110
+ valid_format = 0.0
111
+ match output_type.lower():
112
+ case "html":
113
+ valid_format = float(is_valid_html(response.completion))
114
+ case _:
115
+ raise ValueError(f"Unsupported format for RenderableStructMetric: {output_type}")
116
+
117
+ assert context.keywords is not None, "Keywords must be provided in context"
118
+ keywords = context.keywords
119
+ assert isinstance(keywords, list), "Keywords must be a list of strings"
120
+ has_keywords = 1.0
121
+ if keywords:
122
+ has_keywords = sum(1 for keyword in keywords if keyword.lower() in response.completion.lower()) / len(
123
+ keywords
124
+ )
125
+
126
+ return [
127
+ MetricResult(
128
+ metric_name=f"{self.NAME}/valid_format",
129
+ value=valid_format,
130
+ higher_is_better=True,
131
+ ),
132
+ MetricResult(
133
+ metric_name=f"{self.NAME}/has_keywords",
134
+ value=has_keywords,
135
+ higher_is_better=True,
136
+ ),
137
+ ]
138
+
139
+
140
+ # adapted from: https://github.com/TIGER-AI-Lab/StructEval/blob/main/structeval/eval_engine/eval_utils.py
141
+ def tokenize_path(path: str) -> list[str]:
142
+ """
143
+ Tokenize a dot-notation path, handling back-ticks and array indices.
144
+
145
+ Args:
146
+ path: The path string (e.g. "users.0.name" or "users[0].name")
147
+
148
+ Returns:
149
+ List of path tokens
150
+ """
151
+ # Special‑case: treat CSV header paths as a single token
152
+ if path.startswith("csv::"):
153
+ return [path]
154
+
155
+ tokens, buf, in_bt = [], "", False
156
+ i, n = 0, len(path)
157
+
158
+ while i < n:
159
+ ch = path[i]
160
+
161
+ # Toggle back-tick state
162
+ if ch == "`":
163
+ in_bt = not in_bt
164
+ i += 1
165
+ continue
166
+
167
+ # Dot delimiter (when not inside back-ticks)
168
+ if ch == "." and not in_bt:
169
+ if buf:
170
+ tokens.append(buf)
171
+ buf = ""
172
+ i += 1
173
+ continue
174
+
175
+ # Bracket "[index]" treated as separate token
176
+ if ch == "[" and not in_bt:
177
+ if buf:
178
+ tokens.append(buf)
179
+ buf = ""
180
+ j = path.find("]", i)
181
+ if j == -1:
182
+ raise ValueError(f"Unclosed '[' in path: {path}")
183
+ tokens.append(path[i : j + 1]) # e.g. "[0]"
184
+ i = j + 1
185
+ continue
186
+
187
+ # Regular character
188
+ buf += ch
189
+ i += 1
190
+
191
+ if buf:
192
+ tokens.append(buf)
193
+ return tokens
194
+
195
+
196
+ # adapted from: https://github.com/TIGER-AI-Lab/StructEval/blob/main/structeval/eval_engine/eval_utils.py
197
+ def path_exists(data: Any, path: str) -> bool:
198
+ """
199
+ Check if a path exists in a structured data object.
200
+
201
+ Args:
202
+ data: The structured data to check
203
+ path: The path to check (dot notation)
204
+
205
+ Returns:
206
+ True if path exists, False otherwise
207
+ """
208
+ tokens = tokenize_path(path)
209
+
210
+ def walk(node: Any, toks: list[str]) -> bool:
211
+ if not toks:
212
+ return True
213
+ tok, *rest = toks
214
+
215
+ # CSV header rule (root level only)
216
+ if isinstance(node, dict) and "csv_headers" in node and tok.startswith("csv::"):
217
+ header = tok[5:]
218
+ return header in node["csv_headers"] and not rest # must be terminal
219
+
220
+ # Wildcard
221
+ if tok == "*":
222
+ if isinstance(node, list):
223
+ return any(walk(item, rest) for item in node)
224
+ return False
225
+
226
+ # Fixed index [n]
227
+ if tok.startswith("[") and tok.endswith("]"):
228
+ try:
229
+ idx = int(tok[1:-1])
230
+ except ValueError:
231
+ return False
232
+ return isinstance(node, list) and 0 <= idx < len(node) and walk(node[idx], rest)
233
+
234
+ # Dict key handling (JSON/YAML/TOML/XML)
235
+ if isinstance(node, dict):
236
+ # 1️⃣ Literal key match (works for "@id" too)
237
+ if tok in node:
238
+ return walk(node[tok], rest)
239
+
240
+ # 2️⃣ XML attribute fallback: "@id" → "id"
241
+ if tok.startswith("@"):
242
+ attr = tok[1:]
243
+ if attr in node:
244
+ return walk(node[attr], rest)
245
+
246
+ return False
247
+
248
+ return walk(data, tokens)
@@ -0,0 +1,67 @@
1
+ import sacrebleu
2
+
3
+ from eval_framework.exceptions import LogicError
4
+ from eval_framework.metrics.base import BaseMetric, MetricResult
5
+ from eval_framework.shared.types import Completion
6
+
7
+
8
+ class TER(BaseMetric[Completion]):
9
+ """Translation Error Rate is an error metric for machine translation that
10
+ measures the number of edits required to change a system output into one
11
+ of the references
12
+ Source: http://www.cs.umd.edu/~snover/tercom/
13
+ Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
14
+ """
15
+
16
+ NAME = "TER"
17
+
18
+ def calculate(self, response: Completion) -> list[MetricResult]:
19
+ if response.error is not None:
20
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=False, error=response.error)]
21
+
22
+ scores = []
23
+ for ground_truth in response.ground_truth_list:
24
+ if ground_truth == "" or ground_truth is None:
25
+ raise LogicError("When calculating TER we need a ground truth.")
26
+
27
+ sacre_formatted_completion = [response.completion]
28
+ sacre_formatted_ground_truth = [[ground_truth]]
29
+ ter_score = sacrebleu.corpus_ter(sacre_formatted_completion, sacre_formatted_ground_truth).score
30
+ scores.append(ter_score)
31
+
32
+ return [
33
+ MetricResult(metric_name=self.NAME, value=float(min(scores)), higher_is_better=False, error=response.error)
34
+ ]
35
+
36
+
37
+ class LINEWISE_TER(BaseMetric[Completion]):
38
+ """Minimum Line-level TER (Translation Edit Rate) score."""
39
+
40
+ NAME = "Linewise TER"
41
+
42
+ def calculate(self, response: Completion) -> list[MetricResult]:
43
+ if response.error is not None:
44
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=False, error=response.error)]
45
+
46
+ scores = []
47
+ for ground_truth in response.ground_truth_list:
48
+ for sentence in response.completion.split("\n"):
49
+ if sentence == "":
50
+ continue
51
+
52
+ if ground_truth == "" or ground_truth is None:
53
+ raise LogicError("When calculating TER we need a ground truth.")
54
+
55
+ sacre_formatted_completion = [sentence]
56
+ sacre_formatted_ground_truth = [[ground_truth]]
57
+ ter_score = sacrebleu.corpus_ter(sacre_formatted_completion, sacre_formatted_ground_truth).score
58
+ scores.append(ter_score)
59
+
60
+ return [
61
+ MetricResult(
62
+ metric_name=self.NAME,
63
+ value=float(min(scores, default=100)),
64
+ higher_is_better=False,
65
+ error=response.error,
66
+ )
67
+ ]