eval-framework 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +177 -0
- eval_framework/context/eval.py +121 -0
- eval_framework/context/local.py +78 -0
- eval_framework/evaluation_generator.py +234 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +432 -0
- eval_framework/llm/base.py +180 -0
- eval_framework/llm/huggingface.py +418 -0
- eval_framework/llm/mistral.py +88 -0
- eval_framework/llm/models.py +28 -0
- eval_framework/llm/openai.py +400 -0
- eval_framework/llm/vllm.py +554 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +166 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/aidanbench.py +28 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +179 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +34 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/llm/utils.py +20 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/base.py +50 -0
- eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
- eval_framework/metrics/loglikelihood/dcs.py +43 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
- eval_framework/metrics/loglikelihood/ternary.py +42 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +351 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +88 -0
- eval_framework/result_processors/hf_uploader.py +75 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/result_processors/wandb_uploader.py +137 -0
- eval_framework/run.py +369 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +392 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/aidanbench.py +211 -0
- eval_framework/tasks/benchmarks/arc.py +70 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +64 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +133 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +201 -0
- eval_framework/tasks/benchmarks/gsm8k.py +150 -0
- eval_framework/tasks/benchmarks/hellaswag.py +69 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +215 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +85 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +64 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +110 -0
- eval_framework/tasks/benchmarks/sphyr.py +79 -0
- eval_framework/tasks/benchmarks/squad.py +211 -0
- eval_framework/tasks/benchmarks/struct_eval.py +116 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
- eval_framework/tasks/benchmarks/winogender.py +64 -0
- eval_framework/tasks/benchmarks/winogrande.py +69 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +136 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +81 -0
- eval_framework/tasks/task_names.py +324 -0
- eval_framework/tasks/utils.py +584 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/file_ops.py +245 -0
- eval_framework/utils/generate_task_docs.py +244 -0
- eval_framework/utils/helpers.py +32 -0
- eval_framework/utils/logging.py +62 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework/utils/tqdm_handler.py +14 -0
- eval_framework-0.2.7.dist-info/METADATA +548 -0
- eval_framework-0.2.7.dist-info/RECORD +170 -0
- eval_framework-0.2.7.dist-info/WHEEL +4 -0
- eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +537 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import random
|
|
4
|
+
import re
|
|
5
|
+
import signal
|
|
6
|
+
import sqlite3
|
|
7
|
+
import threading
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from time import sleep
|
|
10
|
+
from typing import Any
|
|
11
|
+
from uuid import uuid4
|
|
12
|
+
|
|
13
|
+
import docker
|
|
14
|
+
import mysql.connector
|
|
15
|
+
import mysql.connector.abstracts
|
|
16
|
+
import psycopg2 # type: ignore
|
|
17
|
+
from pydantic import BaseModel
|
|
18
|
+
|
|
19
|
+
from eval_framework.llm.base import BaseLLM
|
|
20
|
+
from eval_framework.metrics.base import MetricResult
|
|
21
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
22
|
+
from eval_framework.metrics.llm.graders.language import Language
|
|
23
|
+
from eval_framework.metrics.llm.graders.sql_quality_grader import SqlQualityGrader
|
|
24
|
+
from eval_framework.shared.types import Completion, LanguageMetricContext, extract_context_metric
|
|
25
|
+
from eval_framework.tasks.utils import get_docker_address
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SqlDialects(Enum):
|
|
31
|
+
sqlite = "sqlite"
|
|
32
|
+
postgres = "postgresql"
|
|
33
|
+
mysql = "mysql"
|
|
34
|
+
standard_sql = "standard_sql"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SqlOutputComparison(BaseModel):
|
|
38
|
+
matches_results_count: bool
|
|
39
|
+
matches_column_count: bool
|
|
40
|
+
results_equal: bool
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SqlValidationResult(BaseModel):
|
|
44
|
+
success: bool
|
|
45
|
+
schema_error: str | None = None
|
|
46
|
+
query_error: str | None = None
|
|
47
|
+
results: list[Any] = []
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LLMJudgeSqlMetricContext(LanguageMetricContext):
|
|
51
|
+
dialect: str
|
|
52
|
+
db_schema: str
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_DOCKER_LAUNCH_LOCK = threading.Lock()
|
|
56
|
+
_MYSQL_PORT = 0
|
|
57
|
+
_POSTGRES_PORT = 0
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LLMJudgeSql(BaseLLMJudgeMetric):
|
|
61
|
+
NAME = "SQL Quality"
|
|
62
|
+
|
|
63
|
+
def __init__(self, llm_judge: BaseLLM):
|
|
64
|
+
super().__init__(llm_judge)
|
|
65
|
+
self._grader = SqlQualityGrader(llm_judge)
|
|
66
|
+
|
|
67
|
+
self.postgres_password = "mysecretpassword"
|
|
68
|
+
self.postgres_user = "postgres"
|
|
69
|
+
|
|
70
|
+
self.mysql_password = "mysecretpassword"
|
|
71
|
+
self.mysql_user = "root"
|
|
72
|
+
self.mysql_db_name = "mysql"
|
|
73
|
+
|
|
74
|
+
with _DOCKER_LAUNCH_LOCK:
|
|
75
|
+
if _MYSQL_PORT != 0 and _POSTGRES_PORT != 0:
|
|
76
|
+
return
|
|
77
|
+
self.client = docker.from_env()
|
|
78
|
+
atexit.register(self._shutdown_dbs)
|
|
79
|
+
signal.signal(signal.SIGTERM, lambda *_: self._shutdown_dbs())
|
|
80
|
+
self._start_postgres_db()
|
|
81
|
+
self._start_mysql_db()
|
|
82
|
+
self._wait_for_db_containers()
|
|
83
|
+
|
|
84
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
85
|
+
if response.error is not None:
|
|
86
|
+
return [
|
|
87
|
+
MetricResult(metric_name=f"{self.NAME}/{k}", value=None, higher_is_better=True, error=response.error)
|
|
88
|
+
for k in [
|
|
89
|
+
"successfully_runs",
|
|
90
|
+
"is_just_sql",
|
|
91
|
+
"matches_results_count",
|
|
92
|
+
"matches_column_count",
|
|
93
|
+
"results_equal",
|
|
94
|
+
"llm_quality_score",
|
|
95
|
+
]
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
context = extract_context_metric(response, LLMJudgeSqlMetricContext)
|
|
99
|
+
|
|
100
|
+
assert isinstance(response.ground_truth, str)
|
|
101
|
+
|
|
102
|
+
schema_id = str(uuid4()).replace("-", "_")
|
|
103
|
+
|
|
104
|
+
expected_result = self.validate_query(
|
|
105
|
+
SqlDialects(context.dialect),
|
|
106
|
+
context.db_schema,
|
|
107
|
+
response.ground_truth,
|
|
108
|
+
f"golden_{schema_id}",
|
|
109
|
+
)
|
|
110
|
+
completion_stripped = response.completion.strip().strip("```sql").strip("```")
|
|
111
|
+
completion_query = extract_query_from_completions(completion_stripped)
|
|
112
|
+
if completion_query:
|
|
113
|
+
result = self.validate_query(
|
|
114
|
+
SqlDialects(context.dialect),
|
|
115
|
+
context.db_schema,
|
|
116
|
+
completion_query,
|
|
117
|
+
f"completion_{schema_id}",
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
result = None
|
|
121
|
+
|
|
122
|
+
results = [
|
|
123
|
+
MetricResult(
|
|
124
|
+
metric_name=f"{self.NAME}/successfully_runs",
|
|
125
|
+
value=float(result is not None and result.success),
|
|
126
|
+
higher_is_better=True,
|
|
127
|
+
error=response.error,
|
|
128
|
+
),
|
|
129
|
+
MetricResult(
|
|
130
|
+
metric_name=f"{self.NAME}/is_just_sql",
|
|
131
|
+
value=float(completion_query == completion_stripped),
|
|
132
|
+
higher_is_better=True,
|
|
133
|
+
error=response.error,
|
|
134
|
+
),
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
if result is not None and result.success:
|
|
138
|
+
output_comparison = SqlOutputComparison(
|
|
139
|
+
matches_results_count=len(expected_result.results) == len(result.results),
|
|
140
|
+
matches_column_count=count_result_columns(expected_result.results)
|
|
141
|
+
== count_result_columns(result.results),
|
|
142
|
+
results_equal=expected_result.results == result.results,
|
|
143
|
+
)
|
|
144
|
+
results.extend(
|
|
145
|
+
[
|
|
146
|
+
MetricResult(
|
|
147
|
+
metric_name=f"{self.NAME}/matches_results_count",
|
|
148
|
+
value=float(output_comparison.matches_results_count),
|
|
149
|
+
higher_is_better=True,
|
|
150
|
+
error=response.error,
|
|
151
|
+
),
|
|
152
|
+
MetricResult(
|
|
153
|
+
metric_name=f"{self.NAME}/matches_column_count",
|
|
154
|
+
value=float(output_comparison.matches_column_count),
|
|
155
|
+
higher_is_better=True,
|
|
156
|
+
error=response.error,
|
|
157
|
+
),
|
|
158
|
+
MetricResult(
|
|
159
|
+
metric_name=f"{self.NAME}/results_equal",
|
|
160
|
+
value=float(output_comparison.results_equal),
|
|
161
|
+
higher_is_better=True,
|
|
162
|
+
error=response.error,
|
|
163
|
+
),
|
|
164
|
+
]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
grading = self._grader.grade(
|
|
168
|
+
prompt=response.user_instruction,
|
|
169
|
+
completion=completion_stripped,
|
|
170
|
+
result=result.results if result and result.success else None,
|
|
171
|
+
language=Language(response.get_instruction_language()),
|
|
172
|
+
)
|
|
173
|
+
results.append(
|
|
174
|
+
MetricResult(
|
|
175
|
+
metric_name=f"{self.NAME}/llm_quality_score",
|
|
176
|
+
# [0, 1] normalization required for visualizer
|
|
177
|
+
value=(float(grading.query_quality) - 1) / 4 if grading.query_quality is not None else None,
|
|
178
|
+
higher_is_better=True,
|
|
179
|
+
llm_judge_prompt=grading.judge_prompt,
|
|
180
|
+
llm_judge_response=grading.judge_response,
|
|
181
|
+
error=response.error,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return results
|
|
186
|
+
|
|
187
|
+
def _start_postgres_db(self) -> None:
|
|
188
|
+
global _POSTGRES_PORT
|
|
189
|
+
for _ in range(10): # find a free port
|
|
190
|
+
try:
|
|
191
|
+
_POSTGRES_PORT = random.randint(1000, 65535)
|
|
192
|
+
self.postgres_docker = self.client.containers.run(
|
|
193
|
+
"docker.io/postgres",
|
|
194
|
+
environment={"POSTGRES_PASSWORD": self.postgres_password},
|
|
195
|
+
ports={5432: _POSTGRES_PORT},
|
|
196
|
+
tty=True,
|
|
197
|
+
auto_remove=True,
|
|
198
|
+
detach=True,
|
|
199
|
+
network_mode="bridge",
|
|
200
|
+
)
|
|
201
|
+
break
|
|
202
|
+
except docker.errors.APIError as e:
|
|
203
|
+
if "port is already allocated" not in str(e):
|
|
204
|
+
raise e
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
def _start_mysql_db(self) -> None:
|
|
208
|
+
global _MYSQL_PORT
|
|
209
|
+
for _ in range(10): # find a free port
|
|
210
|
+
try:
|
|
211
|
+
_MYSQL_PORT = random.randint(1000, 65535)
|
|
212
|
+
self.mysql_docker = self.client.containers.run(
|
|
213
|
+
"docker.io/mysql:latest",
|
|
214
|
+
environment={"MYSQL_ROOT_PASSWORD": self.mysql_password, "MYSQL_DATABASE": self.mysql_db_name},
|
|
215
|
+
ports={3306: _MYSQL_PORT},
|
|
216
|
+
tty=True,
|
|
217
|
+
auto_remove=True,
|
|
218
|
+
detach=True,
|
|
219
|
+
network_mode="bridge",
|
|
220
|
+
)
|
|
221
|
+
break
|
|
222
|
+
except docker.errors.APIError as e:
|
|
223
|
+
if "port is already allocated" not in str(e):
|
|
224
|
+
raise e
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
def _wait_for_db_containers(self) -> None:
|
|
228
|
+
for _ in range(600):
|
|
229
|
+
try:
|
|
230
|
+
con = self.connect_to_postgres()
|
|
231
|
+
con.close()
|
|
232
|
+
con = self.connect_to_mysql()
|
|
233
|
+
con.close()
|
|
234
|
+
return
|
|
235
|
+
except Exception:
|
|
236
|
+
logger.info("Could not connect to DBs yet...")
|
|
237
|
+
sleep(1)
|
|
238
|
+
raise Exception("DBs not available.")
|
|
239
|
+
|
|
240
|
+
def _shutdown_dbs(self) -> None:
|
|
241
|
+
if hasattr(self, "postgres_docker"):
|
|
242
|
+
self.postgres_docker.kill()
|
|
243
|
+
if hasattr(self, "mysql_docker"):
|
|
244
|
+
self.mysql_docker.kill()
|
|
245
|
+
|
|
246
|
+
def validate_query(
|
|
247
|
+
self,
|
|
248
|
+
dialect: SqlDialects,
|
|
249
|
+
create_db_statements: str,
|
|
250
|
+
sql_query: str,
|
|
251
|
+
db_schema: str,
|
|
252
|
+
) -> SqlValidationResult:
|
|
253
|
+
match dialect:
|
|
254
|
+
case SqlDialects.sqlite | SqlDialects.standard_sql:
|
|
255
|
+
return self.validate_query_sqlite(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
|
|
256
|
+
case SqlDialects.postgres:
|
|
257
|
+
return self.validate_query_postgres(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
|
|
258
|
+
case SqlDialects.mysql:
|
|
259
|
+
return self.validate_query_mysql(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
|
|
260
|
+
case _:
|
|
261
|
+
raise NotImplementedError(f"Query validation not implemented for {dialect.value}.")
|
|
262
|
+
|
|
263
|
+
def validate_query_sqlite(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
|
|
264
|
+
con = sqlite3.connect(":memory:")
|
|
265
|
+
cur = con.cursor()
|
|
266
|
+
try:
|
|
267
|
+
statements = separate_statements(create_db_statements)
|
|
268
|
+
for statement in statements:
|
|
269
|
+
cur.execute(statement)
|
|
270
|
+
con.commit()
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.info(f"Create statements are not compatible with SQLite. Reason: {e}")
|
|
273
|
+
return SqlValidationResult(success=False, schema_error=str(e))
|
|
274
|
+
try:
|
|
275
|
+
queries = separate_statements(sql_query)
|
|
276
|
+
for query in queries:
|
|
277
|
+
cur.execute(query)
|
|
278
|
+
con.commit()
|
|
279
|
+
results = cur.fetchall()
|
|
280
|
+
except Exception as e:
|
|
281
|
+
logger.info(f"SQL query is not compatible with SQLite. Reason: {e}")
|
|
282
|
+
return SqlValidationResult(success=False, query_error=str(e))
|
|
283
|
+
|
|
284
|
+
con.close()
|
|
285
|
+
return SqlValidationResult(success=True, results=results)
|
|
286
|
+
|
|
287
|
+
def connect_to_postgres(self) -> psycopg2.extensions.connection:
|
|
288
|
+
conn_params = {
|
|
289
|
+
"dbname": "postgres",
|
|
290
|
+
"user": self.postgres_user,
|
|
291
|
+
"password": self.postgres_password,
|
|
292
|
+
"host": get_docker_address(),
|
|
293
|
+
"port": _POSTGRES_PORT,
|
|
294
|
+
}
|
|
295
|
+
return psycopg2.connect(**conn_params)
|
|
296
|
+
|
|
297
|
+
def validate_query_postgres(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
|
|
298
|
+
con = self.connect_to_postgres()
|
|
299
|
+
cur = con.cursor()
|
|
300
|
+
cur.execute(f"CREATE SCHEMA {db_schema};")
|
|
301
|
+
con.commit()
|
|
302
|
+
cur.execute(f"ALTER USER {self.postgres_user} set SEARCH_PATH = {db_schema};")
|
|
303
|
+
con.commit()
|
|
304
|
+
try:
|
|
305
|
+
statements = separate_statements(create_db_statements)
|
|
306
|
+
for statement in statements:
|
|
307
|
+
cur.execute(statement)
|
|
308
|
+
con.commit()
|
|
309
|
+
except Exception as e:
|
|
310
|
+
logger.info(f"Create statements are not compatible with PostgreSQL. Reason: {e}")
|
|
311
|
+
return SqlValidationResult(success=False, schema_error=str(e))
|
|
312
|
+
try:
|
|
313
|
+
queries = separate_statements(sql_query)
|
|
314
|
+
for query in queries:
|
|
315
|
+
cur.execute(query)
|
|
316
|
+
con.commit()
|
|
317
|
+
results = cur.fetchall()
|
|
318
|
+
except Exception as e:
|
|
319
|
+
logger.info(f"SQL query is not compatible with PostgreSQL. Reason: {e}")
|
|
320
|
+
return SqlValidationResult(success=False, query_error=str(e))
|
|
321
|
+
|
|
322
|
+
con.commit()
|
|
323
|
+
|
|
324
|
+
con.close()
|
|
325
|
+
return SqlValidationResult(success=True, results=results)
|
|
326
|
+
|
|
327
|
+
def connect_to_mysql(
|
|
328
|
+
self,
|
|
329
|
+
) -> mysql.connector.pooling.PooledMySQLConnection | mysql.connector.abstracts.MySQLConnectionAbstract:
|
|
330
|
+
conn_params = {
|
|
331
|
+
"database": self.mysql_db_name,
|
|
332
|
+
"user": self.mysql_user,
|
|
333
|
+
"password": self.mysql_password,
|
|
334
|
+
"host": get_docker_address(),
|
|
335
|
+
"port": _MYSQL_PORT,
|
|
336
|
+
}
|
|
337
|
+
return mysql.connector.connect(**conn_params)
|
|
338
|
+
|
|
339
|
+
def validate_query_mysql(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
|
|
340
|
+
con = self.connect_to_mysql()
|
|
341
|
+
cur = con.cursor(buffered=True)
|
|
342
|
+
cur.execute(f"CREATE SCHEMA {db_schema};")
|
|
343
|
+
con.commit()
|
|
344
|
+
cur.execute(f"USE {db_schema};")
|
|
345
|
+
try:
|
|
346
|
+
statements = separate_statements(create_db_statements)
|
|
347
|
+
for statement in statements:
|
|
348
|
+
cur.execute(statement)
|
|
349
|
+
con.commit()
|
|
350
|
+
except Exception as e:
|
|
351
|
+
logger.info(f"Create statements are not compatible with MySQL. Reason: {e}")
|
|
352
|
+
con.close()
|
|
353
|
+
return SqlValidationResult(success=False, schema_error=str(e))
|
|
354
|
+
try:
|
|
355
|
+
queries = separate_statements(sql_query)
|
|
356
|
+
for query in queries:
|
|
357
|
+
cur.execute(query)
|
|
358
|
+
con.commit()
|
|
359
|
+
results = cur.fetchall()
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.info(f"SQL query is not compatible with MySQL. Reason: {e}")
|
|
362
|
+
con.close()
|
|
363
|
+
return SqlValidationResult(success=False, query_error=str(e))
|
|
364
|
+
|
|
365
|
+
cur.close()
|
|
366
|
+
con.close()
|
|
367
|
+
return SqlValidationResult(success=True, results=results)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def separate_statements(statements: str) -> list[str]:
|
|
371
|
+
return statements.split(";")[:-1]
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def is_create_table_statement(statement: str) -> bool:
|
|
375
|
+
return "CREATE TABLE" in statement
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def count_result_columns(result: list[Any]) -> int:
|
|
379
|
+
if len(result) == 0:
|
|
380
|
+
return 0
|
|
381
|
+
return len(result[0])
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def extract_query_from_completions(completion: str) -> str | None:
|
|
385
|
+
# Match SQL blocks starting with SELECT or WITH at line start
|
|
386
|
+
# (allowing punctuation/whitespace), ending at first semicolon
|
|
387
|
+
pattern = re.compile(r"(?:^|\n)[^a-zA-Z0-9_]*((?:select|with)\b.*?;)", re.IGNORECASE | re.DOTALL)
|
|
388
|
+
|
|
389
|
+
matches = pattern.findall(completion)
|
|
390
|
+
|
|
391
|
+
# Return the query only if exactly one match is found
|
|
392
|
+
if len(matches) == 1:
|
|
393
|
+
return matches[0].strip()
|
|
394
|
+
return None
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from eval_framework.llm.base import BaseLLM
|
|
2
|
+
from eval_framework.metrics.base import MetricResult
|
|
3
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
4
|
+
from eval_framework.metrics.llm.graders.language import Language
|
|
5
|
+
from eval_framework.metrics.llm.graders.summary_world_knowledge_grader import SummarizationWorldKnowledgeGrader
|
|
6
|
+
from eval_framework.shared.types import Completion
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LLMJudgeWorldKnowledge(BaseLLMJudgeMetric):
|
|
10
|
+
NAME = "World Knowledge"
|
|
11
|
+
|
|
12
|
+
def __init__(self, llm_judge: BaseLLM):
|
|
13
|
+
super().__init__(llm_judge)
|
|
14
|
+
self._grader = SummarizationWorldKnowledgeGrader(llm_judge)
|
|
15
|
+
|
|
16
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
17
|
+
if response.error is not None:
|
|
18
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=False, error=response.error)]
|
|
19
|
+
|
|
20
|
+
language = Language(response.get_instruction_language())
|
|
21
|
+
|
|
22
|
+
grading = self._grader.grade(
|
|
23
|
+
reference_input=response.user_instruction,
|
|
24
|
+
completion=response.sanitized_completion,
|
|
25
|
+
language=language,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return [
|
|
29
|
+
MetricResult(
|
|
30
|
+
metric_name=self.NAME,
|
|
31
|
+
value=float(grading.contains_world_knowledge) if grading.contains_world_knowledge is not None else None,
|
|
32
|
+
higher_is_better=False,
|
|
33
|
+
llm_judge_prompt=grading.judge_prompt,
|
|
34
|
+
llm_judge_response=grading.judge_response,
|
|
35
|
+
error=response.error,
|
|
36
|
+
)
|
|
37
|
+
]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Utility functions for LLM-based metrics."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def order_answers_for_comparison(candidate: str, reference: str, swap: bool) -> tuple[str, str]:
|
|
5
|
+
"""Order candidate and reference answers for A/B comparison.
|
|
6
|
+
|
|
7
|
+
This function is used to mitigate position bias in LLM-as-judge evaluations
|
|
8
|
+
by optionally swapping the order in which answers are presented.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
candidate: The candidate completion to evaluate.
|
|
12
|
+
reference: The reference/baseline completion.
|
|
13
|
+
swap: If True, swap the order (reference becomes A, candidate becomes B).
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Tuple of (answer_a, answer_b) in the correct order.
|
|
17
|
+
"""
|
|
18
|
+
if swap:
|
|
19
|
+
return reference, candidate
|
|
20
|
+
return candidate, reference
|
|
File without changes
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
2
|
+
from eval_framework.shared.types import Loglikelihood
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AccuracyLoglikelihood(BaseMetric[Loglikelihood]):
|
|
6
|
+
NAME = "Accuracy Loglikelihood"
|
|
7
|
+
|
|
8
|
+
def calculate(self, response: Loglikelihood) -> list[MetricResult]:
|
|
9
|
+
if response.error is not None:
|
|
10
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
11
|
+
|
|
12
|
+
ground_truth_list = response.ground_truth_list
|
|
13
|
+
completion_text = max(response.loglikelihoods, key=response.loglikelihoods.get) # type: ignore[arg-type]
|
|
14
|
+
|
|
15
|
+
return [
|
|
16
|
+
MetricResult(
|
|
17
|
+
metric_name=self.NAME,
|
|
18
|
+
value=float(completion_text in ground_truth_list),
|
|
19
|
+
higher_is_better=True,
|
|
20
|
+
error=response.error,
|
|
21
|
+
)
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AccuracyNormLoglikelihood(BaseMetric[Loglikelihood]):
|
|
26
|
+
NAME = "Accuracy Normalized Loglikelihood"
|
|
27
|
+
|
|
28
|
+
def calculate(self, response: Loglikelihood) -> list[MetricResult]:
|
|
29
|
+
if response.error is not None:
|
|
30
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
31
|
+
|
|
32
|
+
ground_truth_list = response.ground_truth_list
|
|
33
|
+
|
|
34
|
+
output_len_normalized = {}
|
|
35
|
+
for k, v in response.loglikelihoods.items():
|
|
36
|
+
completion_length = len(k)
|
|
37
|
+
|
|
38
|
+
if completion_length != 0:
|
|
39
|
+
output_len_normalized[k] = v / completion_length
|
|
40
|
+
else:
|
|
41
|
+
output_len_normalized[k] = v
|
|
42
|
+
|
|
43
|
+
model_output_len_normalized = max(output_len_normalized, key=output_len_normalized.get) # type:ignore
|
|
44
|
+
return [
|
|
45
|
+
MetricResult(
|
|
46
|
+
metric_name=self.NAME,
|
|
47
|
+
value=float(model_output_len_normalized in ground_truth_list),
|
|
48
|
+
higher_is_better=True,
|
|
49
|
+
error=response.error,
|
|
50
|
+
)
|
|
51
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.base import BaseMetric
|
|
4
|
+
from eval_framework.shared.types import Loglikelihood
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseLoglikelihoodMetric(BaseMetric[Loglikelihood]):
|
|
8
|
+
"""Base class for metrics that operate on loglikelihood responses."""
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
*,
|
|
13
|
+
len_normalised: bool = True,
|
|
14
|
+
) -> None:
|
|
15
|
+
self.len_normalised = len_normalised
|
|
16
|
+
|
|
17
|
+
def _normalise_text(self, text: str) -> str:
|
|
18
|
+
return text.strip().lower()
|
|
19
|
+
|
|
20
|
+
def _length_normalise_loglikelihoods(self, loglikelihoods: dict) -> dict:
|
|
21
|
+
"""Return a dict of length-normalised loglikelihoods."""
|
|
22
|
+
output = {}
|
|
23
|
+
for k, v in loglikelihoods.items():
|
|
24
|
+
length = len(k)
|
|
25
|
+
output[k] = v / length if length > 0 else v
|
|
26
|
+
return output
|
|
27
|
+
|
|
28
|
+
def _compute_probabilities(self, loglikelihoods: dict) -> tuple[dict, dict]:
|
|
29
|
+
"""Compute probabilities from loglikelihoods, with optional length normalisation."""
|
|
30
|
+
if self.len_normalised:
|
|
31
|
+
loglikelihoods = self._length_normalise_loglikelihoods(loglikelihoods)
|
|
32
|
+
return loglikelihoods, self._softmax(loglikelihoods)
|
|
33
|
+
|
|
34
|
+
def _gather_ground_truths(self, response: Loglikelihood) -> set[str]:
|
|
35
|
+
"""Extract and normalize ground truth completions from a Loglikelihood response."""
|
|
36
|
+
ground_truths = set(
|
|
37
|
+
self._normalise_text(gt)
|
|
38
|
+
for gt in (response.ground_truth if isinstance(response.ground_truth, list) else [response.ground_truth])
|
|
39
|
+
)
|
|
40
|
+
return ground_truths
|
|
41
|
+
|
|
42
|
+
def _softmax(self, log_probs: dict) -> dict:
|
|
43
|
+
"""Convert log-likelihoods to probabilities with softmax."""
|
|
44
|
+
vals = list(log_probs.values())
|
|
45
|
+
if not vals: # no valid entries
|
|
46
|
+
return {}
|
|
47
|
+
m = max(vals)
|
|
48
|
+
exp_vals = [math.exp(x - m) for x in vals]
|
|
49
|
+
total = sum(exp_vals)
|
|
50
|
+
return {k: ev / total for k, ev in zip(log_probs.keys(), exp_vals)}
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from eval_framework.metrics.base import MetricResult
|
|
2
|
+
from eval_framework.metrics.loglikelihood.base import BaseLoglikelihoodMetric
|
|
3
|
+
from eval_framework.shared.types import Loglikelihood
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConfidenceWeightedAccuracy(BaseLoglikelihoodMetric):
|
|
7
|
+
NAME = "Confidence-weighted Accuracy"
|
|
8
|
+
|
|
9
|
+
def __init__(self, *, len_normalised: bool = True) -> None:
|
|
10
|
+
super().__init__(len_normalised=len_normalised)
|
|
11
|
+
|
|
12
|
+
def calculate(self, response: Loglikelihood) -> list[MetricResult]:
|
|
13
|
+
if response.error is not None:
|
|
14
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
15
|
+
|
|
16
|
+
loglikelihoods, probs = self._compute_probabilities(response.loglikelihoods)
|
|
17
|
+
ground_truths = self._gather_ground_truths(response)
|
|
18
|
+
|
|
19
|
+
best_key = max(loglikelihoods, key=loglikelihoods.get) # type: ignore[arg-type]
|
|
20
|
+
best_key_norm = self._normalise_text(best_key)
|
|
21
|
+
p_c = probs.get(best_key, 0.0)
|
|
22
|
+
|
|
23
|
+
accuracy = p_c if best_key_norm in ground_truths else 0.0
|
|
24
|
+
|
|
25
|
+
return [MetricResult(metric_name=self.NAME, value=accuracy, higher_is_better=True, error=response.error)]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from eval_framework.metrics.base import MetricResult
|
|
2
|
+
from eval_framework.metrics.loglikelihood.base import BaseLoglikelihoodMetric
|
|
3
|
+
from eval_framework.shared.types import Loglikelihood
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DistributionalCorrectnessScore(BaseLoglikelihoodMetric):
|
|
7
|
+
"""Based on Burns (2025) Measuring Language Model Hallucinations Through Distributional Correctness."""
|
|
8
|
+
|
|
9
|
+
NAME = "Distributional Correctness Score"
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
*,
|
|
14
|
+
lc: float = 1.0, # Default reward weight for correct answers
|
|
15
|
+
lw: float = 1.0, # Default penalty weight for wrong answers
|
|
16
|
+
len_normalised: bool = True,
|
|
17
|
+
) -> None:
|
|
18
|
+
super().__init__(len_normalised=len_normalised)
|
|
19
|
+
self._lc = float(lc)
|
|
20
|
+
self._lw = float(lw)
|
|
21
|
+
if not (self._lc >= 0 and self._lw >= 0 and self._lc >= self._lw):
|
|
22
|
+
raise ValueError(f"Invalid DCS loadings: lc={self._lc}, lw={self._lw}. Require lc>=0, lw>=0, and lc>=lw.")
|
|
23
|
+
|
|
24
|
+
def calculate(self, response: Loglikelihood) -> list[MetricResult]:
|
|
25
|
+
if response.error is not None:
|
|
26
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
27
|
+
|
|
28
|
+
loglikelihoods, probs = self._compute_probabilities(response.loglikelihoods)
|
|
29
|
+
ground_truths = self._gather_ground_truths(response)
|
|
30
|
+
|
|
31
|
+
idk_key = self._normalise_text(list(response.loglikelihoods.keys())[-1]) # assumes last key is "IDK" option
|
|
32
|
+
|
|
33
|
+
p_c = sum(p for k, p in probs.items() if self._normalise_text(k) in ground_truths)
|
|
34
|
+
p_idk = probs.get(idk_key, 0.0)
|
|
35
|
+
p_w = sum(
|
|
36
|
+
p
|
|
37
|
+
for k, p in probs.items()
|
|
38
|
+
if (self._normalise_text(k) not in ground_truths and self._normalise_text(k) != idk_key)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
dcs = (self._lc * p_c - self._lw * p_w) * (1.0 - p_idk)
|
|
42
|
+
|
|
43
|
+
return [MetricResult(metric_name=self.NAME, value=float(dcs), higher_is_better=True, error=response.error)]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
from eval_framework.metrics.base import BaseMetric, MetricResult
|
|
4
|
+
from eval_framework.shared.types import Loglikelihood
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ProbabilityMass(BaseMetric[Loglikelihood]):
|
|
8
|
+
NAME = "Probability Mass"
|
|
9
|
+
|
|
10
|
+
def calculate(self, response: Loglikelihood) -> list[MetricResult]:
|
|
11
|
+
if response.error is not None:
|
|
12
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
13
|
+
|
|
14
|
+
assert isinstance(response.ground_truth, list)
|
|
15
|
+
# https://docs.python.org/3.10/library/stdtypes.html?highlight=dictview#dictionary-view-objects
|
|
16
|
+
in_ground_truths = [completion in response.ground_truth for completion in response.loglikelihoods]
|
|
17
|
+
log_probs = list(response.loglikelihoods.values())
|
|
18
|
+
|
|
19
|
+
probs = np.exp(log_probs) / np.sum(np.exp(log_probs))
|
|
20
|
+
prob_mass = np.sum(probs[in_ground_truths])
|
|
21
|
+
|
|
22
|
+
return [
|
|
23
|
+
MetricResult(metric_name=self.NAME, value=float(prob_mass), higher_is_better=True, error=response.error)
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ProbabilityMassNorm(BaseMetric[Loglikelihood]):
|
|
28
|
+
NAME = "Probability Mass Normalized"
|
|
29
|
+
|
|
30
|
+
def calculate(self, response: Loglikelihood) -> list[MetricResult]:
|
|
31
|
+
if response.error is not None:
|
|
32
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
|
|
33
|
+
|
|
34
|
+
assert isinstance(response.ground_truth, list)
|
|
35
|
+
# len normalized
|
|
36
|
+
|
|
37
|
+
output_len_normalized = {}
|
|
38
|
+
for k, v in response.loglikelihoods.items():
|
|
39
|
+
completion_length = len(k)
|
|
40
|
+
|
|
41
|
+
if completion_length != 0:
|
|
42
|
+
output_len_normalized[k] = v / completion_length
|
|
43
|
+
else:
|
|
44
|
+
output_len_normalized[k] = v
|
|
45
|
+
|
|
46
|
+
log_probs = list(output_len_normalized.values())
|
|
47
|
+
in_ground_truths = [completion in response.ground_truth for completion in response.loglikelihoods]
|
|
48
|
+
log_probs = list(output_len_normalized.values())
|
|
49
|
+
|
|
50
|
+
probs = np.exp(log_probs) / np.sum(np.exp(log_probs))
|
|
51
|
+
prob_mass_norm = np.sum(probs[in_ground_truths])
|
|
52
|
+
|
|
53
|
+
return [MetricResult(metric_name=self.NAME, value=prob_mass_norm, higher_is_better=True, error=response.error)]
|