eval-framework 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. eval_framework/__init__.py +7 -0
  2. eval_framework/base_config.py +36 -0
  3. eval_framework/context/__init__.py +0 -0
  4. eval_framework/context/determined.py +177 -0
  5. eval_framework/context/eval.py +121 -0
  6. eval_framework/context/local.py +78 -0
  7. eval_framework/evaluation_generator.py +234 -0
  8. eval_framework/exceptions.py +2 -0
  9. eval_framework/external/ifeval_impl/README.md +5 -0
  10. eval_framework/external/ifeval_impl/instructions.py +1523 -0
  11. eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
  12. eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
  13. eval_framework/external/ifeval_impl/utils.py +135 -0
  14. eval_framework/llm/__init__.py +0 -0
  15. eval_framework/llm/aleph_alpha.py +432 -0
  16. eval_framework/llm/base.py +180 -0
  17. eval_framework/llm/huggingface.py +418 -0
  18. eval_framework/llm/mistral.py +88 -0
  19. eval_framework/llm/models.py +28 -0
  20. eval_framework/llm/openai.py +400 -0
  21. eval_framework/llm/vllm.py +554 -0
  22. eval_framework/logger.py +3 -0
  23. eval_framework/main.py +166 -0
  24. eval_framework/metrics/__init__.py +0 -0
  25. eval_framework/metrics/base.py +40 -0
  26. eval_framework/metrics/completion/__init__.py +1 -0
  27. eval_framework/metrics/completion/accuracy_completion.py +16 -0
  28. eval_framework/metrics/completion/aidanbench.py +28 -0
  29. eval_framework/metrics/completion/bleu.py +76 -0
  30. eval_framework/metrics/completion/chrf.py +62 -0
  31. eval_framework/metrics/completion/code_assertion.py +44 -0
  32. eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
  33. eval_framework/metrics/completion/comet.py +56 -0
  34. eval_framework/metrics/completion/concordance_index.py +38 -0
  35. eval_framework/metrics/completion/csv_format.py +102 -0
  36. eval_framework/metrics/completion/cwe_accuracy.py +49 -0
  37. eval_framework/metrics/completion/exponential_similarity.py +65 -0
  38. eval_framework/metrics/completion/f1.py +42 -0
  39. eval_framework/metrics/completion/format_checker.py +56 -0
  40. eval_framework/metrics/completion/grid_difference.py +77 -0
  41. eval_framework/metrics/completion/ifeval.py +73 -0
  42. eval_framework/metrics/completion/json_format.py +179 -0
  43. eval_framework/metrics/completion/language_checker.py +74 -0
  44. eval_framework/metrics/completion/length_control.py +83 -0
  45. eval_framework/metrics/completion/math_reasoning_completion.py +307 -0
  46. eval_framework/metrics/completion/niah_accuracy.py +163 -0
  47. eval_framework/metrics/completion/placeholder_checker.py +27 -0
  48. eval_framework/metrics/completion/repetition.py +88 -0
  49. eval_framework/metrics/completion/rouge_1.py +35 -0
  50. eval_framework/metrics/completion/rouge_2.py +45 -0
  51. eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
  52. eval_framework/metrics/completion/rouge_l.py +52 -0
  53. eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
  54. eval_framework/metrics/completion/ter.py +67 -0
  55. eval_framework/metrics/completion/text_counter.py +182 -0
  56. eval_framework/metrics/efficiency/__init__.py +0 -0
  57. eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
  58. eval_framework/metrics/llm/__init__.py +0 -0
  59. eval_framework/metrics/llm/base.py +34 -0
  60. eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
  61. eval_framework/metrics/llm/graders/coherence_grader.py +115 -0
  62. eval_framework/metrics/llm/graders/comparison_grader.py +198 -0
  63. eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
  64. eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
  65. eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
  66. eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
  67. eval_framework/metrics/llm/graders/language.py +56 -0
  68. eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
  69. eval_framework/metrics/llm/graders/models.py +74 -0
  70. eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
  71. eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
  72. eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
  73. eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
  74. eval_framework/metrics/llm/llm_judge_coherence.py +44 -0
  75. eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
  76. eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
  77. eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
  78. eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
  79. eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
  80. eval_framework/metrics/llm/llm_judge_mtbench_pair.py +306 -0
  81. eval_framework/metrics/llm/llm_judge_mtbench_single.py +210 -0
  82. eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
  83. eval_framework/metrics/llm/llm_judge_sql.py +394 -0
  84. eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
  85. eval_framework/metrics/llm/utils.py +20 -0
  86. eval_framework/metrics/loglikelihood/__init__.py +0 -0
  87. eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
  88. eval_framework/metrics/loglikelihood/base.py +50 -0
  89. eval_framework/metrics/loglikelihood/confidence_weighted_accuracy.py +25 -0
  90. eval_framework/metrics/loglikelihood/dcs.py +43 -0
  91. eval_framework/metrics/loglikelihood/probability_mass.py +53 -0
  92. eval_framework/metrics/loglikelihood/ternary.py +42 -0
  93. eval_framework/py.typed +0 -0
  94. eval_framework/response_generator.py +351 -0
  95. eval_framework/result_processors/__init__.py +0 -0
  96. eval_framework/result_processors/base.py +88 -0
  97. eval_framework/result_processors/hf_uploader.py +75 -0
  98. eval_framework/result_processors/result_processor.py +129 -0
  99. eval_framework/result_processors/wandb_uploader.py +137 -0
  100. eval_framework/run.py +369 -0
  101. eval_framework/run_direct.py +42 -0
  102. eval_framework/shared/types.py +227 -0
  103. eval_framework/tasks/__init__.py +6 -0
  104. eval_framework/tasks/base.py +392 -0
  105. eval_framework/tasks/benchmarks/__init__.py +0 -0
  106. eval_framework/tasks/benchmarks/aidanbench.py +211 -0
  107. eval_framework/tasks/benchmarks/arc.py +70 -0
  108. eval_framework/tasks/benchmarks/arc_de.py +46 -0
  109. eval_framework/tasks/benchmarks/arc_fi.py +46 -0
  110. eval_framework/tasks/benchmarks/belebele.py +60 -0
  111. eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
  112. eval_framework/tasks/benchmarks/casehold.py +47 -0
  113. eval_framework/tasks/benchmarks/chembench.py +85 -0
  114. eval_framework/tasks/benchmarks/copa.py +64 -0
  115. eval_framework/tasks/benchmarks/duc.py +91 -0
  116. eval_framework/tasks/benchmarks/flores200.py +133 -0
  117. eval_framework/tasks/benchmarks/flores_plus.py +84 -0
  118. eval_framework/tasks/benchmarks/gpqa.py +201 -0
  119. eval_framework/tasks/benchmarks/gsm8k.py +150 -0
  120. eval_framework/tasks/benchmarks/hellaswag.py +69 -0
  121. eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
  122. eval_framework/tasks/benchmarks/humaneval.py +97 -0
  123. eval_framework/tasks/benchmarks/ifeval.py +78 -0
  124. eval_framework/tasks/benchmarks/include.py +119 -0
  125. eval_framework/tasks/benchmarks/infinitebench.py +302 -0
  126. eval_framework/tasks/benchmarks/math_reasoning.py +580 -0
  127. eval_framework/tasks/benchmarks/mbpp.py +192 -0
  128. eval_framework/tasks/benchmarks/mmlu.py +215 -0
  129. eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
  130. eval_framework/tasks/benchmarks/mmlu_pro.py +164 -0
  131. eval_framework/tasks/benchmarks/mmmlu.py +529 -0
  132. eval_framework/tasks/benchmarks/openbookqa.py +85 -0
  133. eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
  134. eval_framework/tasks/benchmarks/pawsx.py +65 -0
  135. eval_framework/tasks/benchmarks/piqa.py +64 -0
  136. eval_framework/tasks/benchmarks/quality.py +56 -0
  137. eval_framework/tasks/benchmarks/sciq.py +110 -0
  138. eval_framework/tasks/benchmarks/sphyr.py +79 -0
  139. eval_framework/tasks/benchmarks/squad.py +211 -0
  140. eval_framework/tasks/benchmarks/struct_eval.py +116 -0
  141. eval_framework/tasks/benchmarks/tablebench.py +117 -0
  142. eval_framework/tasks/benchmarks/triviaqa.py +42 -0
  143. eval_framework/tasks/benchmarks/truthfulqa.py +119 -0
  144. eval_framework/tasks/benchmarks/winogender.py +64 -0
  145. eval_framework/tasks/benchmarks/winogrande.py +69 -0
  146. eval_framework/tasks/benchmarks/winox.py +57 -0
  147. eval_framework/tasks/benchmarks/wmt.py +160 -0
  148. eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
  149. eval_framework/tasks/eval_config.py +136 -0
  150. eval_framework/tasks/perturbation.py +83 -0
  151. eval_framework/tasks/registry.py +186 -0
  152. eval_framework/tasks/task_loader.py +81 -0
  153. eval_framework/tasks/task_names.py +324 -0
  154. eval_framework/tasks/utils.py +584 -0
  155. eval_framework/utils/constants.py +9 -0
  156. eval_framework/utils/file_ops.py +245 -0
  157. eval_framework/utils/generate_task_docs.py +244 -0
  158. eval_framework/utils/helpers.py +32 -0
  159. eval_framework/utils/logging.py +62 -0
  160. eval_framework/utils/packaging.py +52 -0
  161. eval_framework/utils/tqdm_handler.py +14 -0
  162. eval_framework-0.2.7.dist-info/METADATA +548 -0
  163. eval_framework-0.2.7.dist-info/RECORD +170 -0
  164. eval_framework-0.2.7.dist-info/WHEEL +4 -0
  165. eval_framework-0.2.7.dist-info/entry_points.txt +3 -0
  166. template_formatting/README.md +83 -0
  167. template_formatting/__init__.py +0 -0
  168. template_formatting/formatter.py +537 -0
  169. template_formatting/mistral_formatter.py +159 -0
  170. template_formatting/py.typed +0 -0
