graphrag-eval 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphrag_eval/__init__.py +4 -0
- graphrag_eval/aggregation.py +151 -0
- graphrag_eval/answer_correctness.py +162 -0
- graphrag_eval/answer_relevance.py +37 -0
- graphrag_eval/evaluation.py +62 -0
- graphrag_eval/steps/__init__.py +120 -0
- graphrag_eval/steps/retrieval.py +55 -0
- graphrag_eval/steps/sparql.py +139 -0
- graphrag_eval-4.0.0.dist-info/LICENSE +201 -0
- graphrag_eval-4.0.0.dist-info/METADATA +967 -0
- graphrag_eval-4.0.0.dist-info/RECORD +13 -0
- graphrag_eval-4.0.0.dist-info/WHEEL +4 -0
- graphrag_eval-4.0.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from statistics import mean, median
|
|
4
|
+
from typing import Any, Iterable
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
METRICS = [
|
|
8
|
+
"answer_recall",
|
|
9
|
+
"answer_precision",
|
|
10
|
+
"answer_relevance",
|
|
11
|
+
"answer_relevance_cost",
|
|
12
|
+
"answer_f1",
|
|
13
|
+
"steps_score",
|
|
14
|
+
"input_tokens",
|
|
15
|
+
"output_tokens",
|
|
16
|
+
"total_tokens",
|
|
17
|
+
"elapsed_sec"
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
PROTECTED_METRICS = [
|
|
21
|
+
"input_tokens",
|
|
22
|
+
"output_tokens",
|
|
23
|
+
"total_tokens",
|
|
24
|
+
"elapsed_sec"
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def stats_for_series(values: Iterable[int | float]) -> dict[str, float]:
|
|
29
|
+
return {
|
|
30
|
+
"sum": sum(values),
|
|
31
|
+
"mean": mean(values) if values else 0,
|
|
32
|
+
"median": median(values) if values else 0,
|
|
33
|
+
"min": min(values) if values else 0,
|
|
34
|
+
"max": max(values) if values else 0,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def update_stats_per_template(
|
|
39
|
+
sample: dict,
|
|
40
|
+
stats_per_template: dict,
|
|
41
|
+
template_id: str
|
|
42
|
+
):
|
|
43
|
+
for metric in METRICS:
|
|
44
|
+
value = sample.get(metric)
|
|
45
|
+
if value is not None:
|
|
46
|
+
stats_per_template[template_id][metric].append(value)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def update_steps_summary_per_template(
|
|
50
|
+
sample: dict,
|
|
51
|
+
steps_summary_per_template: dict,
|
|
52
|
+
template_id: str
|
|
53
|
+
):
|
|
54
|
+
seen = set()
|
|
55
|
+
for step in sample.get("actual_steps", []):
|
|
56
|
+
name = step["name"]
|
|
57
|
+
template_steps_summary = steps_summary_per_template[template_id]
|
|
58
|
+
template_steps_summary["total"][name] += 1
|
|
59
|
+
if step["status"] == "error":
|
|
60
|
+
template_steps_summary["errors"][name] += 1
|
|
61
|
+
if name not in seen:
|
|
62
|
+
seen.add(name)
|
|
63
|
+
template_steps_summary["once_per_sample"][name] += 1
|
|
64
|
+
|
|
65
|
+
if step["status"] != "error":
|
|
66
|
+
try:
|
|
67
|
+
res = json.loads(step["output"])
|
|
68
|
+
if "results" in res and "bindings" in res["results"]:
|
|
69
|
+
if not res["results"]["bindings"]:
|
|
70
|
+
template_steps_summary["empty_results"][name] += 1
|
|
71
|
+
except json.decoder.JSONDecodeError:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def compute_aggregates(samples: list[dict]) -> dict:
|
|
76
|
+
number_of_samples_per_template_by_status = defaultdict(lambda: defaultdict(int))
|
|
77
|
+
stats_per_template = defaultdict(lambda: defaultdict(list))
|
|
78
|
+
steps_summary_per_template = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
|
|
79
|
+
|
|
80
|
+
# Compute per-template stats
|
|
81
|
+
templates_ids = set()
|
|
82
|
+
for sample in samples:
|
|
83
|
+
template_id = sample["template_id"]
|
|
84
|
+
templates_ids.add(template_id)
|
|
85
|
+
|
|
86
|
+
if "error" in sample:
|
|
87
|
+
number_of_samples_per_template_by_status[template_id]["error"] += 1
|
|
88
|
+
continue
|
|
89
|
+
number_of_samples_per_template_by_status[template_id]["success"] += 1
|
|
90
|
+
|
|
91
|
+
update_stats_per_template(sample, stats_per_template, template_id)
|
|
92
|
+
update_steps_summary_per_template(
|
|
93
|
+
sample,
|
|
94
|
+
steps_summary_per_template,
|
|
95
|
+
template_id
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
summary = {"per_template": {}}
|
|
99
|
+
|
|
100
|
+
# Add per-template stats
|
|
101
|
+
for template_id in templates_ids:
|
|
102
|
+
template_summary: dict[str, Any] = {
|
|
103
|
+
"number_of_error_samples": number_of_samples_per_template_by_status[template_id]["error"],
|
|
104
|
+
"number_of_success_samples": number_of_samples_per_template_by_status[template_id]["success"],
|
|
105
|
+
}
|
|
106
|
+
steps_summary = {
|
|
107
|
+
k1: {k2: v2 for k2, v2 in v1.items()}
|
|
108
|
+
for k1, v1 in steps_summary_per_template[template_id].items()
|
|
109
|
+
}
|
|
110
|
+
if steps_summary:
|
|
111
|
+
template_summary.update({"steps": steps_summary})
|
|
112
|
+
for metric in METRICS:
|
|
113
|
+
results_for_template = stats_per_template[template_id]
|
|
114
|
+
series = results_for_template.get(metric, [])
|
|
115
|
+
if series or metric in PROTECTED_METRICS:
|
|
116
|
+
template_summary[metric] = stats_for_series(series)
|
|
117
|
+
|
|
118
|
+
summary["per_template"][template_id] = template_summary
|
|
119
|
+
|
|
120
|
+
# Add micro stats
|
|
121
|
+
values_ = number_of_samples_per_template_by_status.values()
|
|
122
|
+
summary["micro"] = {
|
|
123
|
+
"number_of_error_samples": sum(
|
|
124
|
+
values["error"] for values in values_
|
|
125
|
+
),
|
|
126
|
+
"number_of_success_samples": sum(
|
|
127
|
+
values["success"] for values in values_
|
|
128
|
+
),
|
|
129
|
+
}
|
|
130
|
+
for metric in METRICS:
|
|
131
|
+
series = [
|
|
132
|
+
i
|
|
133
|
+
for values in stats_per_template.values()
|
|
134
|
+
for i in values[metric]
|
|
135
|
+
if values.get(metric) is not None
|
|
136
|
+
]
|
|
137
|
+
if series or metric in PROTECTED_METRICS:
|
|
138
|
+
summary["micro"][metric] = stats_for_series(series)
|
|
139
|
+
|
|
140
|
+
# Add macro stats
|
|
141
|
+
summary["macro"] = {}
|
|
142
|
+
for metric in METRICS:
|
|
143
|
+
means = [
|
|
144
|
+
values[metric]["mean"]
|
|
145
|
+
for template_id, values in summary["per_template"].items()
|
|
146
|
+
if values.get(metric) is not None
|
|
147
|
+
]
|
|
148
|
+
if means or metric in PROTECTED_METRICS:
|
|
149
|
+
summary["macro"][metric] = {"mean": mean(means) if means else 0}
|
|
150
|
+
|
|
151
|
+
return summary
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from openai import OpenAI
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
IN_FILE_PATH = "../data/data-1.tsv"
|
|
9
|
+
PROMPT_FILE_PATH = "prompts/template.md"
|
|
10
|
+
OUT_FILE_PATH = "results/data-1.tsv"
|
|
11
|
+
OUT_FIELDS = ["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"]
|
|
12
|
+
LLM_MODEL = "gpt-4o-mini"
|
|
13
|
+
TEMPERATURE = 0.0
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compute_recall_precision_f1(
|
|
18
|
+
n_pos: int | None,
|
|
19
|
+
n_pred_pos: int | None,
|
|
20
|
+
n_true_pos: int | None,
|
|
21
|
+
) -> tuple[float | None, float | None, float | None]:
|
|
22
|
+
recall = None
|
|
23
|
+
precision = None
|
|
24
|
+
f1 = None
|
|
25
|
+
if n_true_pos is not None and n_pos:
|
|
26
|
+
recall = n_true_pos / n_pos
|
|
27
|
+
if n_true_pos is not None and n_pred_pos:
|
|
28
|
+
precision = n_true_pos / n_pred_pos
|
|
29
|
+
if precision is not None and recall is not None and precision + recall > 0:
|
|
30
|
+
f1 = 2 * (precision * recall) / (precision + recall)
|
|
31
|
+
return recall, precision, f1
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def extract_response_values(
|
|
35
|
+
response: str
|
|
36
|
+
) -> tuple[int | None, int | None, int | None, str, str]:
|
|
37
|
+
vals = response.split("\t")
|
|
38
|
+
n = len(vals)
|
|
39
|
+
if n < 4:
|
|
40
|
+
msg = f"Expected 4 tab-separated values: {response}"
|
|
41
|
+
return None, None, None, "", msg
|
|
42
|
+
vals = vals[:4]
|
|
43
|
+
try:
|
|
44
|
+
n_ref, n_target, n_matching = map(int, vals[:3])
|
|
45
|
+
except ValueError:
|
|
46
|
+
msg = f"Non-int value: {response}"
|
|
47
|
+
return None, None, None, vals[3], msg
|
|
48
|
+
if any([
|
|
49
|
+
n_ref < 1,
|
|
50
|
+
n_target < 1,
|
|
51
|
+
n_matching < 0,
|
|
52
|
+
n_matching > n_ref,
|
|
53
|
+
n_matching > n_target
|
|
54
|
+
]):
|
|
55
|
+
msg = f"Invalid int values: {n_ref}\t{n_target}\t{n_matching}"
|
|
56
|
+
return None, None, None, vals[3], msg
|
|
57
|
+
return n_ref, n_target, n_matching, vals[3], ""
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AnswerCorrectnessEvaluator:
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
prompt_file_path: str | Path = PROMPT_FILE_PATH,
|
|
64
|
+
temperature : float = TEMPERATURE
|
|
65
|
+
):
|
|
66
|
+
with open(prompt_file_path, encoding="utf-8") as f:
|
|
67
|
+
self.prompt_template = f.read()
|
|
68
|
+
self.openai_client = OpenAI()
|
|
69
|
+
self.temperature = temperature
|
|
70
|
+
|
|
71
|
+
def call_llm(self, prompt: str) -> str:
|
|
72
|
+
try:
|
|
73
|
+
response = self.openai_client.chat.completions.create(
|
|
74
|
+
model=LLM_MODEL,
|
|
75
|
+
messages=[{"role": "user", "content": prompt}],
|
|
76
|
+
temperature=self.temperature
|
|
77
|
+
)
|
|
78
|
+
return response.choices[0].message.content.strip("\n")
|
|
79
|
+
except Exception as e:
|
|
80
|
+
return str(e).replace("\n", " ")
|
|
81
|
+
|
|
82
|
+
def evaluate_answer(
|
|
83
|
+
self,
|
|
84
|
+
question: str,
|
|
85
|
+
reference_answer: str,
|
|
86
|
+
actual_answer: str
|
|
87
|
+
):
|
|
88
|
+
prompt = self.prompt_template.format(
|
|
89
|
+
question=question,
|
|
90
|
+
reference_answer=reference_answer,
|
|
91
|
+
candidate_answer=actual_answer,
|
|
92
|
+
)
|
|
93
|
+
response_str = self.call_llm(prompt)
|
|
94
|
+
return extract_response_values(response_str)
|
|
95
|
+
|
|
96
|
+
def get_correctness_dict(
|
|
97
|
+
self,
|
|
98
|
+
reference: dict,
|
|
99
|
+
target: dict,
|
|
100
|
+
):
|
|
101
|
+
result = {}
|
|
102
|
+
result["reference_answer"] = reference["reference_answer"]
|
|
103
|
+
num_ref_claims, num_actual_claims, num_matching_claims, reason, error = \
|
|
104
|
+
self.evaluate_answer(
|
|
105
|
+
reference["question_text"],
|
|
106
|
+
reference["reference_answer"],
|
|
107
|
+
target["actual_answer"],
|
|
108
|
+
)
|
|
109
|
+
if error:
|
|
110
|
+
result["answer_eval_error"] = error
|
|
111
|
+
else:
|
|
112
|
+
result.update({
|
|
113
|
+
"answer_reference_claims_count": num_ref_claims,
|
|
114
|
+
"answer_actual_claims_count": num_actual_claims,
|
|
115
|
+
"answer_matching_claims_count": num_matching_claims,
|
|
116
|
+
"answer_correctness_reason": reason,
|
|
117
|
+
})
|
|
118
|
+
recall, precision, f1 = compute_recall_precision_f1(
|
|
119
|
+
num_ref_claims, num_actual_claims, num_matching_claims
|
|
120
|
+
)
|
|
121
|
+
if recall is not None:
|
|
122
|
+
result["answer_recall"] = recall
|
|
123
|
+
if precision is not None:
|
|
124
|
+
result["answer_precision"] = precision
|
|
125
|
+
if f1 is not None:
|
|
126
|
+
result["answer_f1"] = f1
|
|
127
|
+
return result
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def evaluate_and_write(
|
|
131
|
+
in_file_path: str | Path,
|
|
132
|
+
out_file_path: str | Path,
|
|
133
|
+
) -> None:
|
|
134
|
+
evaluator = AnswerCorrectnessEvaluator(PROMPT_FILE_PATH)
|
|
135
|
+
with open(in_file_path, encoding="utf-8") as f:
|
|
136
|
+
reader = csv.DictReader(f, delimiter="\t")
|
|
137
|
+
rows = [row for row in reader]
|
|
138
|
+
print(f"Writing results to {out_file_path}")
|
|
139
|
+
Path(out_file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
140
|
+
with open(out_file_path, "w", encoding="utf-8") as f:
|
|
141
|
+
writer = csv.writer(f, delimiter="\t")
|
|
142
|
+
writer.writerow(OUT_FIELDS)
|
|
143
|
+
for row in tqdm(rows):
|
|
144
|
+
vals = evaluator.evaluate_answer(
|
|
145
|
+
row["Question"],
|
|
146
|
+
row["Reference answer"],
|
|
147
|
+
row["Actual answer"]
|
|
148
|
+
)
|
|
149
|
+
writer.writerow(vals)
|
|
150
|
+
f.flush()
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def main():
|
|
154
|
+
import argparse
|
|
155
|
+
parser = argparse.ArgumentParser()
|
|
156
|
+
parser.add_argument("-i", "--in-file", type=str, default=IN_FILE_PATH)
|
|
157
|
+
parser.add_argument("-o", "--out-file", type=str, default=OUT_FILE_PATH)
|
|
158
|
+
args = parser.parse_args()
|
|
159
|
+
evaluate_and_write(
|
|
160
|
+
in_file_path=args.in_file,
|
|
161
|
+
out_file_path=args.out_file,
|
|
162
|
+
)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from langevals_ragas.response_relevancy import (
|
|
2
|
+
RagasResponseRelevancyEvaluator,
|
|
3
|
+
RagasResponseRelevancyEntry
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_relevance_dict(
|
|
8
|
+
question_text: str,
|
|
9
|
+
actual_answer: str,
|
|
10
|
+
model_name : str = 'openai/gpt-4o-mini',
|
|
11
|
+
max_tokens: int = 65_536
|
|
12
|
+
) -> dict:
|
|
13
|
+
settings_dict = {
|
|
14
|
+
'model': model_name,
|
|
15
|
+
'max_tokens': max_tokens
|
|
16
|
+
}
|
|
17
|
+
entry = RagasResponseRelevancyEntry(
|
|
18
|
+
input=question_text,
|
|
19
|
+
output=actual_answer
|
|
20
|
+
)
|
|
21
|
+
evaluator = RagasResponseRelevancyEvaluator(settings=settings_dict)
|
|
22
|
+
try:
|
|
23
|
+
result = evaluator.evaluate(entry)
|
|
24
|
+
if result.status == "processed":
|
|
25
|
+
return {
|
|
26
|
+
"answer_relevance": result.score,
|
|
27
|
+
"answer_relevance_cost": result.cost.amount,
|
|
28
|
+
"answer_relevance_reason": result.details,
|
|
29
|
+
}
|
|
30
|
+
else:
|
|
31
|
+
return {
|
|
32
|
+
"answer_relevance_error": result.details
|
|
33
|
+
}
|
|
34
|
+
except Exception as e:
|
|
35
|
+
return {
|
|
36
|
+
"answer_relevance_error": str(e),
|
|
37
|
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from .steps import get_steps_evaluation_result_dict
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def run_evaluation(
|
|
5
|
+
qa_dataset: list[dict],
|
|
6
|
+
responses_dict: dict,
|
|
7
|
+
) -> list[dict]:
|
|
8
|
+
# Output metrics are not nested, for simpler aggregation
|
|
9
|
+
answer_correctess_evaluator = None
|
|
10
|
+
evaluation_results = []
|
|
11
|
+
for template in qa_dataset:
|
|
12
|
+
template_id = template["template_id"]
|
|
13
|
+
for question in template["questions"]:
|
|
14
|
+
actual_result = responses_dict[question["id"]]
|
|
15
|
+
eval_result = {
|
|
16
|
+
"template_id": template_id,
|
|
17
|
+
"question_id": actual_result["question_id"],
|
|
18
|
+
"question_text": question["question_text"]
|
|
19
|
+
}
|
|
20
|
+
if "reference_answer" in question:
|
|
21
|
+
eval_result["reference_answer"] = question["reference_answer"]
|
|
22
|
+
if "reference_steps" in question:
|
|
23
|
+
eval_result["reference_steps"] = question["reference_steps"]
|
|
24
|
+
if "error" in actual_result:
|
|
25
|
+
eval_result.update({
|
|
26
|
+
"status": "error",
|
|
27
|
+
"error": actual_result["error"],
|
|
28
|
+
})
|
|
29
|
+
evaluation_results.append(eval_result)
|
|
30
|
+
continue
|
|
31
|
+
eval_result["status"] = "success"
|
|
32
|
+
if "actual_answer" in actual_result:
|
|
33
|
+
eval_result["actual_answer"] = actual_result["actual_answer"]
|
|
34
|
+
from graphrag_eval import answer_relevance
|
|
35
|
+
eval_result.update(
|
|
36
|
+
answer_relevance.get_relevance_dict(
|
|
37
|
+
question["question_text"],
|
|
38
|
+
actual_result["actual_answer"],
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
if "reference_answer" in question and "actual_answer" in actual_result:
|
|
42
|
+
from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator
|
|
43
|
+
if not answer_correctess_evaluator:
|
|
44
|
+
answer_correctess_evaluator = AnswerCorrectnessEvaluator()
|
|
45
|
+
eval_result.update(
|
|
46
|
+
answer_correctess_evaluator.get_correctness_dict(
|
|
47
|
+
question,
|
|
48
|
+
actual_result,
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
if "steps" in actual_result:
|
|
52
|
+
eval_result.update(
|
|
53
|
+
get_steps_evaluation_result_dict(question, actual_result)
|
|
54
|
+
)
|
|
55
|
+
eval_result.update({
|
|
56
|
+
"input_tokens": actual_result["input_tokens"],
|
|
57
|
+
"output_tokens": actual_result["output_tokens"],
|
|
58
|
+
"total_tokens": actual_result["total_tokens"],
|
|
59
|
+
"elapsed_sec": actual_result["elapsed_sec"],
|
|
60
|
+
})
|
|
61
|
+
evaluation_results.append(eval_result)
|
|
62
|
+
return evaluation_results
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
from .retrieval import recall_at_k
|
|
5
|
+
from .sparql import compare_sparql_results
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def compare_steps_outputs(reference: dict, actual: dict) -> float:
|
|
9
|
+
ref_output = reference["output"]
|
|
10
|
+
act_output = actual["output"]
|
|
11
|
+
if reference.get("output_media_type") == "application/sparql-results+json":
|
|
12
|
+
return compare_sparql_results(
|
|
13
|
+
json.loads(ref_output),
|
|
14
|
+
json.loads(act_output),
|
|
15
|
+
reference["required_columns"],
|
|
16
|
+
reference.get("ordered", False),
|
|
17
|
+
)
|
|
18
|
+
if reference.get("output_media_type") == "application/json":
|
|
19
|
+
return float(json.loads(ref_output) == json.loads(act_output))
|
|
20
|
+
if reference["name"] == "retrieval":
|
|
21
|
+
k = reference["args"]["k"]
|
|
22
|
+
return recall_at_k(ref_output, act_output, k)
|
|
23
|
+
return float(ref_output == act_output)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def match_group_by_output(
|
|
27
|
+
reference_steps: list[list[dict]],
|
|
28
|
+
group_idx: int,
|
|
29
|
+
actual_steps: list[dict],
|
|
30
|
+
candidates_by_name: dict[str, list[int]],
|
|
31
|
+
) -> list[tuple[int, int, int, float]]:
|
|
32
|
+
used_actual_indices = set()
|
|
33
|
+
matches = []
|
|
34
|
+
|
|
35
|
+
reference_group = reference_steps[group_idx]
|
|
36
|
+
for reference_idx, reference_step in enumerate(reference_group):
|
|
37
|
+
name = reference_step["name"]
|
|
38
|
+
candidates = reversed(candidates_by_name.get(name, []))
|
|
39
|
+
for actual_idx in candidates:
|
|
40
|
+
if actual_idx in used_actual_indices:
|
|
41
|
+
continue
|
|
42
|
+
actual_step = actual_steps[actual_idx]
|
|
43
|
+
score = compare_steps_outputs(reference_step, actual_step)
|
|
44
|
+
if score > 0.0:
|
|
45
|
+
matches.append((group_idx, reference_idx, actual_idx, score))
|
|
46
|
+
used_actual_indices.add(actual_idx)
|
|
47
|
+
break
|
|
48
|
+
return matches
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def collect_possible_matches_by_name_and_status(
|
|
52
|
+
group: list[dict],
|
|
53
|
+
actual_steps: list[dict],
|
|
54
|
+
search_upto: int,
|
|
55
|
+
) -> dict[str, list[int]]:
|
|
56
|
+
group_by_name = defaultdict(list)
|
|
57
|
+
|
|
58
|
+
for j in range(search_upto):
|
|
59
|
+
name = actual_steps[j]["name"]
|
|
60
|
+
if actual_steps[j]["status"] == "success":
|
|
61
|
+
group_by_name[name].append(j)
|
|
62
|
+
|
|
63
|
+
reference_names = {item["name"] for item in group}
|
|
64
|
+
return {name: group_by_name[name] for name in reference_names if name in group_by_name}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_steps_matches(
|
|
68
|
+
reference_steps: list[list[dict]],
|
|
69
|
+
actual_steps: list[dict],
|
|
70
|
+
) -> list[tuple[int, int, int, float]]:
|
|
71
|
+
# when we have autocomplete
|
|
72
|
+
# matches = []
|
|
73
|
+
# search_upto = len(actual_steps)
|
|
74
|
+
# for group_idx in reversed(range(len(reference_steps))):
|
|
75
|
+
# group = reference_steps[group_idx]
|
|
76
|
+
# candidates = collect_possible_matches_by_name(group, actual_steps, search_upto)
|
|
77
|
+
#
|
|
78
|
+
# matched = match_group_by_output(reference_steps, group_idx, actual_steps, candidates)
|
|
79
|
+
# if len(matched) == len(group):
|
|
80
|
+
# # update search_upto to just before the highest matched actual index
|
|
81
|
+
# matches.extend(matched)
|
|
82
|
+
# search_upto = min(j for (_, j) in matched)
|
|
83
|
+
# elif len(matched) < len(group):
|
|
84
|
+
# matches.extend(matched)
|
|
85
|
+
# break # a step is not matched and missing, abort
|
|
86
|
+
# else:
|
|
87
|
+
# break # a step is not matched and missing, abort
|
|
88
|
+
# return matches
|
|
89
|
+
|
|
90
|
+
# for now, we have only the last step(s)
|
|
91
|
+
last_group = reference_steps[-1]
|
|
92
|
+
candidates = collect_possible_matches_by_name_and_status(last_group, actual_steps, len(actual_steps))
|
|
93
|
+
return match_group_by_output(reference_steps, -1, actual_steps, candidates)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def evaluate_steps(
|
|
97
|
+
reference_steps_groups: list[list[dict]],
|
|
98
|
+
actual_steps: list[dict]
|
|
99
|
+
) -> float:
|
|
100
|
+
matches = get_steps_matches(reference_steps_groups, actual_steps)
|
|
101
|
+
matches_by_group = defaultdict(list)
|
|
102
|
+
scores_by_group = defaultdict(float)
|
|
103
|
+
for ref_group_idx, ref_match_idx, actual_idx, score in matches:
|
|
104
|
+
matches_by_group[ref_group_idx].append(ref_match_idx)
|
|
105
|
+
scores_by_group[ref_group_idx] += score
|
|
106
|
+
reference_steps_groups[ref_group_idx][ref_match_idx]["matches"] \
|
|
107
|
+
= actual_steps[actual_idx]["id"]
|
|
108
|
+
group_ix = -1 # For now, consider only the last reference group of steps
|
|
109
|
+
return scores_by_group[group_ix] / len(reference_steps_groups[group_ix])
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_steps_evaluation_result_dict(reference: dict, target: dict) -> dict:
|
|
113
|
+
act_steps = target["steps"]
|
|
114
|
+
eval_result = {}
|
|
115
|
+
eval_result["actual_steps"] = act_steps
|
|
116
|
+
if "reference_steps" in reference:
|
|
117
|
+
ref_steps = reference["reference_steps"]
|
|
118
|
+
steps_score = evaluate_steps(ref_steps, act_steps)
|
|
119
|
+
eval_result["steps_score"] = steps_score
|
|
120
|
+
return eval_result
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Iterable
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def recall_at_k(relevant_docs: Iterable, retrieved_docs: list, k: int = 10) -> float:
|
|
5
|
+
"""
|
|
6
|
+
Calculates Recall@k.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
relevant_docs (Iterable): A set of ground truth relevant document IDs.
|
|
10
|
+
retrieved_docs (list): A list of retrieved document IDs, ordered by rank.
|
|
11
|
+
k (int): The cutoff for the retrieval list.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
float: The Recall@k score.
|
|
15
|
+
"""
|
|
16
|
+
retrieved_at_k = retrieved_docs[:k]
|
|
17
|
+
|
|
18
|
+
relevant_set = set(relevant_docs)
|
|
19
|
+
retrieved_set = set(retrieved_at_k)
|
|
20
|
+
true_positives = len(relevant_set.intersection(retrieved_set))
|
|
21
|
+
|
|
22
|
+
total_relevant = len(relevant_set)
|
|
23
|
+
|
|
24
|
+
if total_relevant == 0:
|
|
25
|
+
return 0.0
|
|
26
|
+
|
|
27
|
+
return true_positives / total_relevant
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def average_precision(relevant_docs: Iterable, retrieved_docs: Iterable) -> float:
|
|
31
|
+
"""
|
|
32
|
+
Calculates Average Precision (AP) for a single query.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
relevant_docs (Iterable): A set of ground truth relevant document IDs.
|
|
36
|
+
retrieved_docs (Iterable): A list of retrieved document IDs, ordered by rank.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
float: The Average Precision score.
|
|
40
|
+
"""
|
|
41
|
+
relevant_set = set(relevant_docs)
|
|
42
|
+
hits = 0
|
|
43
|
+
sum_of_precisions = 0.0
|
|
44
|
+
|
|
45
|
+
for i, doc_id in enumerate(retrieved_docs):
|
|
46
|
+
if doc_id in relevant_set:
|
|
47
|
+
hits += 1
|
|
48
|
+
precision_at_k = hits / (i + 1)
|
|
49
|
+
sum_of_precisions += precision_at_k
|
|
50
|
+
|
|
51
|
+
total_relevant = len(relevant_set)
|
|
52
|
+
if total_relevant == 0:
|
|
53
|
+
return 0.0
|
|
54
|
+
|
|
55
|
+
return sum_of_precisions / total_relevant
|