eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,102 @@
1
+ import json
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from eval_framework.metrics.base import BaseMetric, MetricResult
6
+ from eval_framework.shared.types import Completion
7
+
8
+ SEPARATOR_MAP = {"comma": ",", "semicolon": ";", "space": " ", "tab": "\t"}
9
+
10
+
11
+ class CSVFormatEvaluation(BaseModel):
12
+ implicit: bool = False
13
+ has_csv: bool = False
14
+ is_separator_respected: bool = False
15
+ is_column_count_respected: bool = False
16
+
17
+
18
+ class CSVFormat(BaseMetric[Completion]):
19
+ NAME = "CSV Format"
20
+ KEYS = ["has_csv", "is_separator_respected", "is_column_count_respected"]
21
+
22
+ def calculate(self, response: Completion) -> list[MetricResult]:
23
+ if response.error is not None:
24
+ return [
25
+ MetricResult(metric_name=f"{self.NAME}/{k}", value=None, higher_is_better=True, error=response.error)
26
+ for k in self.KEYS
27
+ ]
28
+
29
+ if response.completion == "":
30
+ return [
31
+ MetricResult(metric_name=f"{self.NAME}/{k}", value=0.0, higher_is_better=True, error=response.error)
32
+ for k in self.KEYS
33
+ ]
34
+
35
+ grading = evaluate_csv_format(response)
36
+
37
+ results = []
38
+ for key in self.KEYS:
39
+ result = MetricResult(
40
+ metric_name=f"{self.NAME}/{key}",
41
+ value=float(getattr(grading, key)),
42
+ higher_is_better=True,
43
+ error=response.error,
44
+ )
45
+ results.append(result)
46
+ return results
47
+
48
+
49
+ def extract_csv_from_text(text: str, min_rows: int = 2, min_columns: int = 2) -> tuple[list[str] | None, str | None]:
50
+ lines = text.split("\n")
51
+ delimiters = set(SEPARATOR_MAP.values())
52
+ best_delimiter = None
53
+ csv_lines: list[str] = []
54
+
55
+ # Iterate over lines to find potential delimiters and consistent substring counts
56
+ for i, line in enumerate(lines):
57
+ for delimiter in delimiters:
58
+ substrings = line.split(delimiter)
59
+ if len(substrings) < min_columns:
60
+ continue
61
+
62
+ current_csv_lines = [line]
63
+ for j in range(i + 1, len(lines)):
64
+ next_line = lines[j]
65
+ next_substrings = next_line.split(delimiter)
66
+ if len(next_substrings) != len(substrings):
67
+ break
68
+ current_csv_lines.append(next_line)
69
+ if len(current_csv_lines) >= min_rows and len(current_csv_lines) > len(csv_lines):
70
+ best_delimiter = delimiter
71
+ csv_lines = current_csv_lines
72
+
73
+ if not csv_lines:
74
+ return None, None
75
+
76
+ return csv_lines, best_delimiter
77
+
78
+
79
+ def evaluate_csv_format(response: Completion) -> CSVFormatEvaluation:
80
+ expected_output = json.loads(str(response.ground_truth))
81
+
82
+ expected_separator_code = expected_output["separator"]
83
+ csv_lines, separator = extract_csv_from_text(response.completion)
84
+
85
+ if not csv_lines:
86
+ return CSVFormatEvaluation(has_csv=False, implicit=not expected_separator_code)
87
+
88
+ csv_format_evaluation = CSVFormatEvaluation(has_csv=True, implicit=not expected_separator_code)
89
+
90
+ if not expected_separator_code:
91
+ csv_format_evaluation.is_separator_respected = separator in SEPARATOR_MAP.values()
92
+ else:
93
+ csv_format_evaluation.is_separator_respected = separator == SEPARATOR_MAP.get(expected_separator_code)
94
+
95
+ expected_column_count = len(expected_output["columns"])
96
+ column_counts = [len(csv_lines.split(separator)) for csv_lines in csv_lines]
97
+
98
+ csv_format_evaluation.is_column_count_respected = all(
99
+ column_count == expected_column_count for column_count in column_counts
100
+ )
101
+
102
+ return csv_format_evaluation
@@ -0,0 +1,49 @@
1
+ import re
2
+
3
+ from eval_framework.metrics.base import BaseMetric, MetricResult
4
+ from eval_framework.shared.types import Completion, Error
5
+
6
+
7
+ class CWEAccuracy(BaseMetric[Completion]):
8
+ """Metric for Common Word Extraction tasks"""
9
+
10
+ NAME = "CWEAccuracy"
11
+
12
+ def calculate(self, response: Completion) -> list[MetricResult]:
13
+ if response.error is not None:
14
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
15
+
16
+ ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
17
+ if not ground_truths:
18
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
19
+
20
+ try:
21
+ # Get model's answer
22
+ model_answer = response.completion
23
+
24
+ # Check if all words in the correct answer are present in the model's answer
25
+ is_correct = self._is_answer_correct(ground_truths, model_answer)
26
+
27
+ return [
28
+ MetricResult(
29
+ metric_name=self.NAME, value=1.0 if is_correct else 0.0, higher_is_better=True, error=response.error
30
+ )
31
+ ]
32
+ except Exception as e:
33
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback="")
34
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
35
+
36
+ def _is_answer_correct(self, correct_answer: list[str], model_answer: str) -> bool:
37
+ """Check if all words in correct_answer are present in model_answer as whole words"""
38
+ model_answer = model_answer.strip().lower()
39
+ correct_answer = [correct.strip().lower() for correct in correct_answer]
40
+
41
+ # For each word in the correct answer, check if it exists as a whole word in the model answer
42
+ for word in correct_answer:
43
+ # Create a regex pattern that matches the word as a whole word
44
+ # \b represents a word boundary
45
+ pattern = r"\b" + re.escape(word) + r"\b"
46
+ if not re.search(pattern, model_answer):
47
+ return False
48
+
49
+ return True
@@ -0,0 +1,65 @@
1
+ from eval_framework.metrics.base import BaseMetric, MetricResult
2
+ from eval_framework.shared.types import Completion, Error
3
+
4
+
5
+ class ExponentialSimilarity(BaseMetric[Completion]):
6
+ NAME = "ExponentialSimilarity"
7
+
8
+ def calculate(self, response: Completion) -> list[MetricResult]:
9
+ if response.error is not None:
10
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
11
+
12
+ ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
13
+ if not ground_truths:
14
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
15
+
16
+ try:
17
+ # Try to calculate exponential similarity for each ground truth
18
+ similarities = []
19
+ for gt in ground_truths:
20
+ try:
21
+ gt_float = float(gt)
22
+ completion_float = float(response.completion)
23
+ similarities.append(calculate_exponential_similarity(gt_float, completion_float))
24
+ except (ValueError, TypeError):
25
+ # Skip this ground truth if conversion fails
26
+ continue
27
+
28
+ # If we have any valid similarities, return the max
29
+ if similarities:
30
+ return [
31
+ MetricResult(
32
+ metric_name=self.NAME, value=max(similarities), higher_is_better=True, error=response.error
33
+ )
34
+ ]
35
+ else:
36
+ # If all conversions failed, return an error
37
+ error = Error(
38
+ error_class="ValueError",
39
+ message="Could not convert ground truth or completion to float",
40
+ traceback="",
41
+ )
42
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
43
+ except Exception as e:
44
+ error = Error(error_class=e.__class__.__name__, message=str(e), traceback="")
45
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=error)]
46
+
47
+
48
+ def calculate_exponential_similarity(p_true: float, p_pred: float) -> float:
49
+ """
50
+ Compute the exponential similarity (SpaceDigest version) between
51
+ the gold percentage and predicted value.
52
+
53
+ Parameters:
54
+ - p_true (float): The gold/reference percentage.
55
+ - p_pred (float): The predicted scalar.
56
+ - d (float): Base of the exponent. Default is 2.
57
+ - c (float): Coefficient in exponent. Default is 10.
58
+
59
+ Returns:
60
+ - float: Similarity score between 0 and 1.
61
+ """
62
+ d = 2
63
+ c = 10
64
+
65
+ return d ** (-c * abs(p_true / 100 - p_pred / 100))
@@ -0,0 +1,42 @@
1
+ from collections import Counter
2
+ from typing import Any
3
+
4
+ from eval_framework.metrics.base import BaseMetric, MetricResult
5
+ from eval_framework.shared.types import Completion
6
+
7
+
8
+ class F1(BaseMetric[Completion]):
9
+ NAME = "F1"
10
+
11
+ def calculate(self, response: Completion) -> list[MetricResult]:
12
+ if response.error is not None:
13
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
14
+
15
+ ground_truths = [gt for gt in response.ground_truth_list if gt is not None]
16
+ if not ground_truths:
17
+ return [MetricResult(metric_name=self.NAME, value=0.0, higher_is_better=True, error=response.error)]
18
+
19
+ hyp_tokens = response.completion.lower().split()
20
+ f1_scores = [calculate_f1(gt.lower().split(), hyp_tokens) for gt in ground_truths]
21
+ max_f1 = max(f1_scores)
22
+
23
+ return [MetricResult(metric_name=self.NAME, value=max_f1, higher_is_better=True, error=response.error)]
24
+
25
+
26
+ def calculate_f1(ref_tokens: list[Any], hyp_tokens: list[Any]) -> float:
27
+ """Calculate F1 score between two texts based on token overlap."""
28
+ if not ref_tokens and not hyp_tokens:
29
+ return 1.0
30
+ if not ref_tokens or not hyp_tokens:
31
+ return 0.0
32
+
33
+ common = Counter(ref_tokens) & Counter(hyp_tokens)
34
+ num_same = sum(common.values())
35
+
36
+ if num_same == 0:
37
+ return 0.0
38
+
39
+ precision = num_same / len(hyp_tokens)
40
+ recall = num_same / len(ref_tokens)
41
+
42
+ return 2 * precision * recall / (precision + recall)
@@ -0,0 +1,56 @@
1
+ import json
2
+ import re
3
+
4
+ from eval_framework.metrics.base import BaseMetric, MetricResult
5
+ from eval_framework.shared.types import Completion
6
+
7
+
8
+ class CheckJsonFormat(BaseMetric[Completion]):
9
+ NAME = "JSON Format"
10
+
11
+ def _preprocess(self, completion: str) -> str:
12
+ completion = completion.strip()
13
+ for prefix in ["```json", "```Json", "```JSON", "```"]:
14
+ completion = completion.removeprefix(prefix)
15
+ completion = completion.removesuffix("```")
16
+ completion = completion.strip()
17
+ return completion
18
+
19
+ def calculate(self, response: Completion) -> list[MetricResult]:
20
+ if response.error is not None:
21
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
22
+
23
+ json_text = self._preprocess(response.completion)
24
+
25
+ try:
26
+ json.loads(json_text)
27
+ is_valid_json = True
28
+ except ValueError as _:
29
+ is_valid_json = False
30
+
31
+ return [
32
+ MetricResult(metric_name=self.NAME, value=float(is_valid_json), higher_is_better=True, error=response.error)
33
+ ]
34
+
35
+
36
+ class CheckPostScriptFormat(BaseMetric[Completion]):
37
+ """
38
+ This metric is honestly not that great
39
+ In the original IFEval implementation it just checks whether the
40
+ text contains the string (P.)P.S. or variants thereof such as p. s.
41
+ It doesn't check for parsing
42
+ """
43
+
44
+ NAME = "Postscript Format"
45
+
46
+ def calculate(self, response: Completion) -> list[MetricResult]:
47
+ if response.error is not None:
48
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
49
+
50
+ postscript_pattern = r"\s*(P\.S\.|P\.P\.S\.)"
51
+ postscript = re.findall(postscript_pattern, response.completion, flags=re.MULTILINE)
52
+ return [
53
+ MetricResult(
54
+ metric_name=self.NAME, value=1.0 if postscript else 0.0, higher_is_better=True, error=response.error
55
+ )
56
+ ]
@@ -0,0 +1,77 @@
1
+ import re
2
+
3
+ from eval_framework.metrics.base import BaseMetric, MetricResult
4
+ from eval_framework.shared.types import Completion
5
+
6
+
7
+ class GridDifference(BaseMetric[Completion]):
8
+ NAME = "grid_difference"
9
+
10
+ def count_differences(self, character_list_1: list[str], character_list_2: list[str]) -> int:
11
+ count = 0
12
+ for character_1, character_2 in zip(character_list_1, character_list_2):
13
+ if character_1 != character_2:
14
+ count += 1
15
+ return count
16
+
17
+ def calculate_score(
18
+ self, output_ground_truth_difference_count: int, input_ground_truth_difference_count: int
19
+ ) -> float:
20
+ if output_ground_truth_difference_count == 0 and input_ground_truth_difference_count == 0:
21
+ return 1.0
22
+ score = 1.0 - (float(output_ground_truth_difference_count) / float(input_ground_truth_difference_count))
23
+ return score
24
+
25
+ def extract_grid_from_prompt(self, prompt: str) -> str:
26
+ # Extract grid between known markers
27
+ start_marker = "Below is the input grid with masked regions:"
28
+ end_marker = "Please output the completed grid"
29
+
30
+ # Use regex with DOTALL flag to match across newlines
31
+ match = re.search(re.escape(start_marker) + r"(.*?)" + re.escape(end_marker), prompt, re.DOTALL)
32
+
33
+ if match:
34
+ grid = match.group(1).strip()
35
+ return grid
36
+
37
+ return ""
38
+
39
+ def calculate(self, response: Completion) -> list[MetricResult]:
40
+ if response.error is not None:
41
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
42
+
43
+ input_grid = self.extract_grid_from_prompt(prompt=response.last_user_instruction).split()
44
+ output_grid = response.completion.split()
45
+
46
+ assert response.ground_truth_list[0], "Ground truth list is empty or not provided in the response."
47
+ ground_truth_grid = response.ground_truth_list[0].split()
48
+
49
+ input_ground_truth_differences_count = self.count_differences(input_grid, ground_truth_grid)
50
+ output_ground_truth_differences_count = self.count_differences(output_grid, ground_truth_grid)
51
+
52
+ exact_match = True
53
+ score = 1.0
54
+ normalized_score = 1.0
55
+ if output_ground_truth_differences_count != 0:
56
+ exact_match = False
57
+ score = self.calculate_score(
58
+ output_ground_truth_differences_count,
59
+ input_ground_truth_differences_count,
60
+ )
61
+ normalized_score = max(score, 0.0)
62
+
63
+ return [
64
+ MetricResult(
65
+ metric_name=f"{self.NAME}_exact_match",
66
+ value=float(exact_match),
67
+ higher_is_better=True,
68
+ error=response.error,
69
+ ),
70
+ MetricResult(metric_name=f"{self.NAME}_score", value=score, higher_is_better=True, error=response.error),
71
+ MetricResult(
72
+ metric_name=f"{self.NAME}_normalized_score",
73
+ value=normalized_score,
74
+ higher_is_better=True,
75
+ error=response.error,
76
+ ),
77
+ ]
@@ -0,0 +1,73 @@
1
+ from typing import Any
2
+
3
+ from eval_framework.external.ifeval_impl.utils import process_results
4
+ from eval_framework.metrics.base import BaseMetric, MetricResult
5
+ from eval_framework.shared.types import BaseMetricContext, Completion, extract_context_metric
6
+
7
+
8
+ class IFEvalMetricContext(BaseMetricContext):
9
+ key: int
10
+ instruction_id_list: list[str]
11
+ prompt: str
12
+ additional_kwargs: list[dict[str, Any]]
13
+
14
+
15
+ class IFEvalMetric(BaseMetric[Completion]):
16
+ NAME = "IFEval"
17
+
18
+ def calculate(self, response: Completion) -> list[MetricResult]:
19
+ context = extract_context_metric(response, IFEvalMetricContext)
20
+
21
+ if response.error is not None:
22
+ return [
23
+ MetricResult(
24
+ metric_name=f"{self.NAME}/prompt_level_strict_acc",
25
+ value=None,
26
+ higher_is_better=True,
27
+ error=response.error,
28
+ ),
29
+ MetricResult(
30
+ metric_name=f"{self.NAME}/prompt_level_loose_acc",
31
+ value=None,
32
+ higher_is_better=True,
33
+ error=response.error,
34
+ ),
35
+ ]
36
+
37
+ grading = process_results(context, [response.completion])
38
+
39
+ results = [
40
+ MetricResult(
41
+ metric_name=f"{self.NAME}/prompt_level_strict_acc",
42
+ value=float(grading["prompt_level_strict_acc"]),
43
+ higher_is_better=True,
44
+ error=response.error,
45
+ ),
46
+ MetricResult(
47
+ metric_name=f"{self.NAME}/prompt_level_loose_acc",
48
+ value=float(grading["prompt_level_loose_acc"]),
49
+ higher_is_better=True,
50
+ error=response.error,
51
+ ),
52
+ ]
53
+ # this framework does not support a custom aggregation step (see agg_inst_level_acc()) so work around
54
+ # by returning the result for each instruction as a separate MetricResult
55
+ results += [
56
+ MetricResult(
57
+ metric_name=f"{self.NAME}/inst_level_strict_acc",
58
+ value=float(v),
59
+ higher_is_better=True,
60
+ error=response.error,
61
+ )
62
+ for v in grading["inst_level_strict_acc"]
63
+ ]
64
+ results += [
65
+ MetricResult(
66
+ metric_name=f"{self.NAME}/inst_level_loose_acc",
67
+ value=float(v),
68
+ higher_is_better=True,
69
+ error=response.error,
70
+ )
71
+ for v in grading["inst_level_loose_acc"]
72
+ ]
73
+ return results
@@ -0,0 +1,179 @@
1
+ import json
2
+ from collections.abc import Mapping
3
+ from typing import Any
4
+
5
+ import jsonschema # type: ignore
6
+ from pydantic import BaseModel
7
+
8
+ from eval_framework.metrics.base import BaseMetric, MetricResult
9
+ from eval_framework.shared.types import Completion
10
+
11
+
12
+ class JsonFormatEvaluation(BaseModel):
13
+ is_just_json: bool = False
14
+ is_valid_json: bool = False
15
+ fulfills_schema: bool | None = None
16
+ exact_match: bool | None = None
17
+ json_parsing_error: str | None = None
18
+ schema_validation_error: str | None = None
19
+
20
+
21
+ class JsonFormat(BaseMetric[Completion]):
22
+ NAME = "JSON Format"
23
+
24
+ def calculate(self, response: Completion) -> list[MetricResult]:
25
+ keys = [
26
+ "is_just_json",
27
+ "is_valid_json",
28
+ "fulfills_schema",
29
+ "exact_match",
30
+ ]
31
+
32
+ if response.error is not None:
33
+ return [
34
+ MetricResult(metric_name=f"{self.NAME}/{k}", value=None, higher_is_better=True, error=response.error)
35
+ for k in keys
36
+ ]
37
+
38
+ if response.completion == "":
39
+ return [
40
+ MetricResult(metric_name=f"{self.NAME}/{k}", value=0.0, higher_is_better=True, error=response.error)
41
+ for k in keys
42
+ ]
43
+
44
+ json_dict, grading = self._extract_and_parse_json(response.completion)
45
+
46
+ ground_truth_dict = json.loads(str(response.ground_truth))
47
+ schema = ground_truth_dict["json_schema"]
48
+ expected_object = ground_truth_dict.get("expected_output", None)
49
+
50
+ if schema and json_dict is None:
51
+ grading.fulfills_schema = False
52
+ if schema and json_dict is not None:
53
+ grading = self._validate_json_against_schema(json_dict, schema, grading)
54
+ if expected_object is not None and json_dict is not None:
55
+ grading.exact_match = json_dict == expected_object
56
+
57
+ results = []
58
+ for key in keys:
59
+ result = MetricResult(
60
+ metric_name=f"{self.NAME}/{key}",
61
+ value=float(getattr(grading, key)) if getattr(grading, key) is not None else None,
62
+ higher_is_better=True,
63
+ error=response.error,
64
+ code_execution_trace=(grading.json_parsing_error or "") + (grading.schema_validation_error or ""),
65
+ )
66
+ results.append(result)
67
+ return results
68
+
69
+ @staticmethod
70
+ def _validate_json_against_schema(
71
+ json_obj: object, schema: Mapping[str, Any], evaluation_result: JsonFormatEvaluation
72
+ ) -> JsonFormatEvaluation:
73
+ evaluation_result = evaluation_result.model_copy(deep=True)
74
+ try:
75
+ jsonschema.validate(json_obj, schema)
76
+ evaluation_result.fulfills_schema = True
77
+ except jsonschema.exceptions.ValidationError as e:
78
+ evaluation_result.fulfills_schema = False
79
+ evaluation_result.schema_validation_error = type(e).__name__
80
+ except jsonschema.exceptions.SchemaError as e:
81
+ evaluation_result.schema_validation_error = type(e).__name__
82
+ return evaluation_result
83
+
84
+ @staticmethod
85
+ def _extract_and_parse_json(completion: str) -> tuple[object, JsonFormatEvaluation]:
86
+ evaluation_result = JsonFormatEvaluation()
87
+ json_dict = None
88
+ try:
89
+ json_dict = json.loads(remove_comments(completion.strip("`")))
90
+ evaluation_result.is_just_json = True
91
+ evaluation_result.is_valid_json = True
92
+ except Exception as _:
93
+ try:
94
+ json_string = remove_comments(get_json_object(completion))
95
+ json_dict = json.loads(json_string)
96
+ evaluation_result.is_valid_json = True
97
+ except Exception as e:
98
+ evaluation_result.json_parsing_error = type(e).__name__
99
+ return json_dict, evaluation_result
100
+
101
+
102
+ def get_json_object(text: str) -> str:
103
+ """
104
+ Extract the first valid JSON object or array from text.
105
+
106
+ This function handles nested brackets properly by using a bracket counting
107
+ approach to find complete JSON structures, rather than using regex which
108
+ can incorrectly match outer brackets containing non-JSON content.
109
+ """
110
+
111
+ def find_json_at_position(text: str, start_pos: int, open_char: str, close_char: str) -> str | None:
112
+ """Find a complete JSON object/array starting at the given position."""
113
+ if start_pos >= len(text) or text[start_pos] != open_char:
114
+ return None
115
+
116
+ bracket_count = 0
117
+ in_string = False
118
+ escaped = False
119
+
120
+ for i in range(start_pos, len(text)):
121
+ char = text[i]
122
+
123
+ if escaped:
124
+ escaped = False
125
+ continue
126
+
127
+ if char == "\\" and in_string:
128
+ escaped = True
129
+ continue
130
+
131
+ if char == '"' and not escaped:
132
+ in_string = not in_string
133
+ continue
134
+
135
+ if not in_string:
136
+ if char == open_char:
137
+ bracket_count += 1
138
+ elif char == close_char:
139
+ bracket_count -= 1
140
+ if bracket_count == 0:
141
+ # Found complete JSON structure
142
+ candidate = text[start_pos : i + 1]
143
+ # Test if it's valid JSON
144
+ try:
145
+ json.loads(candidate)
146
+ return candidate
147
+ except json.JSONDecodeError:
148
+ return None
149
+
150
+ return None
151
+
152
+ # Look for JSON objects {} and arrays []
153
+ json_candidates = []
154
+
155
+ # Search for objects starting with {
156
+ for i in range(len(text)):
157
+ if text[i] == "{":
158
+ candidate = find_json_at_position(text, i, "{", "}")
159
+ if candidate:
160
+ json_candidates.append(candidate)
161
+
162
+ # Search for arrays starting with [
163
+ for i in range(len(text)):
164
+ if text[i] == "[":
165
+ candidate = find_json_at_position(text, i, "[", "]")
166
+ if candidate:
167
+ json_candidates.append(candidate)
168
+
169
+ if not json_candidates:
170
+ raise RuntimeError(f"No valid JSON object found in {text}.")
171
+
172
+ # Return the longest valid JSON (most likely to be the main content)
173
+ return max(json_candidates, key=len)
174
+
175
+
176
+ def remove_comments(text: str, comment_indicator: str = "//") -> str:
177
+ lines = text.splitlines()
178
+ lines = [line.split(comment_indicator)[0] for line in lines]
179
+ return "\n".join([line for line in lines if line.strip()])