@@ -0,0 +1,394 @@
1
+ import atexit
2
+ import logging
3
+ import random
4
+ import re
5
+ import signal
6
+ import sqlite3
7
+ import threading
8
+ from enum import Enum
9
+ from time import sleep
10
+ from typing import Any
11
+ from uuid import uuid4
12
+
13
+ import docker
14
+ import mysql.connector
15
+ import mysql.connector.abstracts
16
+ import psycopg2 # type: ignore
17
+ from pydantic import BaseModel
18
+
19
+ from eval_framework.llm.base import BaseLLM
20
+ from eval_framework.metrics.base import MetricResult
21
+ from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
22
+ from eval_framework.metrics.llm.graders.language import Language
23
+ from eval_framework.metrics.llm.graders.sql_quality_grader import SqlQualityGrader
24
+ from eval_framework.shared.types import Completion, LanguageMetricContext, extract_context_metric
25
+ from eval_framework.tasks.utils import get_docker_address
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class SqlDialects(Enum):
31
+ sqlite = "sqlite"
32
+ postgres = "postgresql"
33
+ mysql = "mysql"
34
+ standard_sql = "standard_sql"
35
+
36
+
37
+ class SqlOutputComparison(BaseModel):
38
+ matches_results_count: bool
39
+ matches_column_count: bool
40
+ results_equal: bool
41
+
42
+
43
+ class SqlValidationResult(BaseModel):
44
+ success: bool
45
+ schema_error: str | None = None
46
+ query_error: str | None = None
47
+ results: list[Any] = []
48
+
49
+
50
+ class LLMJudgeSqlMetricContext(LanguageMetricContext):
51
+ dialect: str
52
+ db_schema: str
53
+
54
+
55
+ _DOCKER_LAUNCH_LOCK = threading.Lock()
56
+ _MYSQL_PORT = 0
57
+ _POSTGRES_PORT = 0
58
+
59
+
60
+ class LLMJudgeSql(BaseLLMJudgeMetric):
61
+ NAME = "SQL Quality"
62
+
63
+ def __init__(self, llm_judge: BaseLLM):
64
+ super().__init__(llm_judge)
65
+ self._grader = SqlQualityGrader(llm_judge)
66
+
67
+ self.postgres_password = "mysecretpassword"
68
+ self.postgres_user = "postgres"
69
+
70
+ self.mysql_password = "mysecretpassword"
71
+ self.mysql_user = "root"
72
+ self.mysql_db_name = "mysql"
73
+
74
+ with _DOCKER_LAUNCH_LOCK:
75
+ if _MYSQL_PORT != 0 and _POSTGRES_PORT != 0:
76
+ return
77
+ self.client = docker.from_env()
78
+ atexit.register(self._shutdown_dbs)
79
+ signal.signal(signal.SIGTERM, lambda *_: self._shutdown_dbs())
80
+ self._start_postgres_db()
81
+ self._start_mysql_db()
82
+ self._wait_for_db_containers()
83
+
84
+ def calculate(self, response: Completion) -> list[MetricResult]:
85
+ if response.error is not None:
86
+ return [
87
+ MetricResult(metric_name=f"{self.NAME}/{k}", value=None, higher_is_better=True, error=response.error)
88
+ for k in [
89
+ "successfully_runs",
90
+ "is_just_sql",
91
+ "matches_results_count",
92
+ "matches_column_count",
93
+ "results_equal",
94
+ "llm_quality_score",
95
+ ]
96
+ ]
97
+
98
+ context = extract_context_metric(response, LLMJudgeSqlMetricContext)
99
+
100
+ assert isinstance(response.ground_truth, str)
101
+
102
+ schema_id = str(uuid4()).replace("-", "_")
103
+
104
+ expected_result = self.validate_query(
105
+ SqlDialects(context.dialect),
106
+ context.db_schema,
107
+ response.ground_truth,
108
+ f"golden_{schema_id}",
109
+ )
110
+ completion_stripped = response.completion.strip().strip("```sql").strip("```")
111
+ completion_query = extract_query_from_completions(completion_stripped)
112
+ if completion_query:
113
+ result = self.validate_query(
114
+ SqlDialects(context.dialect),
115
+ context.db_schema,
116
+ completion_query,
117
+ f"completion_{schema_id}",
118
+ )
119
+ else:
120
+ result = None
121
+
122
+ results = [
123
+ MetricResult(
124
+ metric_name=f"{self.NAME}/successfully_runs",
125
+ value=float(result is not None and result.success),
126
+ higher_is_better=True,
127
+ error=response.error,
128
+ ),
129
+ MetricResult(
130
+ metric_name=f"{self.NAME}/is_just_sql",
131
+ value=float(completion_query == completion_stripped),
132
+ higher_is_better=True,
133
+ error=response.error,
134
+ ),
135
+ ]
136
+
137
+ if result is not None and result.success:
138
+ output_comparison = SqlOutputComparison(
139
+ matches_results_count=len(expected_result.results) == len(result.results),
140
+ matches_column_count=count_result_columns(expected_result.results)
141
+ == count_result_columns(result.results),
142
+ results_equal=expected_result.results == result.results,
143
+ )
144
+ results.extend(
145
+ [
146
+ MetricResult(
147
+ metric_name=f"{self.NAME}/matches_results_count",
148
+ value=float(output_comparison.matches_results_count),
149
+ higher_is_better=True,
150
+ error=response.error,
151
+ ),
152
+ MetricResult(
153
+ metric_name=f"{self.NAME}/matches_column_count",
154
+ value=float(output_comparison.matches_column_count),
155
+ higher_is_better=True,
156
+ error=response.error,
157
+ ),
158
+ MetricResult(
159
+ metric_name=f"{self.NAME}/results_equal",
160
+ value=float(output_comparison.results_equal),
161
+ higher_is_better=True,
162
+ error=response.error,
163
+ ),
164
+ ]
165
+ )
166
+
167
+ grading = self._grader.grade(
168
+ prompt=response.user_instruction,
169
+ completion=completion_stripped,
170
+ result=result.results if result and result.success else None,
171
+ language=Language(response.get_instruction_language()),
172
+ )
173
+ results.append(
174
+ MetricResult(
175
+ metric_name=f"{self.NAME}/llm_quality_score",
176
+ # [0, 1] normalization required for visualizer
177
+ value=(float(grading.query_quality) - 1) / 4 if grading.query_quality is not None else None,
178
+ higher_is_better=True,
179
+ llm_judge_prompt=grading.judge_prompt,
180
+ llm_judge_response=grading.judge_response,
181
+ error=response.error,
182
+ )
183
+ )
184
+
185
+ return results
186
+
187
+ def _start_postgres_db(self) -> None:
188
+ global _POSTGRES_PORT
189
+ for _ in range(10): # find a free port
190
+ try:
191
+ _POSTGRES_PORT = random.randint(1000, 65535)
192
+ self.postgres_docker = self.client.containers.run(
193
+ "docker.io/postgres",
194
+ environment={"POSTGRES_PASSWORD": self.postgres_password},
195
+ ports={5432: _POSTGRES_PORT},
196
+ tty=True,
197
+ auto_remove=True,
198
+ detach=True,
199
+ network_mode="bridge",
200
+ )
201
+ break
202
+ except docker.errors.APIError as e:
203
+ if "port is already allocated" not in str(e):
204
+ raise e
205
+ continue
206
+
207
+ def _start_mysql_db(self) -> None:
208
+ global _MYSQL_PORT
209
+ for _ in range(10): # find a free port
210
+ try:
211
+ _MYSQL_PORT = random.randint(1000, 65535)
212
+ self.mysql_docker = self.client.containers.run(
213
+ "docker.io/mysql:latest",
214
+ environment={"MYSQL_ROOT_PASSWORD": self.mysql_password, "MYSQL_DATABASE": self.mysql_db_name},
215
+ ports={3306: _MYSQL_PORT},
216
+ tty=True,
217
+ auto_remove=True,
218
+ detach=True,
219
+ network_mode="bridge",
220
+ )
221
+ break
222
+ except docker.errors.APIError as e:
223
+ if "port is already allocated" not in str(e):
224
+ raise e
225
+ continue
226
+
227
+ def _wait_for_db_containers(self) -> None:
228
+ for _ in range(600):
229
+ try:
230
+ con = self.connect_to_postgres()
231
+ con.close()
232
+ con = self.connect_to_mysql()
233
+ con.close()
234
+ return
235
+ except Exception:
236
+ logger.info("Could not connect to DBs yet...")
237
+ sleep(1)
238
+ raise Exception("DBs not available.")
239
+
240
+ def _shutdown_dbs(self) -> None:
241
+ if hasattr(self, "postgres_docker"):
242
+ self.postgres_docker.kill()
243
+ if hasattr(self, "mysql_docker"):
244
+ self.mysql_docker.kill()
245
+
246
+ def validate_query(
247
+ self,
248
+ dialect: SqlDialects,
249
+ create_db_statements: str,
250
+ sql_query: str,
251
+ db_schema: str,
252
+ ) -> SqlValidationResult:
253
+ match dialect:
254
+ case SqlDialects.sqlite | SqlDialects.standard_sql:
255
+ return self.validate_query_sqlite(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
256
+ case SqlDialects.postgres:
257
+ return self.validate_query_postgres(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
258
+ case SqlDialects.mysql:
259
+ return self.validate_query_mysql(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
260
+ case _:
261
+ raise NotImplementedError(f"Query validation not implemented for {dialect.value}.")
262
+
263
+ def validate_query_sqlite(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
264
+ con = sqlite3.connect(":memory:")
265
+ cur = con.cursor()
266
+ try:
267
+ statements = separate_statements(create_db_statements)
268
+ for statement in statements:
269
+ cur.execute(statement)
270
+ con.commit()
271
+ except Exception as e:
272
+ logger.info(f"Create statements are not compatible with SQLite. Reason: {e}")
273
+ return SqlValidationResult(success=False, schema_error=str(e))
274
+ try:
275
+ queries = separate_statements(sql_query)
276
+ for query in queries:
277
+ cur.execute(query)
278
+ con.commit()
279
+ results = cur.fetchall()
280
+ except Exception as e:
281
+ logger.info(f"SQL query is not compatible with SQLite. Reason: {e}")
282
+ return SqlValidationResult(success=False, query_error=str(e))
283
+
284
+ con.close()
285
+ return SqlValidationResult(success=True, results=results)
286
+
287
+ def connect_to_postgres(self) -> psycopg2.extensions.connection:
288
+ conn_params = {
289
+ "dbname": "postgres",
290
+ "user": self.postgres_user,
291
+ "password": self.postgres_password,
292
+ "host": get_docker_address(),
293
+ "port": _POSTGRES_PORT,
294
+ }
295
+ return psycopg2.connect(**conn_params)
296
+
297
+ def validate_query_postgres(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
298
+ con = self.connect_to_postgres()
299
+ cur = con.cursor()
300
+ cur.execute(f"CREATE SCHEMA {db_schema};")
301
+ con.commit()
302
+ cur.execute(f"ALTER USER {self.postgres_user} set SEARCH_PATH = {db_schema};")
303
+ con.commit()
304
+ try:
305
+ statements = separate_statements(create_db_statements)
306
+ for statement in statements:
307
+ cur.execute(statement)
308
+ con.commit()
309
+ except Exception as e:
310
+ logger.info(f"Create statements are not compatible with PostgreSQL. Reason: {e}")
311
+ return SqlValidationResult(success=False, schema_error=str(e))
312
+ try:
313
+ queries = separate_statements(sql_query)
314
+ for query in queries:
315
+ cur.execute(query)
316
+ con.commit()
317
+ results = cur.fetchall()
318
+ except Exception as e:
319
+ logger.info(f"SQL query is not compatible with PostgreSQL. Reason: {e}")
320
+ return SqlValidationResult(success=False, query_error=str(e))
321
+
322
+ con.commit()
323
+
324
+ con.close()
325
+ return SqlValidationResult(success=True, results=results)
326
+
327
+ def connect_to_mysql(
328
+ self,
329
+ ) -> mysql.connector.pooling.PooledMySQLConnection | mysql.connector.abstracts.MySQLConnectionAbstract:
330
+ conn_params = {
331
+ "database": self.mysql_db_name,
332
+ "user": self.mysql_user,
333
+ "password": self.mysql_password,
334
+ "host": get_docker_address(),
335
+ "port": _MYSQL_PORT,
336
+ }
337
+ return mysql.connector.connect(**conn_params)
338
+
339
+ def validate_query_mysql(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
340
+ con = self.connect_to_mysql()
341
+ cur = con.cursor(buffered=True)
342
+ cur.execute(f"CREATE SCHEMA {db_schema};")
343
+ con.commit()
344
+ cur.execute(f"USE {db_schema};")
345
+ try:
346
+ statements = separate_statements(create_db_statements)
347
+ for statement in statements:
348
+ cur.execute(statement)
349
+ con.commit()
350
+ except Exception as e:
351
+ logger.info(f"Create statements are not compatible with MySQL. Reason: {e}")
352
+ con.close()
353
+ return SqlValidationResult(success=False, schema_error=str(e))
354
+ try:
355
+ queries = separate_statements(sql_query)
356
+ for query in queries:
357
+ cur.execute(query)
358
+ con.commit()
359
+ results = cur.fetchall()
360
+ except Exception as e:
361
+ logger.info(f"SQL query is not compatible with MySQL. Reason: {e}")
362
+ con.close()
363
+ return SqlValidationResult(success=False, query_error=str(e))
364
+
365
+ cur.close()
366
+ con.close()
367
+ return SqlValidationResult(success=True, results=results)
368
+
369
+
370
+ def separate_statements(statements: str) -> list[str]:
371
+ return statements.split(";")[:-1]
372
+
373
+
374
+ def is_create_table_statement(statement: str) -> bool:
375
+ return "CREATE TABLE" in statement
376
+
377
+
378
+ def count_result_columns(result: list[Any]) -> int:
379
+ if len(result) == 0:
380
+ return 0
381
+ return len(result[0])
382
+
383
+
384
+ def extract_query_from_completions(completion: str) -> str | None:
385
+ # Match SQL blocks starting with SELECT or WITH at line start
386
+ # (allowing punctuation/whitespace), ending at first semicolon
387
+ pattern = re.compile(r"(?:^|\n)[^a-zA-Z0-9_]*((?:select|with)\b.*?;)", re.IGNORECASE | re.DOTALL)
388
+
389
+ matches = pattern.findall(completion)
390
+
391
+ # Return the query only if exactly one match is found
392
+ if len(matches) == 1:
393
+ return matches[0].strip()
394
+ return None
@@ -0,0 +1,37 @@
1
+ from eval_framework.llm.base import BaseLLM
2
+ from eval_framework.metrics.base import MetricResult
3
+ from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
4
+ from eval_framework.metrics.llm.graders.language import Language
5
+ from eval_framework.metrics.llm.graders.summary_world_knowledge_grader import SummarizationWorldKnowledgeGrader
6
+ from eval_framework.shared.types import Completion
7
+
8
+
9
+ class LLMJudgeWorldKnowledge(BaseLLMJudgeMetric):
10
+ NAME = "World Knowledge"
11
+
12
+ def __init__(self, llm_judge: BaseLLM):
13
+ super().__init__(llm_judge)
14
+ self._grader = SummarizationWorldKnowledgeGrader(llm_judge)
15
+
16
+ def calculate(self, response: Completion) -> list[MetricResult]:
17
+ if response.error is not None:
18
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=False, error=response.error)]
19
+
20
+ language = Language(response.get_instruction_language())
21
+
22
+ grading = self._grader.grade(
23
+ reference_input=response.user_instruction,
24
+ completion=response.sanitized_completion,
25
+ language=language,
26
+ )
27
+
28
+ return [
29
+ MetricResult(
30
+ metric_name=self.NAME,
31
+ value=float(grading.contains_world_knowledge) if grading.contains_world_knowledge is not None else None,
32
+ higher_is_better=False,
33
+ llm_judge_prompt=grading.judge_prompt,
34
+ llm_judge_response=grading.judge_response,
35
+ error=response.error,
36
+ )
37
+ ]
@@ -0,0 +1,20 @@
1
+ """Utility functions for LLM-based metrics."""
2
+
3
+
4
+ def order_answers_for_comparison(candidate: str, reference: str, swap: bool) -> tuple[str, str]:
5
+ """Order candidate and reference answers for A/B comparison.
6
+
7
+ This function is used to mitigate position bias in LLM-as-judge evaluations
8
+ by optionally swapping the order in which answers are presented.
9
+
10
+ Args:
11
+ candidate: The candidate completion to evaluate.
12
+ reference: The reference/baseline completion.
13
+ swap: If True, swap the order (reference becomes A, candidate becomes B).
14
+
15
+ Returns:
16
+ Tuple of (answer_a, answer_b) in the correct order.
17
+ """
18
+ if swap:
19
+ return reference, candidate
20
+ return candidate, reference
File without changes
@@ -0,0 +1,51 @@
1
+ from eval_framework.metrics.base import BaseMetric, MetricResult
2
+ from eval_framework.shared.types import Loglikelihood
3
+
4
+
5
+ class AccuracyLoglikelihood(BaseMetric[Loglikelihood]):
6
+ NAME = "Accuracy Loglikelihood"
7
+
8
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
9
+ if response.error is not None:
10
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
11
+
12
+ ground_truth_list = response.ground_truth_list
13
+ completion_text = max(response.loglikelihoods, key=response.loglikelihoods.get) # type: ignore[arg-type]
14
+
15
+ return [
16
+ MetricResult(
17
+ metric_name=self.NAME,
18
+ value=float(completion_text in ground_truth_list),
19
+ higher_is_better=True,
20
+ error=response.error,
21
+ )
22
+ ]
23
+
24
+
25
+ class AccuracyNormLoglikelihood(BaseMetric[Loglikelihood]):
26
+ NAME = "Accuracy Normalized Loglikelihood"
27
+
28
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
29
+ if response.error is not None:
30
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
31
+
32
+ ground_truth_list = response.ground_truth_list
33
+
34
+ output_len_normalized = {}
35
+ for k, v in response.loglikelihoods.items():
36
+ completion_length = len(k)
37
+
38
+ if completion_length != 0:
39
+ output_len_normalized[k] = v / completion_length
40
+ else:
41
+ output_len_normalized[k] = v
42
+
43
+ model_output_len_normalized = max(output_len_normalized, key=output_len_normalized.get) # type:ignore
44
+ return [
45
+ MetricResult(
46
+ metric_name=self.NAME,
47
+ value=float(model_output_len_normalized in ground_truth_list),
48
+ higher_is_better=True,
49
+ error=response.error,
50
+ )
51
+ ]
@@ -0,0 +1,50 @@
1
+ import math
2
+
3
+ from eval_framework.metrics.base import BaseMetric
4
+ from eval_framework.shared.types import Loglikelihood
5
+
6
+
7
+ class BaseLoglikelihoodMetric(BaseMetric[Loglikelihood]):
8
+ """Base class for metrics that operate on loglikelihood responses."""
9
+
10
+ def __init__(
11
+ self,
12
+ *,
13
+ len_normalised: bool = True,
14
+ ) -> None:
15
+ self.len_normalised = len_normalised
16
+
17
+ def _normalise_text(self, text: str) -> str:
18
+ return text.strip().lower()
19
+
20
+ def _length_normalise_loglikelihoods(self, loglikelihoods: dict) -> dict:
21
+ """Return a dict of length-normalised loglikelihoods."""
22
+ output = {}
23
+ for k, v in loglikelihoods.items():
24
+ length = len(k)
25
+ output[k] = v / length if length > 0 else v
26
+ return output
27
+
28
+ def _compute_probabilities(self, loglikelihoods: dict) -> tuple[dict, dict]:
29
+ """Compute probabilities from loglikelihoods, with optional length normalisation."""
30
+ if self.len_normalised:
31
+ loglikelihoods = self._length_normalise_loglikelihoods(loglikelihoods)
32
+ return loglikelihoods, self._softmax(loglikelihoods)
33
+
34
+ def _gather_ground_truths(self, response: Loglikelihood) -> set[str]:
35
+ """Extract and normalize ground truth completions from a Loglikelihood response."""
36
+ ground_truths = set(
37
+ self._normalise_text(gt)
38
+ for gt in (response.ground_truth if isinstance(response.ground_truth, list) else [response.ground_truth])
39
+ )
40
+ return ground_truths
41
+
42
+ def _softmax(self, log_probs: dict) -> dict:
43
+ """Convert log-likelihoods to probabilities with softmax."""
44
+ vals = list(log_probs.values())
45
+ if not vals: # no valid entries
46
+ return {}
47
+ m = max(vals)
48
+ exp_vals = [math.exp(x - m) for x in vals]
49
+ total = sum(exp_vals)
50
+ return {k: ev / total for k, ev in zip(log_probs.keys(), exp_vals)}
@@ -0,0 +1,25 @@
1
+ from eval_framework.metrics.base import MetricResult
2
+ from eval_framework.metrics.loglikelihood.base import BaseLoglikelihoodMetric
3
+ from eval_framework.shared.types import Loglikelihood
4
+
5
+
6
+ class ConfidenceWeightedAccuracy(BaseLoglikelihoodMetric):
7
+ NAME = "Confidence-weighted Accuracy"
8
+
9
+ def __init__(self, *, len_normalised: bool = True) -> None:
10
+ super().__init__(len_normalised=len_normalised)
11
+
12
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
13
+ if response.error is not None:
14
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
15
+
16
+ loglikelihoods, probs = self._compute_probabilities(response.loglikelihoods)
17
+ ground_truths = self._gather_ground_truths(response)
18
+
19
+ best_key = max(loglikelihoods, key=loglikelihoods.get) # type: ignore[arg-type]
20
+ best_key_norm = self._normalise_text(best_key)
21
+ p_c = probs.get(best_key, 0.0)
22
+
23
+ accuracy = p_c if best_key_norm in ground_truths else 0.0
24
+
25
+ return [MetricResult(metric_name=self.NAME, value=accuracy, higher_is_better=True, error=response.error)]
@@ -0,0 +1,43 @@
1
+ from eval_framework.metrics.base import MetricResult
2
+ from eval_framework.metrics.loglikelihood.base import BaseLoglikelihoodMetric
3
+ from eval_framework.shared.types import Loglikelihood
4
+
5
+
6
+ class DistributionalCorrectnessScore(BaseLoglikelihoodMetric):
7
+ """Based on Burns (2025) Measuring Language Model Hallucinations Through Distributional Correctness."""
8
+
9
+ NAME = "Distributional Correctness Score"
10
+
11
+ def __init__(
12
+ self,
13
+ *,
14
+ lc: float = 1.0, # Default reward weight for correct answers
15
+ lw: float = 1.0, # Default penalty weight for wrong answers
16
+ len_normalised: bool = True,
17
+ ) -> None:
18
+ super().__init__(len_normalised=len_normalised)
19
+ self._lc = float(lc)
20
+ self._lw = float(lw)
21
+ if not (self._lc >= 0 and self._lw >= 0 and self._lc >= self._lw):
22
+ raise ValueError(f"Invalid DCS loadings: lc={self._lc}, lw={self._lw}. Require lc>=0, lw>=0, and lc>=lw.")
23
+
24
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
25
+ if response.error is not None:
26
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
27
+
28
+ loglikelihoods, probs = self._compute_probabilities(response.loglikelihoods)
29
+ ground_truths = self._gather_ground_truths(response)
30
+
31
+ idk_key = self._normalise_text(list(response.loglikelihoods.keys())[-1]) # assumes last key is "IDK" option
32
+
33
+ p_c = sum(p for k, p in probs.items() if self._normalise_text(k) in ground_truths)
34
+ p_idk = probs.get(idk_key, 0.0)
35
+ p_w = sum(
36
+ p
37
+ for k, p in probs.items()
38
+ if (self._normalise_text(k) not in ground_truths and self._normalise_text(k) != idk_key)
39
+ )
40
+
41
+ dcs = (self._lc * p_c - self._lw * p_w) * (1.0 - p_idk)
42
+
43
+ return [MetricResult(metric_name=self.NAME, value=float(dcs), higher_is_better=True, error=response.error)]
@@ -0,0 +1,53 @@
1
+ import numpy as np
2
+
3
+ from eval_framework.metrics.base import BaseMetric, MetricResult
4
+ from eval_framework.shared.types import Loglikelihood
5
+
6
+
7
+ class ProbabilityMass(BaseMetric[Loglikelihood]):
8
+ NAME = "Probability Mass"
9
+
10
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
11
+ if response.error is not None:
12
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
13
+
14
+ assert isinstance(response.ground_truth, list)
15
+ # https://docs.python.org/3.10/library/stdtypes.html?highlight=dictview#dictionary-view-objects
16
+ in_ground_truths = [completion in response.ground_truth for completion in response.loglikelihoods]
17
+ log_probs = list(response.loglikelihoods.values())
18
+
19
+ probs = np.exp(log_probs) / np.sum(np.exp(log_probs))
20
+ prob_mass = np.sum(probs[in_ground_truths])
21
+
22
+ return [
23
+ MetricResult(metric_name=self.NAME, value=float(prob_mass), higher_is_better=True, error=response.error)
24
+ ]
25
+
26
+
27
+ class ProbabilityMassNorm(BaseMetric[Loglikelihood]):
28
+ NAME = "Probability Mass Normalized"
29
+
30
+ def calculate(self, response: Loglikelihood) -> list[MetricResult]:
31
+ if response.error is not None:
32
+ return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=True, error=response.error)]
33
+
34
+ assert isinstance(response.ground_truth, list)
35
+ # len normalized
36
+
37
+ output_len_normalized = {}
38
+ for k, v in response.loglikelihoods.items():
39
+ completion_length = len(k)
40
+
41
+ if completion_length != 0:
42
+ output_len_normalized[k] = v / completion_length
43
+ else:
44
+ output_len_normalized[k] = v
45
+
46
+ log_probs = list(output_len_normalized.values())
47
+ in_ground_truths = [completion in response.ground_truth for completion in response.loglikelihoods]
48
+ log_probs = list(output_len_normalized.values())
49
+
50
+ probs = np.exp(log_probs) / np.sum(np.exp(log_probs))
51
+ prob_mass_norm = np.sum(probs[in_ground_truths])
52
+
53
+ return [MetricResult(metric_name=self.NAME, value=prob_mass_norm, higher_is_better=True, error=response.error)]