eval-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eval_framework/__init__.py +7 -0
- eval_framework/base_config.py +36 -0
- eval_framework/context/__init__.py +0 -0
- eval_framework/context/determined.py +170 -0
- eval_framework/context/eval.py +114 -0
- eval_framework/context/local.py +52 -0
- eval_framework/evaluation_generator.py +231 -0
- eval_framework/exceptions.py +2 -0
- eval_framework/external/ifeval_impl/README.md +5 -0
- eval_framework/external/ifeval_impl/instructions.py +1523 -0
- eval_framework/external/ifeval_impl/instructions_registry.py +161 -0
- eval_framework/external/ifeval_impl/instructions_util.py +1689 -0
- eval_framework/external/ifeval_impl/utils.py +135 -0
- eval_framework/llm/__init__.py +0 -0
- eval_framework/llm/aleph_alpha.py +323 -0
- eval_framework/llm/base.py +58 -0
- eval_framework/llm/huggingface.py +332 -0
- eval_framework/llm/mistral.py +73 -0
- eval_framework/llm/models.py +16 -0
- eval_framework/llm/openai.py +205 -0
- eval_framework/llm/vllm.py +438 -0
- eval_framework/logger.py +3 -0
- eval_framework/main.py +187 -0
- eval_framework/metrics/__init__.py +0 -0
- eval_framework/metrics/base.py +40 -0
- eval_framework/metrics/completion/__init__.py +1 -0
- eval_framework/metrics/completion/accuracy_completion.py +16 -0
- eval_framework/metrics/completion/bleu.py +76 -0
- eval_framework/metrics/completion/chrf.py +62 -0
- eval_framework/metrics/completion/code_assertion.py +44 -0
- eval_framework/metrics/completion/code_execution_pass_at_one.py +126 -0
- eval_framework/metrics/completion/comet.py +56 -0
- eval_framework/metrics/completion/concordance_index.py +38 -0
- eval_framework/metrics/completion/csv_format.py +102 -0
- eval_framework/metrics/completion/cwe_accuracy.py +49 -0
- eval_framework/metrics/completion/exponential_similarity.py +65 -0
- eval_framework/metrics/completion/f1.py +42 -0
- eval_framework/metrics/completion/format_checker.py +56 -0
- eval_framework/metrics/completion/grid_difference.py +77 -0
- eval_framework/metrics/completion/ifeval.py +73 -0
- eval_framework/metrics/completion/json_format.py +171 -0
- eval_framework/metrics/completion/language_checker.py +74 -0
- eval_framework/metrics/completion/length_control.py +83 -0
- eval_framework/metrics/completion/math_reasoning_completion.py +303 -0
- eval_framework/metrics/completion/niah_accuracy.py +163 -0
- eval_framework/metrics/completion/placeholder_checker.py +27 -0
- eval_framework/metrics/completion/repetition.py +88 -0
- eval_framework/metrics/completion/rouge_1.py +35 -0
- eval_framework/metrics/completion/rouge_2.py +45 -0
- eval_framework/metrics/completion/rouge_geometric_mean.py +36 -0
- eval_framework/metrics/completion/rouge_l.py +52 -0
- eval_framework/metrics/completion/struct_eval_metrics.py +248 -0
- eval_framework/metrics/completion/ter.py +67 -0
- eval_framework/metrics/completion/text_counter.py +182 -0
- eval_framework/metrics/efficiency/__init__.py +0 -0
- eval_framework/metrics/efficiency/bytes_per_sequence_position.py +48 -0
- eval_framework/metrics/llm/__init__.py +0 -0
- eval_framework/metrics/llm/base.py +8 -0
- eval_framework/metrics/llm/graders/chatbot_style_grader.py +92 -0
- eval_framework/metrics/llm/graders/comparison_grader.py +146 -0
- eval_framework/metrics/llm/graders/conciseness_grader.py +93 -0
- eval_framework/metrics/llm/graders/contains_names_grader.py +71 -0
- eval_framework/metrics/llm/graders/format_correctness_grader.py +109 -0
- eval_framework/metrics/llm/graders/instruction_grader.py +177 -0
- eval_framework/metrics/llm/graders/language.py +56 -0
- eval_framework/metrics/llm/graders/long_context_grader.py +72 -0
- eval_framework/metrics/llm/graders/models.py +74 -0
- eval_framework/metrics/llm/graders/refusal_grader.py +57 -0
- eval_framework/metrics/llm/graders/sql_quality_grader.py +145 -0
- eval_framework/metrics/llm/graders/summary_world_knowledge_grader.py +103 -0
- eval_framework/metrics/llm/llm_judge_chatbot_style.py +36 -0
- eval_framework/metrics/llm/llm_judge_completion_accuracy.py +39 -0
- eval_framework/metrics/llm/llm_judge_conciseness.py +37 -0
- eval_framework/metrics/llm/llm_judge_contains_names.py +36 -0
- eval_framework/metrics/llm/llm_judge_format_correctness.py +43 -0
- eval_framework/metrics/llm/llm_judge_instruction.py +58 -0
- eval_framework/metrics/llm/llm_judge_mtbench_pair.py +205 -0
- eval_framework/metrics/llm/llm_judge_mtbench_single.py +188 -0
- eval_framework/metrics/llm/llm_judge_refusal.py +35 -0
- eval_framework/metrics/llm/llm_judge_sql.py +394 -0
- eval_framework/metrics/llm/llm_judge_world_knowledge.py +37 -0
- eval_framework/metrics/loglikelihood/__init__.py +0 -0
- eval_framework/metrics/loglikelihood/accuracy_loglikelihood.py +51 -0
- eval_framework/metrics/loglikelihood/probability_mass.py +56 -0
- eval_framework/py.typed +0 -0
- eval_framework/response_generator.py +416 -0
- eval_framework/result_processors/__init__.py +0 -0
- eval_framework/result_processors/base.py +74 -0
- eval_framework/result_processors/hf_processor.py +87 -0
- eval_framework/result_processors/result_processor.py +129 -0
- eval_framework/run.py +314 -0
- eval_framework/run_direct.py +42 -0
- eval_framework/shared/types.py +227 -0
- eval_framework/tasks/__init__.py +6 -0
- eval_framework/tasks/base.py +314 -0
- eval_framework/tasks/benchmarks/__init__.py +0 -0
- eval_framework/tasks/benchmarks/arc.py +46 -0
- eval_framework/tasks/benchmarks/arc_de.py +46 -0
- eval_framework/tasks/benchmarks/arc_fi.py +46 -0
- eval_framework/tasks/benchmarks/belebele.py +60 -0
- eval_framework/tasks/benchmarks/bigcodebench.py +155 -0
- eval_framework/tasks/benchmarks/casehold.py +47 -0
- eval_framework/tasks/benchmarks/chembench.py +85 -0
- eval_framework/tasks/benchmarks/copa.py +39 -0
- eval_framework/tasks/benchmarks/duc.py +91 -0
- eval_framework/tasks/benchmarks/flores200.py +62 -0
- eval_framework/tasks/benchmarks/flores_plus.py +84 -0
- eval_framework/tasks/benchmarks/gpqa.py +177 -0
- eval_framework/tasks/benchmarks/gsm8k.py +148 -0
- eval_framework/tasks/benchmarks/hellaswag.py +44 -0
- eval_framework/tasks/benchmarks/hellaswag_de.py +52 -0
- eval_framework/tasks/benchmarks/humaneval.py +97 -0
- eval_framework/tasks/benchmarks/ifeval.py +78 -0
- eval_framework/tasks/benchmarks/include.py +119 -0
- eval_framework/tasks/benchmarks/infinitebench.py +302 -0
- eval_framework/tasks/benchmarks/math_reasoning.py +569 -0
- eval_framework/tasks/benchmarks/mbpp.py +192 -0
- eval_framework/tasks/benchmarks/mmlu.py +190 -0
- eval_framework/tasks/benchmarks/mmlu_de.py +109 -0
- eval_framework/tasks/benchmarks/mmlu_pro.py +139 -0
- eval_framework/tasks/benchmarks/mmmlu.py +529 -0
- eval_framework/tasks/benchmarks/openbookqa.py +37 -0
- eval_framework/tasks/benchmarks/opengptx_eu20.py +363 -0
- eval_framework/tasks/benchmarks/pawsx.py +65 -0
- eval_framework/tasks/benchmarks/piqa.py +39 -0
- eval_framework/tasks/benchmarks/quality.py +56 -0
- eval_framework/tasks/benchmarks/sciq.py +44 -0
- eval_framework/tasks/benchmarks/sphyr.py +75 -0
- eval_framework/tasks/benchmarks/squad.py +89 -0
- eval_framework/tasks/benchmarks/struct_eval.py +110 -0
- eval_framework/tasks/benchmarks/tablebench.py +117 -0
- eval_framework/tasks/benchmarks/triviaqa.py +42 -0
- eval_framework/tasks/benchmarks/truthfulqa.py +95 -0
- eval_framework/tasks/benchmarks/winogender.py +39 -0
- eval_framework/tasks/benchmarks/winogrande.py +44 -0
- eval_framework/tasks/benchmarks/winox.py +57 -0
- eval_framework/tasks/benchmarks/wmt.py +160 -0
- eval_framework/tasks/benchmarks/zero_scrolls.py +197 -0
- eval_framework/tasks/eval_config.py +112 -0
- eval_framework/tasks/perturbation.py +83 -0
- eval_framework/tasks/registry.py +186 -0
- eval_framework/tasks/task_loader.py +80 -0
- eval_framework/tasks/task_names.py +138 -0
- eval_framework/tasks/utils.py +578 -0
- eval_framework/utils/constants.py +9 -0
- eval_framework/utils/generate_task_docs.py +229 -0
- eval_framework/utils/helpers.py +3 -0
- eval_framework/utils/logging.py +50 -0
- eval_framework/utils/packaging.py +52 -0
- eval_framework-0.2.0.dist-info/METADATA +514 -0
- eval_framework-0.2.0.dist-info/RECORD +161 -0
- eval_framework-0.2.0.dist-info/WHEEL +4 -0
- eval_framework-0.2.0.dist-info/entry_points.txt +3 -0
- template_formatting/README.md +83 -0
- template_formatting/__init__.py +0 -0
- template_formatting/formatter.py +536 -0
- template_formatting/mistral_formatter.py +159 -0
- template_formatting/py.typed +0 -0
- template_formatting/tests/test_formatter_eval.py +408 -0
- template_formatting/tests/test_formatter_scaling.py +253 -0
- template_formatting/tests/test_mistral_formatter.py +136 -0
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from eval_framework.logger import logger
|
|
6
|
+
from eval_framework.metrics.base import MetricResult
|
|
7
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
8
|
+
from eval_framework.shared.types import BaseMetricContext, Completion, extract_context_metric
|
|
9
|
+
from template_formatting.formatter import Message, Role
|
|
10
|
+
|
|
11
|
+
SINGLE_JUDGE_PROMPTS = {
|
|
12
|
+
"single_assistant_single_turn": {
|
|
13
|
+
"prompt_template": '[System]\nYou are a helpful assistant.\n\n[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example: "Rating: [[5]]".\n\n[Question]\n{question}\n\n[The Start of Assistant\'s Answer]\n{answer}\n[The End of Assistant\'s Answer]' # noqa: E501
|
|
14
|
+
},
|
|
15
|
+
"single_assistant_multi_turn": {
|
|
16
|
+
"prompt_template": 'Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant\'s answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example: "Rating: [[5]]".\n\n<|The Start of Assistant A\'s Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A\'s Conversation with User|>' # noqa: E501
|
|
17
|
+
},
|
|
18
|
+
"single_assistant_single_turn_w_reference": {
|
|
19
|
+
"prompt_template": "[System]\nYou are a helpful assistant.\n\n[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]" # noqa: E501
|
|
20
|
+
},
|
|
21
|
+
"single_assistant_multi_turn_w_reference": {
|
|
22
|
+
"prompt_template": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>" # noqa: E501
|
|
23
|
+
},
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
SINGLE_JUDGE_PROMPTS_DE = {
|
|
27
|
+
"single_assistant_single_turn": {
|
|
28
|
+
"prompt_template": '[System]\nDu bist ein hilfreicher Assistent.\n\n[Anweisung]\nBitte agieren Sie als unparteiischer Richter und bewerten Sie die Qualität der Antwort, die von einem KI-Assistenten auf die unten angezeigte Nutzerfrage gegeben wurde. Ihre Bewertung sollte Faktoren wie Nützlichkeit, Relevanz, Genauigkeit, Tiefe, Kreativität und Detailliertheit der Antwort berücksichtigen. Beginnen Sie Ihre Bewertung mit einer kurzen Erklärung. Seien Sie so objektiv wie möglich. Nachdem Sie Ihre Erklärung gegeben haben, müssen Sie die Antwort auf einer Skala von 1 bis 10 bewerten und dabei streng dieses Format einhalten: "[[rating]]", zum Beispiel: "Bewertung: [[5]]".\n\n[Frage]\n{question}\n\n[Der Anfang der Assistentenantwort]\n{answer}\n[Das Ende der Assistentenantwort]' # noqa: E501
|
|
29
|
+
},
|
|
30
|
+
"single_assistant_multi_turn": {
|
|
31
|
+
"prompt_template": 'Bitte agieren Sie als unparteiischer Richter und bewerten Sie die Qualität der Antwort, die von einem KI-Assistenten auf die unten angezeigte Nutzerfrage gegeben wurde. Ihre Bewertung sollte Faktoren wie Nützlichkeit, Relevanz, Genauigkeit, Tiefe, Kreativität und Detailliertheit der Antwort berücksichtigen. Ihre Bewertung sollte sich auf die Antwort des Assistenten auf die zweite Nutzerfrage konzentrieren. Beginnen Sie Ihre Bewertung mit einer kurzen Erklärung. Seien Sie so objektiv wie möglich. Nachdem Sie Ihre Erklärung gegeben haben, müssen Sie die Antwort auf einer Skala von 1 bis 10 bewerten, wobei Sie streng dieses Format einhalten: "[[rating]]", zum Beispiel: "Bewertung: [[5]]".\n\n<|Der Anfang von Assistent A\'s Unterhaltung mit dem Nutzer|>\n\n### Nutzer:\n{question_1}\n\n### Assistent A:\n{answer_1}\n\n### Nutzer:\n{question_2}\n\n### Assistent A:\n{answer_2}\n\n<|Das Ende von Assistent A\'s Unterhaltung mit dem Nutzer|>' # noqa: E501
|
|
32
|
+
},
|
|
33
|
+
"single_assistant_single_turn_w_reference": {
|
|
34
|
+
"prompt_template": '[System]\nDu bist ein hilfreicher Assistent.\n\n[Anweisung]\nBitte agieren Sie als unparteiischer Richter und bewerten Sie die Qualität der Antwort, die von einem KI-Assistenten auf die unten angezeigte Nutzerfrage gegeben wurde. Ihre Bewertung sollte Korrektheit und Nützlichkeit berücksichtigen. Ihnen wird eine Referenzantwort und die Antwort des Assistenten gegeben. Beginnen Sie Ihre Bewertung, indem Sie die Antwort des Assistenten mit der Referenzantwort vergleichen. Identifizieren Sie und korrigieren Sie etwaige Fehler. Seien Sie so objektiv wie möglich. Nachdem Sie Ihre Erklärung gegeben haben, müssen Sie die Antwort auf einer Skala von 1 bis 10 bewerten und dabei streng dieses Format einhalten: "[[rating]]", zum Beispiel: "Bewertung: [[5]]".\n\n[Frage]\n{question}\n\n[Der Anfang der Referenzantwort]\n{ref_answer_1}\n[Das Ende der Referenzantwort]\n\n[Der Anfang der Assistentenantwort]\n{answer}\n[Das Ende der Assistentenantwort]' # noqa: E501
|
|
35
|
+
},
|
|
36
|
+
"single_assistant_multi_turn_w_reference": {
|
|
37
|
+
"propmt_templte": 'Bitte agieren Sie als unparteiischer Richter und bewerten Sie die Qualität der Antwort, die von einem KI-Assistenten auf die Nutzerfrage gegeben wurde. Ihre Bewertung sollte Korrektheit und Nützlichkeit berücksichtigen. Ihnen wird eine Referenzantwort und die Antwort des Assistenten gegeben. Ihre Bewertung sollte sich auf die Antwort des Assistenten auf die zweite Frage konzentrieren. Beginnen Sie Ihre Bewertung, indem Sie die Antwort des Assistenten mit der Referenzantwort vergleichen. Identifizieren und korrigieren Sie etwaige Fehler. Seien Sie so objektiv wie möglich. Nachdem Sie Ihre Erklärung gegeben haben, müssen Sie die Antwort auf einer Skala von 1 bis 10 bewerten, wobei Sie streng dieses Format einhalten: "[[rating]]", zum Beispiel: "Bewertung: [[5]]".\n\n<|Der Anfang der Referenzantwort|>\n\n### Nutzer:\n{question_1}\n\n### Referenzantwort:\n{ref_answer_1}\n\n### Nutzer:\n{question_2}\n\n### Referenzantwort:\n{ref_answer_2}\n\n<|Das Ende der Referenzantwort|>\n\n\n<|Der Anfang von Assistent A\'s Unterhaltung mit dem Nutzer|>\n\n### Nutzer:\n{question_1}\n\n### Assistent A:\n{answer_1}\n\n### Nutzer:\n{question_2}\n\n### Assistent A:\n{answer_2}\n\n<|Das Ende von Assistent A\'s Unterhaltung mit dem Nutzer|>' # noqa: E501
|
|
38
|
+
},
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
SINGLE_JUDGE_PROMPTS_FI = {
|
|
43
|
+
"single_assistant_single_turn": {
|
|
44
|
+
"prompt_template": '[Järjestelmä]\nOlet avulias avustaja.\n\n[Ohje]\nToimi puolueettomana tuomarina ja arvioi AI-avustajan antaman vastauksen laatua käyttäjän kysymykseen, joka näkyy alla. Arviosi tulisi ottaa huomioon tekijät kuten hyödyllisyys, asiaankuuluvuus, tarkkuus, syvällisyys, luovuus ja yksityiskohtien taso. Aloita arviointisi antamalla lyhyt selitys. Ole mahdollisimman objektiivinen. Selityksen jälkeen sinun on arvioitava vastaus asteikolla 1–10 noudattamalla tarkasti tätä muotoa: "[[arvosana]]", esimerkiksi: "Arvosana: [[5]]".\n\n[Kysymys]\n{question}\n\n[Avustajan vastauksen alku]\n{answer}\n[Avustajan vastauksen loppu]' # noqa: E501
|
|
45
|
+
},
|
|
46
|
+
"single_assistant_multi_turn": {
|
|
47
|
+
"prompt_template": 'Toimi puolueettomana tuomarina ja arvioi AI-avustajan antaman vastauksen laatua käyttäjän kysymykseen, joka näkyy alla. Arviosi tulisi ottaa huomioon tekijät kuten hyödyllisyys, asiaankuuluvuus, tarkkuus, syvällisyys, luovuus ja yksityiskohtien taso. Arviosi tulisi keskittyä avustajan vastaukseen toiseen käyttäjän kysymykseen. Aloita arviointisi antamalla lyhyt selitys. Ole mahdollisimman objektiivinen. Selityksen jälkeen sinun on arvioitava vastaus asteikolla 1–10 noudattamalla tarkasti tätä muotoa: "[[arvosana]]", esimerkiksi: "Arvosana: [[5]]".\n\n<|Avustaja A:n keskustelun alku käyttäjän kanssa|>\n\n### Käyttäjä:\n{question_1}\n\n### Avustaja A:\n{answer_1}\n\n### Käyttäjä:\n{question_2}\n\n### Avustaja A:\n{answer_2}\n\n<|Avustaja A:n keskustelun loppu käyttäjän kanssa|>' # noqa: E501
|
|
48
|
+
},
|
|
49
|
+
"single_assistant_single_turn_w_reference": {
|
|
50
|
+
"prompt_template": '[Järjestelmä]\nOlet avulias avustaja.\n\n[Ohje]\nToimi puolueettomana tuomarina ja arvioi AI-avustajan antaman vastauksen laatua käyttäjän kysymykseen, joka näkyy alla. Arviosi tulisi ottaa huomioon oikeellisuus ja hyödyllisyys. Sinulle annetaan viitevastaus ja avustajan vastaus. Aloita arviointisi vertaamalla avustajan vastausta viitevastaukseen. Tunnista ja korjaa mahdolliset virheet. Ole mahdollisimman objektiivinen. Selityksen jälkeen sinun on arvioitava vastaus asteikolla 1–10 noudattamalla tarkasti tätä muotoa: "[[arvosana]]", esimerkiksi: "Arvosana: [[5]]".\n\n[Kysymys]\n{question}\n\n[Viitevastauksen alku]\n{ref_answer_1}\n[Viitevastauksen loppu]\n\n[Avustajan vastauksen alku]\n{answer}\n[Avustajan vastauksen loppu]' # noqa: E501
|
|
51
|
+
},
|
|
52
|
+
"single_assistant_multi_turn_w_reference": {
|
|
53
|
+
"prompt_template": 'Toimi puolueettomana tuomarina ja arvioi AI-avustajan antaman vastauksen laatua käyttäjän kysymykseen. Arviosi tulisi ottaa huomioon oikeellisuus ja hyödyllisyys. Sinulle annetaan viitevastaus ja avustajan vastaus. Arviosi tulisi keskittyä avustajan vastaukseen toiseen kysymykseen. Aloita arviointisi vertaamalla avustajan vastausta viitevastaukseen. Tunnista ja korjaa mahdolliset virheet. Ole mahdollisimman objektiivinen. Selityksen jälkeen sinun on arvioitava vastaus asteikolla 1–10 noudattamalla tarkasti tätä muotoa: "[[arvosana]]", esimerkiksi: "Arvosana: [[5]]".\n\n<|Viitevastauksen alku|>\n\n### Käyttäjä:\n{question_1}\n\n### Viitevastaus:\n{ref_answer_1}\n\n### Käyttäjä:\n{question_2}\n\n### Viitevastaus:\n{ref_answer_2}\n\n<|Viitevastauksen loppu|>\n\n\n<|Avustaja A:n keskustelun alku käyttäjän kanssa|>\n\n### Käyttäjä:\n{question_1}\n\n### Avustaja A:\n{answer_1}\n\n### Käyttäjä:\n{question_2}\n\n### Avustaja A:\n{answer_2}\n\n<|Avustaja A:n keskustelun loppu käyttäjän kanssa|>' # noqa: E501
|
|
54
|
+
},
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
NEED_REF_CATEGORIES = ["math", "reasoning", "coding", "arena-hard-200"]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class PromptToJudge(BaseModel):
|
|
61
|
+
comparison_type: str
|
|
62
|
+
prompt_text: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MTBenchJudgeSingleMetricContext(BaseMetricContext):
|
|
66
|
+
category: str
|
|
67
|
+
reference: list[str] | str | None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def generate_single_judge_prompts(response: Completion) -> list[PromptToJudge]:
|
|
71
|
+
context = extract_context_metric(response, MTBenchJudgeSingleMetricContext)
|
|
72
|
+
|
|
73
|
+
assert response.messages is not None
|
|
74
|
+
|
|
75
|
+
if response.subject.startswith("de"):
|
|
76
|
+
prompt_templates = SINGLE_JUDGE_PROMPTS_DE
|
|
77
|
+
elif response.subject.startswith("fi"):
|
|
78
|
+
prompt_templates = SINGLE_JUDGE_PROMPTS_FI
|
|
79
|
+
else:
|
|
80
|
+
prompt_templates = SINGLE_JUDGE_PROMPTS
|
|
81
|
+
prompts_to_judge = []
|
|
82
|
+
|
|
83
|
+
assert context.category is not None, "Category must be provided in the context for MTBenchJudgeSingleMetricContext"
|
|
84
|
+
|
|
85
|
+
# No reference answer needed
|
|
86
|
+
if context.category not in NEED_REF_CATEGORIES:
|
|
87
|
+
# SINLGE TURN
|
|
88
|
+
if len(response.messages) <= 2:
|
|
89
|
+
# turn 1
|
|
90
|
+
question = response.last_user_instruction
|
|
91
|
+
answer = response.completion
|
|
92
|
+
# format prompt
|
|
93
|
+
single_turn_prompt = prompt_templates["single_assistant_single_turn"]["prompt_template"].format(
|
|
94
|
+
question=question,
|
|
95
|
+
answer=answer,
|
|
96
|
+
)
|
|
97
|
+
prompts_to_judge.append(PromptToJudge(comparison_type="single_judgement", prompt_text=single_turn_prompt))
|
|
98
|
+
# MULTI TURN
|
|
99
|
+
else:
|
|
100
|
+
# turn 1
|
|
101
|
+
question_1 = response.first_user_instruction
|
|
102
|
+
answer_1 = response.messages[1].content
|
|
103
|
+
# turn 2
|
|
104
|
+
question_2 = response.last_user_instruction
|
|
105
|
+
answer_2 = response.completion
|
|
106
|
+
# format prompt
|
|
107
|
+
multi_turn_prompt = prompt_templates["single_assistant_multi_turn"]["prompt_template"].format(
|
|
108
|
+
question_1=question_1, answer_1=answer_1, question_2=question_2, answer_2=answer_2
|
|
109
|
+
)
|
|
110
|
+
prompts_to_judge.append(PromptToJudge(comparison_type="single_judgement", prompt_text=multi_turn_prompt))
|
|
111
|
+
# Reference answer needed
|
|
112
|
+
elif context.reference:
|
|
113
|
+
# SINGLE TURN
|
|
114
|
+
if len(response.messages) <= 2 and len(context.reference) >= 1:
|
|
115
|
+
# turn 1
|
|
116
|
+
question = response.last_user_instruction
|
|
117
|
+
answer = response.completion
|
|
118
|
+
ref_answer = context.reference[0]
|
|
119
|
+
# format prompt
|
|
120
|
+
single_turn_prompt = prompt_templates["single_assistant_single_turn_w_reference"]["prompt_template"].format(
|
|
121
|
+
question=question,
|
|
122
|
+
answer=answer,
|
|
123
|
+
ref_answer_1=ref_answer,
|
|
124
|
+
)
|
|
125
|
+
prompts_to_judge.append(PromptToJudge(comparison_type="single_judgement", prompt_text=single_turn_prompt))
|
|
126
|
+
# MULTI TURN
|
|
127
|
+
elif len(context.reference) >= 2:
|
|
128
|
+
# turn 1
|
|
129
|
+
question_1 = response.first_user_instruction
|
|
130
|
+
answer_1 = response.messages[1].content
|
|
131
|
+
ref_answer_1 = context.reference[0]
|
|
132
|
+
# turn 2
|
|
133
|
+
question_2 = response.last_user_instruction
|
|
134
|
+
answer_2 = response.completion
|
|
135
|
+
ref_answer_2 = context.reference[1]
|
|
136
|
+
# format prompt
|
|
137
|
+
multi_turn_prompt = prompt_templates["single_assistant_multi_turn_w_reference"]["prompt_template"].format(
|
|
138
|
+
question_1=question_1,
|
|
139
|
+
answer_1=answer_1,
|
|
140
|
+
ref_answer_1=ref_answer_1,
|
|
141
|
+
question_2=question_2,
|
|
142
|
+
answer_2=answer_2,
|
|
143
|
+
ref_answer_2=ref_answer_2,
|
|
144
|
+
)
|
|
145
|
+
prompts_to_judge.append(PromptToJudge(comparison_type="single_judgement", prompt_text=multi_turn_prompt))
|
|
146
|
+
else:
|
|
147
|
+
logger.info(
|
|
148
|
+
f"Warning: No reference answer found for this sample (category: "
|
|
149
|
+
f"{context.category}), even though it is needed."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
return prompts_to_judge
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class MTBenchJudgeSingle(BaseLLMJudgeMetric):
|
|
156
|
+
NAME = "single_judgement"
|
|
157
|
+
|
|
158
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
159
|
+
try:
|
|
160
|
+
prompts_to_judge: list[PromptToJudge] = generate_single_judge_prompts(response)
|
|
161
|
+
|
|
162
|
+
all_metrics = []
|
|
163
|
+
for prompt_to_judge in prompts_to_judge:
|
|
164
|
+
messages = [Message(role=Role.USER, content=prompt_to_judge.prompt_text)]
|
|
165
|
+
output = self._llm_judge.generate_from_messages([messages])
|
|
166
|
+
parsed_output = self._output_to_rating(output[0].completion)
|
|
167
|
+
|
|
168
|
+
all_metrics.append(
|
|
169
|
+
MetricResult(
|
|
170
|
+
metric_name=prompt_to_judge.comparison_type,
|
|
171
|
+
value=parsed_output,
|
|
172
|
+
higher_is_better=True,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return all_metrics
|
|
177
|
+
|
|
178
|
+
except KeyError as e:
|
|
179
|
+
logger.info(f"LLM judge did not produce an expected output, sample will be ignored in aggregations. {e}")
|
|
180
|
+
return []
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def _output_to_rating(output: str) -> float:
|
|
184
|
+
match = re.search(r"\[\[(\d+)\]\]", output)
|
|
185
|
+
|
|
186
|
+
if match:
|
|
187
|
+
return float(match.group(1))
|
|
188
|
+
return 0
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from eval_framework.llm.base import BaseLLM
|
|
2
|
+
from eval_framework.metrics.base import MetricResult
|
|
3
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
4
|
+
from eval_framework.metrics.llm.graders.language import Language
|
|
5
|
+
from eval_framework.metrics.llm.graders.refusal_grader import RefusalGrader
|
|
6
|
+
from eval_framework.shared.types import Completion
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LLMJudgeRefusal(BaseLLMJudgeMetric):
|
|
10
|
+
NAME = "refusal_classifier"
|
|
11
|
+
|
|
12
|
+
def __init__(self, llm_judge: BaseLLM):
|
|
13
|
+
super().__init__(llm_judge)
|
|
14
|
+
self._grader = RefusalGrader(llm_judge)
|
|
15
|
+
|
|
16
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
17
|
+
if response.error is not None:
|
|
18
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=False, error=response.error)]
|
|
19
|
+
|
|
20
|
+
language = response.get_completion_language() or response.get_instruction_language() or "en"
|
|
21
|
+
grading = self._grader.grade(
|
|
22
|
+
completion=response.sanitized_completion,
|
|
23
|
+
language=Language(language),
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
return [
|
|
27
|
+
MetricResult(
|
|
28
|
+
metric_name=self.NAME,
|
|
29
|
+
value=float(grading.is_refusal) if grading.is_refusal is not None else None,
|
|
30
|
+
higher_is_better=True,
|
|
31
|
+
llm_judge_prompt=grading.judge_prompt,
|
|
32
|
+
llm_judge_response=grading.judge_response,
|
|
33
|
+
error=response.error,
|
|
34
|
+
)
|
|
35
|
+
]
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
import atexit
|
|
2
|
+
import logging
|
|
3
|
+
import random
|
|
4
|
+
import re
|
|
5
|
+
import signal
|
|
6
|
+
import sqlite3
|
|
7
|
+
import threading
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from time import sleep
|
|
10
|
+
from typing import Any
|
|
11
|
+
from uuid import uuid4
|
|
12
|
+
|
|
13
|
+
import docker
|
|
14
|
+
import mysql.connector
|
|
15
|
+
import mysql.connector.abstracts
|
|
16
|
+
import psycopg2 # type: ignore
|
|
17
|
+
from pydantic import BaseModel
|
|
18
|
+
|
|
19
|
+
from eval_framework.llm.base import BaseLLM
|
|
20
|
+
from eval_framework.metrics.base import MetricResult
|
|
21
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
22
|
+
from eval_framework.metrics.llm.graders.language import Language
|
|
23
|
+
from eval_framework.metrics.llm.graders.sql_quality_grader import SqlQualityGrader
|
|
24
|
+
from eval_framework.shared.types import Completion, LanguageMetricContext, extract_context_metric
|
|
25
|
+
from eval_framework.tasks.utils import get_docker_address
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SqlDialects(Enum):
|
|
31
|
+
sqlite = "sqlite"
|
|
32
|
+
postgres = "postgresql"
|
|
33
|
+
mysql = "mysql"
|
|
34
|
+
standard_sql = "standard_sql"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SqlOutputComparison(BaseModel):
|
|
38
|
+
matches_results_count: bool
|
|
39
|
+
matches_column_count: bool
|
|
40
|
+
results_equal: bool
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SqlValidationResult(BaseModel):
|
|
44
|
+
success: bool
|
|
45
|
+
schema_error: str | None = None
|
|
46
|
+
query_error: str | None = None
|
|
47
|
+
results: list[Any] = []
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LLMJudgeSqlMetricContext(LanguageMetricContext):
|
|
51
|
+
dialect: str
|
|
52
|
+
db_schema: str
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_DOCKER_LAUNCH_LOCK = threading.Lock()
|
|
56
|
+
_MYSQL_PORT = 0
|
|
57
|
+
_POSTGRES_PORT = 0
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class LLMJudgeSql(BaseLLMJudgeMetric):
|
|
61
|
+
NAME = "SQL Quality"
|
|
62
|
+
|
|
63
|
+
def __init__(self, llm_judge: BaseLLM):
|
|
64
|
+
super().__init__(llm_judge)
|
|
65
|
+
self._grader = SqlQualityGrader(llm_judge)
|
|
66
|
+
|
|
67
|
+
self.postgres_password = "mysecretpassword"
|
|
68
|
+
self.postgres_user = "postgres"
|
|
69
|
+
|
|
70
|
+
self.mysql_password = "mysecretpassword"
|
|
71
|
+
self.mysql_user = "root"
|
|
72
|
+
self.mysql_db_name = "mysql"
|
|
73
|
+
|
|
74
|
+
with _DOCKER_LAUNCH_LOCK:
|
|
75
|
+
if _MYSQL_PORT != 0 and _POSTGRES_PORT != 0:
|
|
76
|
+
return
|
|
77
|
+
self.client = docker.from_env()
|
|
78
|
+
atexit.register(self._shutdown_dbs)
|
|
79
|
+
signal.signal(signal.SIGTERM, lambda *_: self._shutdown_dbs())
|
|
80
|
+
self._start_postgres_db()
|
|
81
|
+
self._start_mysql_db()
|
|
82
|
+
self._wait_for_db_containers()
|
|
83
|
+
|
|
84
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
85
|
+
if response.error is not None:
|
|
86
|
+
return [
|
|
87
|
+
MetricResult(metric_name=f"{self.NAME}/{k}", value=None, higher_is_better=True, error=response.error)
|
|
88
|
+
for k in [
|
|
89
|
+
"successfully_runs",
|
|
90
|
+
"is_just_sql",
|
|
91
|
+
"matches_results_count",
|
|
92
|
+
"matches_column_count",
|
|
93
|
+
"results_equal",
|
|
94
|
+
"llm_quality_score",
|
|
95
|
+
]
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
context = extract_context_metric(response, LLMJudgeSqlMetricContext)
|
|
99
|
+
|
|
100
|
+
assert isinstance(response.ground_truth, str)
|
|
101
|
+
|
|
102
|
+
schema_id = str(uuid4()).replace("-", "_")
|
|
103
|
+
|
|
104
|
+
expected_result = self.validate_query(
|
|
105
|
+
SqlDialects(context.dialect),
|
|
106
|
+
context.db_schema,
|
|
107
|
+
response.ground_truth,
|
|
108
|
+
f"golden_{schema_id}",
|
|
109
|
+
)
|
|
110
|
+
completion_stripped = response.completion.strip().strip("```sql").strip("```")
|
|
111
|
+
completion_query = extract_query_from_completions(completion_stripped)
|
|
112
|
+
if completion_query:
|
|
113
|
+
result = self.validate_query(
|
|
114
|
+
SqlDialects(context.dialect),
|
|
115
|
+
context.db_schema,
|
|
116
|
+
completion_query,
|
|
117
|
+
f"completion_{schema_id}",
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
result = None
|
|
121
|
+
|
|
122
|
+
results = [
|
|
123
|
+
MetricResult(
|
|
124
|
+
metric_name=f"{self.NAME}/successfully_runs",
|
|
125
|
+
value=float(result is not None and result.success),
|
|
126
|
+
higher_is_better=True,
|
|
127
|
+
error=response.error,
|
|
128
|
+
),
|
|
129
|
+
MetricResult(
|
|
130
|
+
metric_name=f"{self.NAME}/is_just_sql",
|
|
131
|
+
value=float(completion_query == completion_stripped),
|
|
132
|
+
higher_is_better=True,
|
|
133
|
+
error=response.error,
|
|
134
|
+
),
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
if result is not None and result.success:
|
|
138
|
+
output_comparison = SqlOutputComparison(
|
|
139
|
+
matches_results_count=len(expected_result.results) == len(result.results),
|
|
140
|
+
matches_column_count=count_result_columns(expected_result.results)
|
|
141
|
+
== count_result_columns(result.results),
|
|
142
|
+
results_equal=expected_result.results == result.results,
|
|
143
|
+
)
|
|
144
|
+
results.extend(
|
|
145
|
+
[
|
|
146
|
+
MetricResult(
|
|
147
|
+
metric_name=f"{self.NAME}/matches_results_count",
|
|
148
|
+
value=float(output_comparison.matches_results_count),
|
|
149
|
+
higher_is_better=True,
|
|
150
|
+
error=response.error,
|
|
151
|
+
),
|
|
152
|
+
MetricResult(
|
|
153
|
+
metric_name=f"{self.NAME}/matches_column_count",
|
|
154
|
+
value=float(output_comparison.matches_column_count),
|
|
155
|
+
higher_is_better=True,
|
|
156
|
+
error=response.error,
|
|
157
|
+
),
|
|
158
|
+
MetricResult(
|
|
159
|
+
metric_name=f"{self.NAME}/results_equal",
|
|
160
|
+
value=float(output_comparison.results_equal),
|
|
161
|
+
higher_is_better=True,
|
|
162
|
+
error=response.error,
|
|
163
|
+
),
|
|
164
|
+
]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
grading = self._grader.grade(
|
|
168
|
+
prompt=response.user_instruction,
|
|
169
|
+
completion=completion_stripped,
|
|
170
|
+
result=result.results if result and result.success else None,
|
|
171
|
+
language=Language(response.get_instruction_language()),
|
|
172
|
+
)
|
|
173
|
+
results.append(
|
|
174
|
+
MetricResult(
|
|
175
|
+
metric_name=f"{self.NAME}/llm_quality_score",
|
|
176
|
+
# [0, 1] normalization required for visualizer
|
|
177
|
+
value=(float(grading.query_quality) - 1) / 4 if grading.query_quality is not None else None,
|
|
178
|
+
higher_is_better=True,
|
|
179
|
+
llm_judge_prompt=grading.judge_prompt,
|
|
180
|
+
llm_judge_response=grading.judge_response,
|
|
181
|
+
error=response.error,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return results
|
|
186
|
+
|
|
187
|
+
def _start_postgres_db(self) -> None:
|
|
188
|
+
global _POSTGRES_PORT
|
|
189
|
+
for _ in range(10): # find a free port
|
|
190
|
+
try:
|
|
191
|
+
_POSTGRES_PORT = random.randint(1000, 65535)
|
|
192
|
+
self.postgres_docker = self.client.containers.run(
|
|
193
|
+
"docker.io/postgres",
|
|
194
|
+
environment={"POSTGRES_PASSWORD": self.postgres_password},
|
|
195
|
+
ports={5432: _POSTGRES_PORT},
|
|
196
|
+
tty=True,
|
|
197
|
+
auto_remove=True,
|
|
198
|
+
detach=True,
|
|
199
|
+
network_mode="bridge",
|
|
200
|
+
)
|
|
201
|
+
break
|
|
202
|
+
except docker.errors.APIError as e:
|
|
203
|
+
if "port is already allocated" not in str(e):
|
|
204
|
+
raise e
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
def _start_mysql_db(self) -> None:
|
|
208
|
+
global _MYSQL_PORT
|
|
209
|
+
for _ in range(10): # find a free port
|
|
210
|
+
try:
|
|
211
|
+
_MYSQL_PORT = random.randint(1000, 65535)
|
|
212
|
+
self.mysql_docker = self.client.containers.run(
|
|
213
|
+
"docker.io/mysql:latest",
|
|
214
|
+
environment={"MYSQL_ROOT_PASSWORD": self.mysql_password, "MYSQL_DATABASE": self.mysql_db_name},
|
|
215
|
+
ports={3306: _MYSQL_PORT},
|
|
216
|
+
tty=True,
|
|
217
|
+
auto_remove=True,
|
|
218
|
+
detach=True,
|
|
219
|
+
network_mode="bridge",
|
|
220
|
+
)
|
|
221
|
+
break
|
|
222
|
+
except docker.errors.APIError as e:
|
|
223
|
+
if "port is already allocated" not in str(e):
|
|
224
|
+
raise e
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
def _wait_for_db_containers(self) -> None:
|
|
228
|
+
for _ in range(600):
|
|
229
|
+
try:
|
|
230
|
+
con = self.connect_to_postgres()
|
|
231
|
+
con.close()
|
|
232
|
+
con = self.connect_to_mysql()
|
|
233
|
+
con.close()
|
|
234
|
+
return
|
|
235
|
+
except Exception:
|
|
236
|
+
logger.info("Could not connect to DBs yet...")
|
|
237
|
+
sleep(1)
|
|
238
|
+
raise Exception("DBs not available.")
|
|
239
|
+
|
|
240
|
+
def _shutdown_dbs(self) -> None:
|
|
241
|
+
if hasattr(self, "postgres_docker"):
|
|
242
|
+
self.postgres_docker.kill()
|
|
243
|
+
if hasattr(self, "mysql_docker"):
|
|
244
|
+
self.mysql_docker.kill()
|
|
245
|
+
|
|
246
|
+
def validate_query(
|
|
247
|
+
self,
|
|
248
|
+
dialect: SqlDialects,
|
|
249
|
+
create_db_statements: str,
|
|
250
|
+
sql_query: str,
|
|
251
|
+
db_schema: str,
|
|
252
|
+
) -> SqlValidationResult:
|
|
253
|
+
match dialect:
|
|
254
|
+
case SqlDialects.sqlite | SqlDialects.standard_sql:
|
|
255
|
+
return self.validate_query_sqlite(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
|
|
256
|
+
case SqlDialects.postgres:
|
|
257
|
+
return self.validate_query_postgres(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
|
|
258
|
+
case SqlDialects.mysql:
|
|
259
|
+
return self.validate_query_mysql(create_db_statements, sql_query, f"{dialect.value}_{db_schema}")
|
|
260
|
+
case _:
|
|
261
|
+
raise NotImplementedError(f"Query validation not implemented for {dialect.value}.")
|
|
262
|
+
|
|
263
|
+
def validate_query_sqlite(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
|
|
264
|
+
con = sqlite3.connect(":memory:")
|
|
265
|
+
cur = con.cursor()
|
|
266
|
+
try:
|
|
267
|
+
statements = separate_statements(create_db_statements)
|
|
268
|
+
for statement in statements:
|
|
269
|
+
cur.execute(statement)
|
|
270
|
+
con.commit()
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.info(f"Create statements are not compatible with SQLite. Reason: {e}")
|
|
273
|
+
return SqlValidationResult(success=False, schema_error=str(e))
|
|
274
|
+
try:
|
|
275
|
+
queries = separate_statements(sql_query)
|
|
276
|
+
for query in queries:
|
|
277
|
+
cur.execute(query)
|
|
278
|
+
con.commit()
|
|
279
|
+
results = cur.fetchall()
|
|
280
|
+
except Exception as e:
|
|
281
|
+
logger.info(f"SQL query is not compatible with SQLite. Reason: {e}")
|
|
282
|
+
return SqlValidationResult(success=False, query_error=str(e))
|
|
283
|
+
|
|
284
|
+
con.close()
|
|
285
|
+
return SqlValidationResult(success=True, results=results)
|
|
286
|
+
|
|
287
|
+
def connect_to_postgres(self) -> psycopg2.extensions.connection:
|
|
288
|
+
conn_params = {
|
|
289
|
+
"dbname": "postgres",
|
|
290
|
+
"user": self.postgres_user,
|
|
291
|
+
"password": self.postgres_password,
|
|
292
|
+
"host": get_docker_address(),
|
|
293
|
+
"port": _POSTGRES_PORT,
|
|
294
|
+
}
|
|
295
|
+
return psycopg2.connect(**conn_params)
|
|
296
|
+
|
|
297
|
+
def validate_query_postgres(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
|
|
298
|
+
con = self.connect_to_postgres()
|
|
299
|
+
cur = con.cursor()
|
|
300
|
+
cur.execute(f"CREATE SCHEMA {db_schema};")
|
|
301
|
+
con.commit()
|
|
302
|
+
cur.execute(f"ALTER USER {self.postgres_user} set SEARCH_PATH = {db_schema};")
|
|
303
|
+
con.commit()
|
|
304
|
+
try:
|
|
305
|
+
statements = separate_statements(create_db_statements)
|
|
306
|
+
for statement in statements:
|
|
307
|
+
cur.execute(statement)
|
|
308
|
+
con.commit()
|
|
309
|
+
except Exception as e:
|
|
310
|
+
logger.info(f"Create statements are not compatible with PostgreSQL. Reason: {e}")
|
|
311
|
+
return SqlValidationResult(success=False, schema_error=str(e))
|
|
312
|
+
try:
|
|
313
|
+
queries = separate_statements(sql_query)
|
|
314
|
+
for query in queries:
|
|
315
|
+
cur.execute(query)
|
|
316
|
+
con.commit()
|
|
317
|
+
results = cur.fetchall()
|
|
318
|
+
except Exception as e:
|
|
319
|
+
logger.info(f"SQL query is not compatible with PostgreSQL. Reason: {e}")
|
|
320
|
+
return SqlValidationResult(success=False, query_error=str(e))
|
|
321
|
+
|
|
322
|
+
con.commit()
|
|
323
|
+
|
|
324
|
+
con.close()
|
|
325
|
+
return SqlValidationResult(success=True, results=results)
|
|
326
|
+
|
|
327
|
+
def connect_to_mysql(
|
|
328
|
+
self,
|
|
329
|
+
) -> mysql.connector.pooling.PooledMySQLConnection | mysql.connector.abstracts.MySQLConnectionAbstract:
|
|
330
|
+
conn_params = {
|
|
331
|
+
"database": self.mysql_db_name,
|
|
332
|
+
"user": self.mysql_user,
|
|
333
|
+
"password": self.mysql_password,
|
|
334
|
+
"host": get_docker_address(),
|
|
335
|
+
"port": _MYSQL_PORT,
|
|
336
|
+
}
|
|
337
|
+
return mysql.connector.connect(**conn_params)
|
|
338
|
+
|
|
339
|
+
def validate_query_mysql(self, create_db_statements: str, sql_query: str, db_schema: str) -> SqlValidationResult:
|
|
340
|
+
con = self.connect_to_mysql()
|
|
341
|
+
cur = con.cursor(buffered=True)
|
|
342
|
+
cur.execute(f"CREATE SCHEMA {db_schema};")
|
|
343
|
+
con.commit()
|
|
344
|
+
cur.execute(f"USE {db_schema};")
|
|
345
|
+
try:
|
|
346
|
+
statements = separate_statements(create_db_statements)
|
|
347
|
+
for statement in statements:
|
|
348
|
+
cur.execute(statement)
|
|
349
|
+
con.commit()
|
|
350
|
+
except Exception as e:
|
|
351
|
+
logger.info(f"Create statements are not compatible with MySQL. Reason: {e}")
|
|
352
|
+
con.close()
|
|
353
|
+
return SqlValidationResult(success=False, schema_error=str(e))
|
|
354
|
+
try:
|
|
355
|
+
queries = separate_statements(sql_query)
|
|
356
|
+
for query in queries:
|
|
357
|
+
cur.execute(query)
|
|
358
|
+
con.commit()
|
|
359
|
+
results = cur.fetchall()
|
|
360
|
+
except Exception as e:
|
|
361
|
+
logger.info(f"SQL query is not compatible with MySQL. Reason: {e}")
|
|
362
|
+
con.close()
|
|
363
|
+
return SqlValidationResult(success=False, query_error=str(e))
|
|
364
|
+
|
|
365
|
+
cur.close()
|
|
366
|
+
con.close()
|
|
367
|
+
return SqlValidationResult(success=True, results=results)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def separate_statements(statements: str) -> list[str]:
|
|
371
|
+
return statements.split(";")[:-1]
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def is_create_table_statement(statement: str) -> bool:
|
|
375
|
+
return "CREATE TABLE" in statement
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def count_result_columns(result: list[Any]) -> int:
|
|
379
|
+
if len(result) == 0:
|
|
380
|
+
return 0
|
|
381
|
+
return len(result[0])
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def extract_query_from_completions(completion: str) -> str | None:
|
|
385
|
+
# Match SQL blocks starting with SELECT or WITH at line start
|
|
386
|
+
# (allowing punctuation/whitespace), ending at first semicolon
|
|
387
|
+
pattern = re.compile(r"(?:^|\n)[^a-zA-Z0-9_]*((?:select|with)\b.*?;)", re.IGNORECASE | re.DOTALL)
|
|
388
|
+
|
|
389
|
+
matches = pattern.findall(completion)
|
|
390
|
+
|
|
391
|
+
# Return the query only if exactly one match is found
|
|
392
|
+
if len(matches) == 1:
|
|
393
|
+
return matches[0].strip()
|
|
394
|
+
return None
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from eval_framework.llm.base import BaseLLM
|
|
2
|
+
from eval_framework.metrics.base import MetricResult
|
|
3
|
+
from eval_framework.metrics.llm.base import BaseLLMJudgeMetric
|
|
4
|
+
from eval_framework.metrics.llm.graders.language import Language
|
|
5
|
+
from eval_framework.metrics.llm.graders.summary_world_knowledge_grader import SummarizationWorldKnowledgeGrader
|
|
6
|
+
from eval_framework.shared.types import Completion
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LLMJudgeWorldKnowledge(BaseLLMJudgeMetric):
|
|
10
|
+
NAME = "World Knowledge"
|
|
11
|
+
|
|
12
|
+
def __init__(self, llm_judge: BaseLLM):
|
|
13
|
+
super().__init__(llm_judge)
|
|
14
|
+
self._grader = SummarizationWorldKnowledgeGrader(llm_judge)
|
|
15
|
+
|
|
16
|
+
def calculate(self, response: Completion) -> list[MetricResult]:
|
|
17
|
+
if response.error is not None:
|
|
18
|
+
return [MetricResult(metric_name=self.NAME, value=None, higher_is_better=False, error=response.error)]
|
|
19
|
+
|
|
20
|
+
language = Language(response.get_instruction_language())
|
|
21
|
+
|
|
22
|
+
grading = self._grader.grade(
|
|
23
|
+
reference_input=response.user_instruction,
|
|
24
|
+
completion=response.sanitized_completion,
|
|
25
|
+
language=language,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return [
|
|
29
|
+
MetricResult(
|
|
30
|
+
metric_name=self.NAME,
|
|
31
|
+
value=float(grading.contains_world_knowledge) if grading.contains_world_knowledge is not None else None,
|
|
32
|
+
higher_is_better=False,
|
|
33
|
+
llm_judge_prompt=grading.judge_prompt,
|
|
34
|
+
llm_judge_response=grading.judge_response,
|
|
35
|
+
error=response.error,
|
|
36
|
+
)
|
|
37
|
+
]
|
|
File without changes
|