ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/evaluate/agent_simulation/__init__.py +87 -0
- ragbits/evaluate/agent_simulation/context.py +118 -0
- ragbits/evaluate/agent_simulation/conversation.py +333 -0
- ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
- ragbits/evaluate/agent_simulation/logger.py +165 -0
- ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
- ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
- ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
- ragbits/evaluate/agent_simulation/models.py +37 -0
- ragbits/evaluate/agent_simulation/results.py +200 -0
- ragbits/evaluate/agent_simulation/scenarios.py +129 -0
- ragbits/evaluate/agent_simulation/simulation.py +243 -0
- ragbits/evaluate/cli.py +150 -0
- ragbits/evaluate/config.py +11 -0
- ragbits/evaluate/dataloaders/__init__.py +3 -0
- ragbits/evaluate/dataloaders/base.py +95 -0
- ragbits/evaluate/dataloaders/document_search.py +61 -0
- ragbits/evaluate/dataloaders/exceptions.py +25 -0
- ragbits/evaluate/dataloaders/gaia.py +78 -0
- ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
- ragbits/evaluate/dataloaders/human_eval.py +70 -0
- ragbits/evaluate/dataloaders/question_answer.py +56 -0
- ragbits/evaluate/dataset_generator/pipeline.py +4 -4
- ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
- ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
- ragbits/evaluate/evaluator.py +178 -50
- ragbits/evaluate/factories/__init__.py +42 -0
- ragbits/evaluate/metrics/__init__.py +2 -23
- ragbits/evaluate/metrics/base.py +40 -17
- ragbits/evaluate/metrics/document_search.py +40 -23
- ragbits/evaluate/metrics/gaia.py +84 -0
- ragbits/evaluate/metrics/hotpot_qa.py +51 -0
- ragbits/evaluate/metrics/human_eval.py +105 -0
- ragbits/evaluate/metrics/question_answer.py +222 -0
- ragbits/evaluate/optimizer.py +138 -86
- ragbits/evaluate/pipelines/__init__.py +37 -0
- ragbits/evaluate/pipelines/base.py +34 -10
- ragbits/evaluate/pipelines/document_search.py +72 -67
- ragbits/evaluate/pipelines/gaia.py +249 -0
- ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
- ragbits/evaluate/pipelines/human_eval.py +323 -0
- ragbits/evaluate/pipelines/question_answer.py +96 -0
- ragbits/evaluate/utils.py +86 -59
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
- ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
- ragbits/evaluate/callbacks/base.py +0 -22
- ragbits/evaluate/callbacks/neptune.py +0 -26
- ragbits/evaluate/loaders/__init__.py +0 -21
- ragbits/evaluate/loaders/base.py +0 -24
- ragbits/evaluate/loaders/hf.py +0 -25
- ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
- /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
|
@@ -23,7 +23,7 @@ class BasicAnswerGenPrompt(Prompt[BasicAnswerGenInput, str]):
|
|
|
23
23
|
"If you don't know the answer just say: I don't know."
|
|
24
24
|
)
|
|
25
25
|
|
|
26
|
-
user_prompt: str = "Text:\n<|text_start|>\n {{ chunk }} \n<|text_end|>\n\nQuestion:\n
|
|
26
|
+
user_prompt: str = "Text:\n<|text_start|>\n {{ chunk }} \n<|text_end|>\n\nQuestion:\n {{ question }} \n\nAnswer:"
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class PassagesGenInput(BaseModel):
|
|
@@ -49,9 +49,7 @@ class PassagesGenPrompt(Prompt[PassagesGenInput, str]):
|
|
|
49
49
|
"FULL SENTENCES"
|
|
50
50
|
)
|
|
51
51
|
|
|
52
|
-
user_prompt: str =
|
|
53
|
-
"Question:\n {{ question }} \nAnswer:\n {{ basic_answer }} \nChunk:\n " "{{ chunk }}\n\nPassages:"
|
|
54
|
-
)
|
|
52
|
+
user_prompt: str = "Question:\n {{ question }} \nAnswer:\n {{ basic_answer }} \nChunk:\n {{ chunk }}\n\nPassages:"
|
|
55
53
|
|
|
56
54
|
|
|
57
55
|
class QueryGenInput(BaseModel):
|
|
@@ -7,7 +7,7 @@ from distilabel.steps.base import Step
|
|
|
7
7
|
|
|
8
8
|
from ragbits.core.llms.base import LLM
|
|
9
9
|
from ragbits.core.prompt import Prompt
|
|
10
|
-
from ragbits.core.utils.config_handling import
|
|
10
|
+
from ragbits.core.utils.config_handling import import_by_path
|
|
11
11
|
|
|
12
12
|
module = sys.modules[__name__]
|
|
13
13
|
|
|
@@ -23,9 +23,7 @@ class CorpusGenerationStep(Step):
|
|
|
23
23
|
):
|
|
24
24
|
super().__init__()
|
|
25
25
|
self._llm = llm
|
|
26
|
-
self._prompt_class = (
|
|
27
|
-
get_cls_from_config(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
|
|
28
|
-
)
|
|
26
|
+
self._prompt_class = import_by_path(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
|
|
29
27
|
self._num_per_topic = num_per_topic
|
|
30
28
|
|
|
31
29
|
@property
|
|
@@ -2,11 +2,11 @@ import sys
|
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
from distilabel.
|
|
5
|
+
from distilabel.models import LLM
|
|
6
6
|
from distilabel.steps.tasks import TextGeneration
|
|
7
7
|
|
|
8
8
|
from ragbits.core.prompt import ChatFormat, Prompt
|
|
9
|
-
from ragbits.core.utils.config_handling import
|
|
9
|
+
from ragbits.core.utils.config_handling import import_by_path
|
|
10
10
|
|
|
11
11
|
module = sys.modules[__name__]
|
|
12
12
|
|
|
@@ -18,9 +18,7 @@ class BaseDistilabelTask(TextGeneration, ABC):
|
|
|
18
18
|
super().__init__(llm=llm)
|
|
19
19
|
self._inputs = inputs
|
|
20
20
|
self._outputs = outputs
|
|
21
|
-
self._prompt_class = (
|
|
22
|
-
get_cls_from_config(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
|
|
23
|
-
)
|
|
21
|
+
self._prompt_class = import_by_path(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
|
|
24
22
|
|
|
25
23
|
@property
|
|
26
24
|
def inputs(self) -> list[str]:
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
|
-
from distilabel.
|
|
3
|
+
from distilabel.models import LLM
|
|
4
4
|
|
|
5
|
-
from
|
|
6
|
-
from .
|
|
5
|
+
from ragbits.evaluate.dataset_generator.tasks.text_generation.base import BaseDistilabelTask
|
|
6
|
+
from ragbits.evaluate.dataset_generator.utils import get_closest_substring, get_passages_list
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class QueryGenTask(BaseDistilabelTask):
|
ragbits/evaluate/evaluator.py
CHANGED
|
@@ -1,53 +1,153 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import random
|
|
1
3
|
import time
|
|
2
|
-
from collections.abc import Iterable
|
|
3
|
-
from dataclasses import
|
|
4
|
-
from typing import
|
|
4
|
+
from collections.abc import Awaitable, Callable, Iterable, Sized
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Generic, ParamSpec, TypeVar
|
|
5
7
|
|
|
6
|
-
from
|
|
8
|
+
from pydantic import BaseModel
|
|
9
|
+
from tqdm import tqdm
|
|
7
10
|
|
|
8
|
-
from ragbits.
|
|
11
|
+
from ragbits.core.utils.config_handling import ObjectConstructionConfig, WithConstructionConfig
|
|
12
|
+
from ragbits.core.utils.helpers import batched
|
|
13
|
+
from ragbits.evaluate.dataloaders.base import DataLoader
|
|
9
14
|
from ragbits.evaluate.metrics.base import MetricSet
|
|
10
|
-
from ragbits.evaluate.pipelines.base import EvaluationPipeline,
|
|
15
|
+
from ragbits.evaluate.pipelines.base import EvaluationDataT, EvaluationPipeline, EvaluationResultT, EvaluationTargetT
|
|
11
16
|
|
|
17
|
+
_CallP = ParamSpec("_CallP")
|
|
18
|
+
_CallReturnT = TypeVar("_CallReturnT")
|
|
12
19
|
|
|
13
|
-
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class EvaluationTimePerf:
|
|
23
|
+
"""
|
|
24
|
+
Container for evaluation time performance metrics.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
total_time_in_seconds: float
|
|
28
|
+
samples_per_second: float
|
|
29
|
+
latency_in_seconds: float
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class EvaluatorResult(Generic[EvaluationResultT]):
|
|
34
|
+
"""
|
|
35
|
+
Container for evaluation results.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
metrics: dict[str, int | float]
|
|
39
|
+
results: list[EvaluationResultT]
|
|
40
|
+
errors: list[Exception]
|
|
41
|
+
time_perf: EvaluationTimePerf
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class EvaluationConfig(BaseModel):
|
|
45
|
+
"""
|
|
46
|
+
Schema for the evaluation run config.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
pipeline: ObjectConstructionConfig
|
|
50
|
+
dataloader: ObjectConstructionConfig
|
|
51
|
+
metrics: dict[str, ObjectConstructionConfig]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class EvaluatorConfig(BaseModel):
|
|
55
|
+
"""
|
|
56
|
+
Schema for the evaluator config.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
evaluation: EvaluationConfig
|
|
60
|
+
evaluator: dict | None = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Evaluator(WithConstructionConfig):
|
|
14
64
|
"""
|
|
15
65
|
Evaluator class.
|
|
16
66
|
"""
|
|
17
67
|
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
batch_size: int = 10,
|
|
71
|
+
num_retries: int = 3,
|
|
72
|
+
backoff_multiplier: int = 1,
|
|
73
|
+
backoff_max: int = 60,
|
|
74
|
+
parallelize_batches: bool = False,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""
|
|
77
|
+
Initialize the Evaluator instance.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
batch_size: batch size for the evaluation pipeline inference.
|
|
81
|
+
num_retries: The number of retries per evaluation pipeline inference error.
|
|
82
|
+
backoff_multiplier: The base delay multiplier for exponential backoff (in seconds).
|
|
83
|
+
backoff_max: The maximum allowed delay (in seconds) between retries.
|
|
84
|
+
parallelize_batches: Whether to process samples within each batch in parallel (asyncio.gather).
|
|
85
|
+
"""
|
|
86
|
+
self.batch_size = batch_size
|
|
87
|
+
self.num_retries = num_retries
|
|
88
|
+
self.backoff_multiplier = backoff_multiplier
|
|
89
|
+
self.backoff_max = backoff_max
|
|
90
|
+
self.parallelize_batches = parallelize_batches
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
async def run_from_config(cls, config: dict) -> EvaluatorResult:
|
|
94
|
+
"""
|
|
95
|
+
Run the evaluation based on configuration.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
config: Evaluation config.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
The evaluation results.
|
|
102
|
+
"""
|
|
103
|
+
evaluator_config = EvaluatorConfig.model_validate(config)
|
|
104
|
+
evaluation_config = EvaluationConfig.model_validate(evaluator_config.evaluation)
|
|
105
|
+
pipeline: EvaluationPipeline = EvaluationPipeline.subclass_from_config(evaluation_config.pipeline)
|
|
106
|
+
dataloader: DataLoader = DataLoader.subclass_from_config(evaluation_config.dataloader)
|
|
107
|
+
metricset: MetricSet = MetricSet.from_config(evaluation_config.metrics)
|
|
108
|
+
|
|
109
|
+
evaluator = cls.from_config(evaluator_config.evaluator or {})
|
|
110
|
+
return await evaluator.compute(
|
|
111
|
+
pipeline=pipeline,
|
|
112
|
+
dataloader=dataloader,
|
|
113
|
+
metricset=metricset,
|
|
114
|
+
)
|
|
115
|
+
|
|
18
116
|
async def compute(
|
|
19
117
|
self,
|
|
20
|
-
pipeline: EvaluationPipeline,
|
|
21
|
-
dataloader: DataLoader,
|
|
22
|
-
|
|
23
|
-
) ->
|
|
118
|
+
pipeline: EvaluationPipeline[EvaluationTargetT, EvaluationDataT, EvaluationResultT],
|
|
119
|
+
dataloader: DataLoader[EvaluationDataT],
|
|
120
|
+
metricset: MetricSet[EvaluationResultT],
|
|
121
|
+
) -> EvaluatorResult[EvaluationResultT]:
|
|
24
122
|
"""
|
|
25
123
|
Compute the evaluation results for the given pipeline and data.
|
|
26
124
|
|
|
27
125
|
Args:
|
|
28
126
|
pipeline: The pipeline to be evaluated.
|
|
29
127
|
dataloader: The dataloader to load the data.
|
|
30
|
-
|
|
128
|
+
metricset: The metrics to be computed.
|
|
31
129
|
|
|
32
130
|
Returns:
|
|
33
131
|
The evaluation results.
|
|
34
132
|
"""
|
|
133
|
+
await pipeline.prepare()
|
|
134
|
+
|
|
35
135
|
dataset = await dataloader.load()
|
|
36
|
-
results,
|
|
37
|
-
|
|
38
|
-
processed_results = self._results_processor(results)
|
|
136
|
+
results, errors, time_perf = await self._call_pipeline(pipeline, dataset)
|
|
137
|
+
metrics = await metricset.compute(results)
|
|
39
138
|
|
|
40
|
-
return
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
139
|
+
return EvaluatorResult(
|
|
140
|
+
metrics=metrics,
|
|
141
|
+
results=results,
|
|
142
|
+
errors=errors,
|
|
143
|
+
time_perf=time_perf,
|
|
144
|
+
)
|
|
45
145
|
|
|
46
146
|
async def _call_pipeline(
|
|
47
147
|
self,
|
|
48
|
-
pipeline: EvaluationPipeline,
|
|
49
|
-
dataset: Iterable,
|
|
50
|
-
) -> tuple[list[
|
|
148
|
+
pipeline: EvaluationPipeline[EvaluationTargetT, EvaluationDataT, EvaluationResultT],
|
|
149
|
+
dataset: Iterable[EvaluationDataT],
|
|
150
|
+
) -> tuple[list[EvaluationResultT], list[Exception], EvaluationTimePerf]:
|
|
51
151
|
"""
|
|
52
152
|
Call the pipeline with the given data.
|
|
53
153
|
|
|
@@ -59,39 +159,69 @@ class Evaluator:
|
|
|
59
159
|
The evaluation results and performance metrics.
|
|
60
160
|
"""
|
|
61
161
|
start_time = time.perf_counter()
|
|
62
|
-
|
|
162
|
+
|
|
163
|
+
total_samples = len(dataset) if isinstance(dataset, Sized) else None
|
|
164
|
+
batches = batched(dataset, self.batch_size)
|
|
165
|
+
outputs: list[Iterable[EvaluationResultT] | Exception] = []
|
|
166
|
+
|
|
167
|
+
with tqdm(total=total_samples, desc="Evaluation", unit="sample") as progress_bar:
|
|
168
|
+
for batch in batches:
|
|
169
|
+
batch_list = list(batch)
|
|
170
|
+
|
|
171
|
+
if self.parallelize_batches:
|
|
172
|
+
tasks = [self._call_with_error_handling(pipeline, [sample]) for sample in batch_list]
|
|
173
|
+
batch_results = await asyncio.gather(*tasks)
|
|
174
|
+
|
|
175
|
+
for result in batch_results:
|
|
176
|
+
outputs.append(result)
|
|
177
|
+
progress_bar.update(1)
|
|
178
|
+
else:
|
|
179
|
+
result = await self._call_with_error_handling(pipeline, batch_list)
|
|
180
|
+
outputs.append(result)
|
|
181
|
+
progress_bar.update(len(batch_list))
|
|
182
|
+
|
|
63
183
|
end_time = time.perf_counter()
|
|
64
|
-
return pipe_outputs, self._compute_time_perf(start_time, end_time, len(pipe_outputs))
|
|
65
184
|
|
|
66
|
-
|
|
67
|
-
|
|
185
|
+
errors = [output for output in outputs if isinstance(output, Exception)]
|
|
186
|
+
results = [item for output in outputs if not isinstance(output, Exception) for item in output]
|
|
187
|
+
|
|
188
|
+
return results, errors, self._compute_time_perf(start_time, end_time, len(results))
|
|
189
|
+
|
|
190
|
+
async def _call_with_error_handling(
|
|
191
|
+
self,
|
|
192
|
+
executable: Callable[_CallP, Awaitable[_CallReturnT]],
|
|
193
|
+
*executable_args: _CallP.args,
|
|
194
|
+
**executable_kwargs: _CallP.kwargs,
|
|
195
|
+
) -> _CallReturnT | Exception:
|
|
68
196
|
"""
|
|
69
|
-
|
|
197
|
+
Call executable with a standarized error handling.
|
|
198
|
+
If an error occurs, the executable is retried `num_retries` times using randomized exponential backoff.
|
|
70
199
|
|
|
71
200
|
Args:
|
|
72
|
-
|
|
201
|
+
executable: The callable function to execute.
|
|
202
|
+
executable_args: Positional arguments to pass to the executable.
|
|
203
|
+
executable_kwargs: Keyword arguments to pass to the executable.
|
|
73
204
|
|
|
74
205
|
Returns:
|
|
75
|
-
The
|
|
76
|
-
"""
|
|
77
|
-
return {"results": [asdict(result) for result in results]}
|
|
206
|
+
The result of the executable if successful.
|
|
78
207
|
|
|
79
|
-
|
|
80
|
-
|
|
208
|
+
Raises:
|
|
209
|
+
Exception: The last encountered exception after all retries are exhausted.
|
|
81
210
|
"""
|
|
82
|
-
|
|
211
|
+
for i in range(max(0, self.num_retries) + 1):
|
|
212
|
+
try:
|
|
213
|
+
return await executable(*executable_args, **executable_kwargs)
|
|
214
|
+
except Exception as exc:
|
|
215
|
+
if i == self.num_retries:
|
|
216
|
+
return exc
|
|
83
217
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
results: The evaluation results.
|
|
218
|
+
delay = random.uniform(0, min(2**i * self.backoff_multiplier, self.backoff_max)) # noqa: S311
|
|
219
|
+
await asyncio.sleep(delay)
|
|
87
220
|
|
|
88
|
-
|
|
89
|
-
The computed metric.
|
|
90
|
-
"""
|
|
91
|
-
return {"metrics": metrics.compute(results)}
|
|
221
|
+
raise RuntimeError("Unreachable code reached") # mypy quirk
|
|
92
222
|
|
|
93
223
|
@staticmethod
|
|
94
|
-
def _compute_time_perf(start_time: float, end_time: float, num_samples: int) ->
|
|
224
|
+
def _compute_time_perf(start_time: float, end_time: float, num_samples: int) -> EvaluationTimePerf:
|
|
95
225
|
"""
|
|
96
226
|
Compute the performance metrics.
|
|
97
227
|
|
|
@@ -107,10 +237,8 @@ class Evaluator:
|
|
|
107
237
|
throughput = num_samples / latency
|
|
108
238
|
latency_sample = 1.0 / throughput if throughput > 0 else 0.0
|
|
109
239
|
|
|
110
|
-
return
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
},
|
|
116
|
-
}
|
|
240
|
+
return EvaluationTimePerf(
|
|
241
|
+
total_time_in_seconds=latency,
|
|
242
|
+
samples_per_second=throughput,
|
|
243
|
+
latency_in_seconds=latency_sample,
|
|
244
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from continuous_eval.metrics.retrieval.matching_strategy import RougeChunkMatch
|
|
4
|
+
from datasets import load_dataset
|
|
5
|
+
|
|
6
|
+
from ragbits.core.embeddings.dense import LiteLLMEmbedder
|
|
7
|
+
from ragbits.core.sources.hf import HuggingFaceSource
|
|
8
|
+
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
|
|
9
|
+
from ragbits.document_search import DocumentSearch
|
|
10
|
+
from ragbits.document_search.documents.document import DocumentMeta
|
|
11
|
+
from ragbits.evaluate.dataloaders.document_search import DocumentSearchDataLoader
|
|
12
|
+
from ragbits.evaluate.metrics import MetricSet
|
|
13
|
+
from ragbits.evaluate.metrics.document_search import DocumentSearchPrecisionRecallF1
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def _add_example_documents(document_search: DocumentSearch) -> None:
|
|
17
|
+
dataset = load_dataset(path="deepsense-ai/synthetic-rag-dataset_v1.0", split="train")
|
|
18
|
+
documents = [DocumentMeta.from_literal(doc) for chunks in dataset["chunks"] for doc in chunks]
|
|
19
|
+
await document_search.ingest(documents)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def basic_document_search_factory() -> DocumentSearch:
|
|
23
|
+
"""
|
|
24
|
+
Factory for basic example document search instance.
|
|
25
|
+
"""
|
|
26
|
+
document_search: DocumentSearch = DocumentSearch(vector_store=InMemoryVectorStore(embedder=LiteLLMEmbedder()))
|
|
27
|
+
asyncio.run(_add_example_documents(document_search))
|
|
28
|
+
return document_search
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def synthetic_rag_dataset() -> DocumentSearchDataLoader:
|
|
32
|
+
"""
|
|
33
|
+
Factory for synthetic RAG dataset.
|
|
34
|
+
"""
|
|
35
|
+
return DocumentSearchDataLoader(source=HuggingFaceSource(path="deepsense-ai/synthetic-rag-dataset_v1.0"))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def precision_recall_f1() -> MetricSet:
|
|
39
|
+
"""
|
|
40
|
+
Factory of precision recall f1 metric set for retrival evaluation.
|
|
41
|
+
"""
|
|
42
|
+
return MetricSet(DocumentSearchPrecisionRecallF1(matching_strategy=RougeChunkMatch()))
|
|
@@ -1,24 +1,3 @@
|
|
|
1
|
-
import
|
|
1
|
+
from ragbits.evaluate.metrics.base import Metric, MetricSet
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from ragbits.core.utils.config_handling import get_cls_from_config
|
|
6
|
-
|
|
7
|
-
from .base import MetricSet
|
|
8
|
-
|
|
9
|
-
module = sys.modules[__name__]
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def metric_set_factory(cfg: ListConfig) -> MetricSet:
|
|
13
|
-
"""
|
|
14
|
-
A function creating MetricSet instance from the configuration
|
|
15
|
-
Args:
|
|
16
|
-
cfg - metric cnfiguration
|
|
17
|
-
Returns:
|
|
18
|
-
MetricSet
|
|
19
|
-
"""
|
|
20
|
-
metrics = []
|
|
21
|
-
for metric_cfg in cfg:
|
|
22
|
-
metric_module = get_cls_from_config(metric_cfg.type, module)
|
|
23
|
-
metrics.append(metric_module(metric_cfg))
|
|
24
|
-
return MetricSet(*metrics)
|
|
3
|
+
__all__ = ["Metric", "MetricSet"]
|
ragbits/evaluate/metrics/base.py
CHANGED
|
@@ -1,31 +1,35 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
from abc import ABC, abstractmethod
|
|
2
|
-
from
|
|
3
|
+
from types import ModuleType
|
|
4
|
+
from typing import ClassVar, Generic
|
|
3
5
|
|
|
4
|
-
from
|
|
6
|
+
from typing_extensions import Self
|
|
5
7
|
|
|
6
|
-
from ragbits.
|
|
8
|
+
from ragbits.core.utils.config_handling import WithConstructionConfig
|
|
9
|
+
from ragbits.evaluate import metrics
|
|
10
|
+
from ragbits.evaluate.pipelines.base import EvaluationResultT
|
|
7
11
|
|
|
8
|
-
ResultT = TypeVar("ResultT", bound=EvaluationResult)
|
|
9
12
|
|
|
10
|
-
|
|
11
|
-
class Metric(Generic[ResultT], ABC):
|
|
13
|
+
class Metric(WithConstructionConfig, Generic[EvaluationResultT], ABC):
|
|
12
14
|
"""
|
|
13
15
|
Base class for metrics.
|
|
14
16
|
"""
|
|
15
17
|
|
|
16
|
-
|
|
18
|
+
default_module: ClassVar[ModuleType | None] = metrics
|
|
19
|
+
configuration_key: ClassVar[str] = "metric"
|
|
20
|
+
|
|
21
|
+
def __init__(self, weight: float = 1.0) -> None:
|
|
17
22
|
"""
|
|
18
|
-
|
|
23
|
+
Initialize the metric.
|
|
19
24
|
|
|
20
25
|
Args:
|
|
21
|
-
|
|
26
|
+
weight: Metric value weight in the final score, used during optimization.
|
|
22
27
|
"""
|
|
23
28
|
super().__init__()
|
|
24
|
-
self.
|
|
25
|
-
self.weight: float = getattr(self.config, "weight", 1.0)
|
|
29
|
+
self.weight = weight
|
|
26
30
|
|
|
27
31
|
@abstractmethod
|
|
28
|
-
def compute(self, results: list[
|
|
32
|
+
async def compute(self, results: list[EvaluationResultT]) -> dict:
|
|
29
33
|
"""
|
|
30
34
|
Compute the metric.
|
|
31
35
|
|
|
@@ -37,21 +41,37 @@ class Metric(Generic[ResultT], ABC):
|
|
|
37
41
|
"""
|
|
38
42
|
|
|
39
43
|
|
|
40
|
-
class MetricSet(Generic[
|
|
44
|
+
class MetricSet(WithConstructionConfig, Generic[EvaluationResultT]):
|
|
41
45
|
"""
|
|
42
46
|
Represents a set of metrics.
|
|
43
47
|
"""
|
|
44
48
|
|
|
45
|
-
|
|
49
|
+
configuration_key: ClassVar[str] = "metrics"
|
|
50
|
+
default_module: ClassVar[ModuleType | None] = metrics
|
|
51
|
+
|
|
52
|
+
def __init__(self, *metrics: Metric[EvaluationResultT]) -> None:
|
|
46
53
|
"""
|
|
47
|
-
|
|
54
|
+
Initialize the metric set.
|
|
48
55
|
|
|
49
56
|
Args:
|
|
50
57
|
metrics: The metrics.
|
|
51
58
|
"""
|
|
52
59
|
self.metrics = metrics
|
|
53
60
|
|
|
54
|
-
|
|
61
|
+
@classmethod
|
|
62
|
+
def from_config(cls, config: dict) -> Self:
|
|
63
|
+
"""
|
|
64
|
+
Create an instance of `MetricSet` from a configuration dictionary.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
config: A dictionary containing configuration settings for the metric set.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
An instance of the metric set class initialized with the provided configuration.
|
|
71
|
+
"""
|
|
72
|
+
return cls(*[Metric.subclass_from_config(metric_config) for metric_config in config.values()])
|
|
73
|
+
|
|
74
|
+
async def compute(self, results: list[EvaluationResultT]) -> dict:
|
|
55
75
|
"""
|
|
56
76
|
Compute the metrics.
|
|
57
77
|
|
|
@@ -61,6 +81,9 @@ class MetricSet(Generic[ResultT]):
|
|
|
61
81
|
Returns:
|
|
62
82
|
The computed metrics.
|
|
63
83
|
"""
|
|
84
|
+
metric_results = await asyncio.gather(*[metric.compute(results) for metric in self.metrics])
|
|
64
85
|
return {
|
|
65
|
-
name: metric.weight * value
|
|
86
|
+
name: metric.weight * value
|
|
87
|
+
for metric, result in zip(self.metrics, metric_results, strict=False)
|
|
88
|
+
for name, value in result.items()
|
|
66
89
|
}
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
from abc import ABC
|
|
3
|
-
from typing import Any
|
|
4
3
|
|
|
5
4
|
from continuous_eval.metrics.retrieval import PrecisionRecallF1, RankedRetrievalMetrics
|
|
6
|
-
from continuous_eval.metrics.retrieval.matching_strategy import
|
|
7
|
-
from
|
|
5
|
+
from continuous_eval.metrics.retrieval.matching_strategy import MatchingStrategy
|
|
6
|
+
from typing_extensions import Self
|
|
8
7
|
|
|
9
8
|
from ragbits.evaluate.metrics.base import Metric
|
|
10
9
|
from ragbits.evaluate.pipelines.document_search import DocumentSearchResult
|
|
@@ -17,30 +16,37 @@ class DocumentSearchMetric(Metric[DocumentSearchResult], ABC):
|
|
|
17
16
|
"""
|
|
18
17
|
|
|
19
18
|
metric_cls: type[PrecisionRecallF1 | RankedRetrievalMetrics]
|
|
20
|
-
default_matching_strategy: type[RougeChunkMatch] = RougeChunkMatch
|
|
21
|
-
default_matching_options: DictConfig = OmegaConf.create({"threshold": 0.5})
|
|
22
19
|
|
|
23
|
-
def __init__(self,
|
|
20
|
+
def __init__(self, matching_strategy: MatchingStrategy, weight: float = 1.0) -> None:
|
|
24
21
|
"""
|
|
25
|
-
|
|
22
|
+
Initialize the document search metric.
|
|
26
23
|
|
|
27
24
|
Args:
|
|
28
|
-
|
|
25
|
+
matching_strategy: Matching strategys that determine relevance.
|
|
26
|
+
weight: Metric value weight in the final score, used during optimization.
|
|
29
27
|
"""
|
|
30
|
-
super().__init__(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
28
|
+
super().__init__(weight=weight)
|
|
29
|
+
self.metric = self.metric_cls(matching_strategy)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_config(cls, config: dict) -> Self:
|
|
33
|
+
"""
|
|
34
|
+
Create an instance of `DocumentSearchMetric` from a configuration dictionary.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
config: A dictionary containing configuration settings for the metric.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
An instance of the metric class initialized with the provided configuration.
|
|
41
|
+
"""
|
|
42
|
+
matching_strategy_cls = getattr(
|
|
43
|
+
importlib.import_module("continuous_eval.metrics.retrieval.matching_strategy"),
|
|
44
|
+
config["matching_strategy"]["type"],
|
|
45
|
+
)
|
|
46
|
+
matching_strategy = matching_strategy_cls(**config["matching_strategy"]["config"])
|
|
47
|
+
return cls(matching_strategy=matching_strategy, weight=config.get("weight", 1.0))
|
|
48
|
+
|
|
49
|
+
async def compute(self, results: list[DocumentSearchResult]) -> dict:
|
|
44
50
|
"""
|
|
45
51
|
Compute the metric.
|
|
46
52
|
|
|
@@ -51,7 +57,18 @@ class DocumentSearchMetric(Metric[DocumentSearchResult], ABC):
|
|
|
51
57
|
The computed metric.
|
|
52
58
|
"""
|
|
53
59
|
return self.metric.aggregate(
|
|
54
|
-
[
|
|
60
|
+
[
|
|
61
|
+
self.metric(
|
|
62
|
+
[
|
|
63
|
+
element.text_representation
|
|
64
|
+
for element in result.predicted_elements
|
|
65
|
+
if element.text_representation
|
|
66
|
+
],
|
|
67
|
+
result.reference_passages,
|
|
68
|
+
)
|
|
69
|
+
for result in results
|
|
70
|
+
if result.reference_passages is not None
|
|
71
|
+
]
|
|
55
72
|
)
|
|
56
73
|
|
|
57
74
|
|