ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/evaluate/agent_simulation/__init__.py +87 -0
- ragbits/evaluate/agent_simulation/context.py +118 -0
- ragbits/evaluate/agent_simulation/conversation.py +333 -0
- ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
- ragbits/evaluate/agent_simulation/logger.py +165 -0
- ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
- ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
- ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
- ragbits/evaluate/agent_simulation/models.py +37 -0
- ragbits/evaluate/agent_simulation/results.py +200 -0
- ragbits/evaluate/agent_simulation/scenarios.py +129 -0
- ragbits/evaluate/agent_simulation/simulation.py +243 -0
- ragbits/evaluate/cli.py +150 -0
- ragbits/evaluate/config.py +11 -0
- ragbits/evaluate/dataloaders/__init__.py +3 -0
- ragbits/evaluate/dataloaders/base.py +95 -0
- ragbits/evaluate/dataloaders/document_search.py +61 -0
- ragbits/evaluate/dataloaders/exceptions.py +25 -0
- ragbits/evaluate/dataloaders/gaia.py +78 -0
- ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
- ragbits/evaluate/dataloaders/human_eval.py +70 -0
- ragbits/evaluate/dataloaders/question_answer.py +56 -0
- ragbits/evaluate/dataset_generator/pipeline.py +4 -4
- ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
- ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
- ragbits/evaluate/evaluator.py +178 -50
- ragbits/evaluate/factories/__init__.py +42 -0
- ragbits/evaluate/metrics/__init__.py +2 -23
- ragbits/evaluate/metrics/base.py +40 -17
- ragbits/evaluate/metrics/document_search.py +40 -23
- ragbits/evaluate/metrics/gaia.py +84 -0
- ragbits/evaluate/metrics/hotpot_qa.py +51 -0
- ragbits/evaluate/metrics/human_eval.py +105 -0
- ragbits/evaluate/metrics/question_answer.py +222 -0
- ragbits/evaluate/optimizer.py +138 -86
- ragbits/evaluate/pipelines/__init__.py +37 -0
- ragbits/evaluate/pipelines/base.py +34 -10
- ragbits/evaluate/pipelines/document_search.py +72 -67
- ragbits/evaluate/pipelines/gaia.py +249 -0
- ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
- ragbits/evaluate/pipelines/human_eval.py +323 -0
- ragbits/evaluate/pipelines/question_answer.py +96 -0
- ragbits/evaluate/utils.py +86 -59
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
- ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
- ragbits/evaluate/callbacks/base.py +0 -22
- ragbits/evaluate/callbacks/neptune.py +0 -26
- ragbits/evaluate/loaders/__init__.py +0 -21
- ragbits/evaluate/loaders/base.py +0 -24
- ragbits/evaluate/loaders/hf.py +0 -25
- ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
- /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from statistics import mean
|
|
2
|
+
|
|
3
|
+
from ragbits.evaluate.metrics.base import Metric
|
|
4
|
+
from ragbits.evaluate.pipelines.gaia import GaiaResult
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GaiaOutcome(Metric[GaiaResult]):
|
|
8
|
+
"""
|
|
9
|
+
Computes task success rate over GAIA tasks.
|
|
10
|
+
Measures the fraction of tasks that were successfully solved.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
async def compute(results: list[GaiaResult]) -> dict:
|
|
15
|
+
"""Compute task success rate.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
Dictionary with gaia_task_success_rate: fraction of successfully solved tasks.
|
|
19
|
+
"""
|
|
20
|
+
success_count = sum(1 for r in results if r.task_success)
|
|
21
|
+
success_rate = (success_count / len(results)) if results else 0.0
|
|
22
|
+
|
|
23
|
+
return {"gaia_task_success_rate": float(success_rate)}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GaiaTooling(Metric[GaiaResult]):
|
|
27
|
+
"""
|
|
28
|
+
Tool utilization and performance metrics:
|
|
29
|
+
- gaia_tool_trigger_rate: fraction of tasks where tools were used
|
|
30
|
+
- gaia_avg_num_tool_calls: average number of tool calls per task
|
|
31
|
+
- gaia_avg_tool_error_count: average number of tool errors per task
|
|
32
|
+
- averaged_freq: average tool usage/calls per task
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
async def compute(results: list[GaiaResult]) -> dict:
|
|
37
|
+
"""Compute tool utilization and performance metrics.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Dictionary with tool trigger rate, average tool calls, average errors,
|
|
41
|
+
and flattened tool frequency usage as numeric metrics.
|
|
42
|
+
"""
|
|
43
|
+
tool_triggered_count = sum(1 for r in results if r.tool_triggered)
|
|
44
|
+
tool_trigger_rate = (tool_triggered_count / len(results)) if results else 0.0
|
|
45
|
+
avg_tool_calls = float(mean(r.num_tool_calls for r in results)) if results else 0.0
|
|
46
|
+
avg_tool_errors = float(mean(r.tool_error_count for r in results)) if results else 0.0
|
|
47
|
+
|
|
48
|
+
# tool frequency as average per task (mean calls per task per tool)
|
|
49
|
+
total_tasks = len(results) if results else 1
|
|
50
|
+
aggregated_counts: dict[str, int] = {}
|
|
51
|
+
for r in results:
|
|
52
|
+
if r.tool_names:
|
|
53
|
+
for name in r.tool_names:
|
|
54
|
+
aggregated_counts[name] = aggregated_counts.get(name, 0) + 1
|
|
55
|
+
averaged_freq: dict[str, float] = {
|
|
56
|
+
f"gaia_tool_frequency_usage.{name}": (count / total_tasks) for name, count in aggregated_counts.items()
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
return {
|
|
60
|
+
"gaia_tool_trigger_rate": float(tool_trigger_rate),
|
|
61
|
+
"gaia_avg_num_tool_calls": avg_tool_calls,
|
|
62
|
+
"gaia_avg_tool_error_count": avg_tool_errors,
|
|
63
|
+
**averaged_freq,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GaiaEfficiency(Metric[GaiaResult]):
|
|
68
|
+
"""
|
|
69
|
+
Efficiency and resource usage metrics:
|
|
70
|
+
- gaia_avg_latency_ms: average response latency in milliseconds
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
async def compute(results: list[GaiaResult]) -> dict:
|
|
75
|
+
"""Compute efficiency and resource usage metrics.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Dictionary with average latency.
|
|
79
|
+
"""
|
|
80
|
+
avg_latency = float(mean(r.total_latency_ms for r in results)) if results else 0.0
|
|
81
|
+
|
|
82
|
+
return {
|
|
83
|
+
"gaia_avg_latency_ms": avg_latency,
|
|
84
|
+
}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
|
|
4
|
+
from ragbits.evaluate.metrics.base import Metric
|
|
5
|
+
from ragbits.evaluate.pipelines.hotpot_qa import HotpotQAResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HotpotQAExactMatch(Metric[HotpotQAResult]):
|
|
9
|
+
"""Computes EM over HotpotQA by type and overall."""
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
async def compute(results: list[HotpotQAResult]) -> dict:
|
|
13
|
+
"""Compute EM. Returns hotpotqa_<type>_em and hotpotqa_overall_em."""
|
|
14
|
+
buckets: dict[str, list[float]] = defaultdict(list)
|
|
15
|
+
for r in results:
|
|
16
|
+
em = r.em_value
|
|
17
|
+
t = r.qtype or "unknown"
|
|
18
|
+
buckets[t].append(em)
|
|
19
|
+
buckets["overall"].append(em)
|
|
20
|
+
|
|
21
|
+
def avg(vals: Iterable[float]) -> float:
|
|
22
|
+
lst = list(vals)
|
|
23
|
+
return float(sum(lst) / len(lst)) if lst else 0.0
|
|
24
|
+
|
|
25
|
+
metrics: dict[str, float] = {}
|
|
26
|
+
for t, vals in buckets.items():
|
|
27
|
+
metrics[f"hotpotqa_{t}_em"] = avg(vals)
|
|
28
|
+
return metrics
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class HotpotQAF1(Metric[HotpotQAResult]):
|
|
32
|
+
"""Computes token-level F1 over HotpotQA by type and overall."""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
async def compute(results: list[HotpotQAResult]) -> dict:
|
|
36
|
+
"""Compute F1. Returns hotpotqa_<type>_f1 and hotpotqa_overall_f1."""
|
|
37
|
+
buckets: dict[str, list[float]] = defaultdict(list)
|
|
38
|
+
for r in results:
|
|
39
|
+
f1v = r.f1_value
|
|
40
|
+
t = r.qtype or "unknown"
|
|
41
|
+
buckets[t].append(f1v)
|
|
42
|
+
buckets["overall"].append(f1v)
|
|
43
|
+
|
|
44
|
+
def avg(vals: Iterable[float]) -> float:
|
|
45
|
+
lst = list(vals)
|
|
46
|
+
return float(sum(lst) / len(lst)) if lst else 0.0
|
|
47
|
+
|
|
48
|
+
metrics: dict[str, float] = {}
|
|
49
|
+
for t, vals in buckets.items():
|
|
50
|
+
metrics[f"hotpotqa_{t}_f1"] = avg(vals)
|
|
51
|
+
return metrics
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from statistics import mean
|
|
3
|
+
|
|
4
|
+
from ragbits.evaluate.metrics.base import Metric
|
|
5
|
+
from ragbits.evaluate.pipelines.human_eval import HumanEvalResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HumanEvalPassAtK(Metric[HumanEvalResult]):
|
|
9
|
+
"""
|
|
10
|
+
Computes pass@k over HumanEval tasks.
|
|
11
|
+
Measures the fraction of tasks with at least one passing sample out of k attempts.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, k: int = 1) -> None:
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.k = k
|
|
17
|
+
|
|
18
|
+
async def compute(self, results: list[HumanEvalResult]) -> dict:
|
|
19
|
+
"""Compute pass@k averaged over tasks.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Dictionary with humaneval_pass@k: fraction of tasks with at least one passing sample.
|
|
23
|
+
"""
|
|
24
|
+
values = []
|
|
25
|
+
for r in results:
|
|
26
|
+
n = len(r.passed_mask)
|
|
27
|
+
m = sum(1 for x in r.passed_mask if x)
|
|
28
|
+
k = min(self.k, n)
|
|
29
|
+
if n == 0 or k == 0:
|
|
30
|
+
values.append(0.0)
|
|
31
|
+
continue
|
|
32
|
+
if m == 0:
|
|
33
|
+
values.append(0.0)
|
|
34
|
+
continue
|
|
35
|
+
if m == n:
|
|
36
|
+
values.append(1.0)
|
|
37
|
+
continue
|
|
38
|
+
# 1 - C(n-m, k) / C(n, k)
|
|
39
|
+
denom = math.comb(n, k)
|
|
40
|
+
numer = math.comb(n - m, k) if n - m >= k else 0
|
|
41
|
+
values.append(1.0 - (numer / denom))
|
|
42
|
+
return {f"humaneval_pass@{self.k}": float(mean(values)) if values else 0.0}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class HumanEvalQualityPerf(Metric[HumanEvalResult]):
|
|
46
|
+
"""
|
|
47
|
+
Code quality and execution performance metrics:
|
|
48
|
+
- humaneval_compile_rate: fraction of samples that compiled
|
|
49
|
+
- humaneval_syntax_error_rate: fraction of samples with syntax error (compile failed)
|
|
50
|
+
- humaneval_assert_fail_rate: fraction of samples that ran but failed assertions
|
|
51
|
+
- humaneval_runtime_error_rate: fraction of samples with other runtime errors
|
|
52
|
+
- humaneval_timeout_rate: fraction of samples that timed out
|
|
53
|
+
- humaneval_tasks_solved: fraction of tasks with any passing sample
|
|
54
|
+
- humaneval_avg_exec_time_sec: average exec time over compilable runs
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
async def compute(results: list[HumanEvalResult]) -> dict:
|
|
59
|
+
"""Compute code quality and execution performance metrics.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Dictionary with compile rates, error rates, tasks solved rate, and average execution time.
|
|
63
|
+
"""
|
|
64
|
+
total_samples = sum(len(r.passed_mask) for r in results)
|
|
65
|
+
compiled = 0
|
|
66
|
+
syntax_errors = 0
|
|
67
|
+
assert_fails = 0
|
|
68
|
+
runtime_errors = 0
|
|
69
|
+
timeouts = 0
|
|
70
|
+
any_pass = sum(1 for r in results if any(r.passed_mask))
|
|
71
|
+
durations: list[float] = []
|
|
72
|
+
|
|
73
|
+
for r in results:
|
|
74
|
+
for ok, err, dur in zip(r.compile_ok_mask, r.errors, r.exec_durations_sec, strict=False):
|
|
75
|
+
if ok:
|
|
76
|
+
compiled += 1
|
|
77
|
+
durations.append(dur)
|
|
78
|
+
if err:
|
|
79
|
+
if err.startswith("AssertionError"):
|
|
80
|
+
assert_fails += 1
|
|
81
|
+
elif err.startswith("TimeoutError"):
|
|
82
|
+
timeouts += 1
|
|
83
|
+
else:
|
|
84
|
+
runtime_errors += 1
|
|
85
|
+
else:
|
|
86
|
+
# Compile failed: count as syntax error
|
|
87
|
+
syntax_errors += 1
|
|
88
|
+
|
|
89
|
+
compile_rate = (compiled / total_samples) if total_samples else 0.0
|
|
90
|
+
syntax_error_rate = (syntax_errors / total_samples) if total_samples else 0.0
|
|
91
|
+
assert_fail_rate = (assert_fails / total_samples) if total_samples else 0.0
|
|
92
|
+
runtime_error_rate = (runtime_errors / total_samples) if total_samples else 0.0
|
|
93
|
+
timeout_rate = (timeouts / total_samples) if total_samples else 0.0
|
|
94
|
+
tasks_solved = (any_pass / len(results)) if results else 0.0
|
|
95
|
+
avg_exec_time = float(mean(durations)) if durations else 0.0
|
|
96
|
+
|
|
97
|
+
return {
|
|
98
|
+
"humaneval_compile_rate": float(compile_rate),
|
|
99
|
+
"humaneval_syntax_error_rate": float(syntax_error_rate),
|
|
100
|
+
"humaneval_assert_fail_rate": float(assert_fail_rate),
|
|
101
|
+
"humaneval_runtime_error_rate": float(runtime_error_rate),
|
|
102
|
+
"humaneval_timeout_rate": float(timeout_rate),
|
|
103
|
+
"humaneval_tasks_solved": float(tasks_solved),
|
|
104
|
+
"humaneval_avg_exec_time_sec": avg_exec_time,
|
|
105
|
+
}
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from asyncio import AbstractEventLoop
|
|
4
|
+
from itertools import chain
|
|
5
|
+
from typing import Generic, TypeVar
|
|
6
|
+
|
|
7
|
+
from typing_extensions import Self
|
|
8
|
+
|
|
9
|
+
from ragbits.agents.types import QuestionAnswerPromptOutputT
|
|
10
|
+
from ragbits.core.llms.base import LLM
|
|
11
|
+
from ragbits.core.utils.helpers import batched
|
|
12
|
+
from ragbits.evaluate.metrics.base import Metric
|
|
13
|
+
from ragbits.evaluate.pipelines.question_answer import QuestionAnswerResult
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from continuous_eval.llm_factory import LLMInterface
|
|
17
|
+
from continuous_eval.metrics.base import LLMBasedMetric
|
|
18
|
+
from continuous_eval.metrics.generation.text import (
|
|
19
|
+
LLMBasedAnswerCorrectness,
|
|
20
|
+
LLMBasedAnswerRelevance,
|
|
21
|
+
LLMBasedFaithfulness,
|
|
22
|
+
LLMBasedStyleConsistency,
|
|
23
|
+
)
|
|
24
|
+
except ModuleNotFoundError:
|
|
25
|
+
from continuous_eval.llms.base import LLMInterface
|
|
26
|
+
from continuous_eval.metrics import Metric as LLMBasedMetric
|
|
27
|
+
from continuous_eval.metrics.generation.text import (
|
|
28
|
+
AnswerCorrectness as LLMBasedAnswerCorrectness,
|
|
29
|
+
)
|
|
30
|
+
from continuous_eval.metrics.generation.text import (
|
|
31
|
+
AnswerRelevance as LLMBasedAnswerRelevance,
|
|
32
|
+
)
|
|
33
|
+
from continuous_eval.metrics.generation.text import (
|
|
34
|
+
Faithfulness as LLMBasedFaithfulness,
|
|
35
|
+
)
|
|
36
|
+
from continuous_eval.metrics.generation.text import (
|
|
37
|
+
StyleConsistency as LLMBasedStyleConsistency,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
MetricT = TypeVar("MetricT", bound=LLMBasedMetric)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class _MetricLMM(LLMInterface):
|
|
44
|
+
"""
|
|
45
|
+
Implementation of required interface of Relari generative metrics based on LiteLMM.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, llm: LLM, loop: AbstractEventLoop) -> None:
|
|
49
|
+
self._llm = llm
|
|
50
|
+
self._loop = loop
|
|
51
|
+
|
|
52
|
+
def run(self, prompt: dict[str, str], temperature: float = 0, max_tokens: int = 1024) -> str:
|
|
53
|
+
formatted_prompt = [
|
|
54
|
+
{"role": "system", "content": prompt["system_prompt"]},
|
|
55
|
+
{"role": "user", "content": prompt["user_prompt"]},
|
|
56
|
+
]
|
|
57
|
+
options = self._llm.options_cls(
|
|
58
|
+
temperature=temperature,
|
|
59
|
+
max_tokens=max_tokens,
|
|
60
|
+
)
|
|
61
|
+
return asyncio.run_coroutine_threadsafe(
|
|
62
|
+
self._llm.generate(formatted_prompt, options=options),
|
|
63
|
+
self._loop,
|
|
64
|
+
).result()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class QuestionAnswerMetric(Generic[MetricT], Metric[QuestionAnswerResult], ABC):
|
|
68
|
+
"""
|
|
69
|
+
Metric for question answer evaluation based on Relari backend.
|
|
70
|
+
More details can be found [here](https://docs.relari.ai/category/text-generation).
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
metric_cls: type[MetricT]
|
|
74
|
+
|
|
75
|
+
def __init__(self, llm: LLM, batch_size: int = 15, weight: float = 1.0) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Initialize the agent metric.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
llm: Judge LLM instance.
|
|
81
|
+
batch_size: Batch size for metric computation.
|
|
82
|
+
weight: Metric value weight in the final score, used during optimization.
|
|
83
|
+
"""
|
|
84
|
+
super().__init__(weight=weight)
|
|
85
|
+
self.llm = llm
|
|
86
|
+
self.batch_size = batch_size
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def from_config(cls, config: dict) -> Self:
|
|
90
|
+
"""
|
|
91
|
+
Create an instance of `QuestionAnswerMetric` from a configuration dictionary.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
config: A dictionary containing configuration settings for the metric.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
An instance of the metric class initialized with the provided configuration.
|
|
98
|
+
"""
|
|
99
|
+
config["llm"] = LLM.from_config(config["llm"])
|
|
100
|
+
config["batch_size"] = config.get("batch_size", 15)
|
|
101
|
+
config["weight"] = config.get("weight", 1.0)
|
|
102
|
+
return super().from_config(config)
|
|
103
|
+
|
|
104
|
+
async def compute(self, results: list[QuestionAnswerResult[QuestionAnswerPromptOutputT]]) -> dict:
|
|
105
|
+
"""
|
|
106
|
+
Compute the metric.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
results: The evaluation results.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The computed metric.
|
|
113
|
+
"""
|
|
114
|
+
metric = self.metric_cls(_MetricLMM(self.llm, loop=asyncio.get_running_loop()))
|
|
115
|
+
metric_results = chain.from_iterable(
|
|
116
|
+
[
|
|
117
|
+
await asyncio.gather(*[asyncio.to_thread(self._call_metric, metric, result) for result in batch])
|
|
118
|
+
for batch in batched(results, self.batch_size)
|
|
119
|
+
]
|
|
120
|
+
)
|
|
121
|
+
return metric.aggregate(list(metric_results))
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def _call_metric(metric: MetricT, result: QuestionAnswerResult[QuestionAnswerPromptOutputT]) -> dict:
|
|
126
|
+
"""
|
|
127
|
+
Call the metric with the proper arguments.
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class QuestionAnswerAnswerCorrectness(QuestionAnswerMetric[LLMBasedAnswerCorrectness]):
|
|
132
|
+
"""
|
|
133
|
+
Metric checking answer correctness based on LLM.
|
|
134
|
+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_correctness).
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
metric_cls: type[LLMBasedAnswerCorrectness] = LLMBasedAnswerCorrectness
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _call_metric(
|
|
141
|
+
metric: LLMBasedAnswerCorrectness,
|
|
142
|
+
result: QuestionAnswerResult[QuestionAnswerPromptOutputT],
|
|
143
|
+
) -> dict:
|
|
144
|
+
return metric(
|
|
145
|
+
question=result.question,
|
|
146
|
+
answer=(
|
|
147
|
+
result.predicted_result.content
|
|
148
|
+
if isinstance(result.predicted_result.content, str)
|
|
149
|
+
else result.predicted_result.content.answer
|
|
150
|
+
),
|
|
151
|
+
ground_truth_answers=result.reference_answer,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class QuestionAnswerAnswerFaithfulness(QuestionAnswerMetric[LLMBasedFaithfulness]):
|
|
156
|
+
"""
|
|
157
|
+
Metric checking answer faithfulness based on LLM.
|
|
158
|
+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_faithfulness).
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
metric_cls: type[LLMBasedFaithfulness] = LLMBasedFaithfulness
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def _call_metric(
|
|
165
|
+
metric: LLMBasedFaithfulness,
|
|
166
|
+
result: QuestionAnswerResult[QuestionAnswerPromptOutputT],
|
|
167
|
+
) -> dict:
|
|
168
|
+
return metric(
|
|
169
|
+
question=result.question,
|
|
170
|
+
answer=(
|
|
171
|
+
result.predicted_result.content
|
|
172
|
+
if isinstance(result.predicted_result.content, str)
|
|
173
|
+
else result.predicted_result.content.answer
|
|
174
|
+
),
|
|
175
|
+
retrieved_context=result.reference_context,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class QuestionAnswerAnswerRelevance(QuestionAnswerMetric[LLMBasedAnswerRelevance]):
|
|
180
|
+
"""
|
|
181
|
+
Metric checking answer relevance based on LLM.
|
|
182
|
+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_relevance).
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
metric_cls: type[LLMBasedAnswerRelevance] = LLMBasedAnswerRelevance
|
|
186
|
+
|
|
187
|
+
@staticmethod
|
|
188
|
+
def _call_metric(
|
|
189
|
+
metric: LLMBasedAnswerRelevance,
|
|
190
|
+
result: QuestionAnswerResult[QuestionAnswerPromptOutputT],
|
|
191
|
+
) -> dict:
|
|
192
|
+
return metric(
|
|
193
|
+
question=result.question,
|
|
194
|
+
answer=(
|
|
195
|
+
result.predicted_result.content
|
|
196
|
+
if isinstance(result.predicted_result.content, str)
|
|
197
|
+
else result.predicted_result.content.answer
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class QuestionAnswerAnswerConsistency(QuestionAnswerMetric[LLMBasedStyleConsistency]):
|
|
203
|
+
"""
|
|
204
|
+
Metric checking answer relevance based on LLM.
|
|
205
|
+
More details can be found [here](https://docs.relari.ai/metrics/Generation/LLM-Based/llm_style).
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
metric_cls: type[LLMBasedStyleConsistency] = LLMBasedStyleConsistency
|
|
209
|
+
|
|
210
|
+
@staticmethod
|
|
211
|
+
def _call_metric(
|
|
212
|
+
metric: LLMBasedStyleConsistency,
|
|
213
|
+
result: QuestionAnswerResult[QuestionAnswerPromptOutputT],
|
|
214
|
+
) -> dict:
|
|
215
|
+
return metric(
|
|
216
|
+
answer=(
|
|
217
|
+
result.predicted_result.content
|
|
218
|
+
if isinstance(result.predicted_result.content, str)
|
|
219
|
+
else result.predicted_result.content.answer
|
|
220
|
+
),
|
|
221
|
+
ground_truth_answers=result.reference_answer,
|
|
222
|
+
)
|