ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. ragbits/evaluate/agent_simulation/__init__.py +87 -0
  2. ragbits/evaluate/agent_simulation/context.py +118 -0
  3. ragbits/evaluate/agent_simulation/conversation.py +333 -0
  4. ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
  5. ragbits/evaluate/agent_simulation/logger.py +165 -0
  6. ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
  7. ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
  8. ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
  9. ragbits/evaluate/agent_simulation/models.py +37 -0
  10. ragbits/evaluate/agent_simulation/results.py +200 -0
  11. ragbits/evaluate/agent_simulation/scenarios.py +129 -0
  12. ragbits/evaluate/agent_simulation/simulation.py +243 -0
  13. ragbits/evaluate/cli.py +150 -0
  14. ragbits/evaluate/config.py +11 -0
  15. ragbits/evaluate/dataloaders/__init__.py +3 -0
  16. ragbits/evaluate/dataloaders/base.py +95 -0
  17. ragbits/evaluate/dataloaders/document_search.py +61 -0
  18. ragbits/evaluate/dataloaders/exceptions.py +25 -0
  19. ragbits/evaluate/dataloaders/gaia.py +78 -0
  20. ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
  21. ragbits/evaluate/dataloaders/human_eval.py +70 -0
  22. ragbits/evaluate/dataloaders/question_answer.py +56 -0
  23. ragbits/evaluate/dataset_generator/pipeline.py +4 -4
  24. ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
  25. ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
  26. ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
  27. ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
  28. ragbits/evaluate/evaluator.py +178 -50
  29. ragbits/evaluate/factories/__init__.py +42 -0
  30. ragbits/evaluate/metrics/__init__.py +2 -23
  31. ragbits/evaluate/metrics/base.py +40 -17
  32. ragbits/evaluate/metrics/document_search.py +40 -23
  33. ragbits/evaluate/metrics/gaia.py +84 -0
  34. ragbits/evaluate/metrics/hotpot_qa.py +51 -0
  35. ragbits/evaluate/metrics/human_eval.py +105 -0
  36. ragbits/evaluate/metrics/question_answer.py +222 -0
  37. ragbits/evaluate/optimizer.py +138 -86
  38. ragbits/evaluate/pipelines/__init__.py +37 -0
  39. ragbits/evaluate/pipelines/base.py +34 -10
  40. ragbits/evaluate/pipelines/document_search.py +72 -67
  41. ragbits/evaluate/pipelines/gaia.py +249 -0
  42. ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
  43. ragbits/evaluate/pipelines/human_eval.py +323 -0
  44. ragbits/evaluate/pipelines/question_answer.py +96 -0
  45. ragbits/evaluate/utils.py +86 -59
  46. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
  47. ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
  48. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
  49. ragbits/evaluate/callbacks/base.py +0 -22
  50. ragbits/evaluate/callbacks/neptune.py +0 -26
  51. ragbits/evaluate/loaders/__init__.py +0 -21
  52. ragbits/evaluate/loaders/base.py +0 -24
  53. ragbits/evaluate/loaders/hf.py +0 -25
  54. ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
  55. /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
@@ -23,7 +23,7 @@ class BasicAnswerGenPrompt(Prompt[BasicAnswerGenInput, str]):
23
23
  "If you don't know the answer just say: I don't know."
24
24
  )
25
25
 
26
- user_prompt: str = "Text:\n<|text_start|>\n {{ chunk }} \n<|text_end|>\n\nQuestion:\n " "{{ question }} \n\nAnswer:"
26
+ user_prompt: str = "Text:\n<|text_start|>\n {{ chunk }} \n<|text_end|>\n\nQuestion:\n {{ question }} \n\nAnswer:"
27
27
 
28
28
 
29
29
  class PassagesGenInput(BaseModel):
@@ -49,9 +49,7 @@ class PassagesGenPrompt(Prompt[PassagesGenInput, str]):
49
49
  "FULL SENTENCES"
50
50
  )
51
51
 
52
- user_prompt: str = (
53
- "Question:\n {{ question }} \nAnswer:\n {{ basic_answer }} \nChunk:\n " "{{ chunk }}\n\nPassages:"
54
- )
52
+ user_prompt: str = "Question:\n {{ question }} \nAnswer:\n {{ basic_answer }} \nChunk:\n {{ chunk }}\n\nPassages:"
55
53
 
56
54
 
57
55
  class QueryGenInput(BaseModel):
@@ -7,7 +7,7 @@ from distilabel.steps.base import Step
7
7
 
8
8
  from ragbits.core.llms.base import LLM
9
9
  from ragbits.core.prompt import Prompt
10
- from ragbits.core.utils.config_handling import get_cls_from_config
10
+ from ragbits.core.utils.config_handling import import_by_path
11
11
 
12
12
  module = sys.modules[__name__]
13
13
 
@@ -23,9 +23,7 @@ class CorpusGenerationStep(Step):
23
23
  ):
24
24
  super().__init__()
25
25
  self._llm = llm
26
- self._prompt_class = (
27
- get_cls_from_config(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
28
- )
26
+ self._prompt_class = import_by_path(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
29
27
  self._num_per_topic = num_per_topic
30
28
 
31
29
  @property
@@ -2,11 +2,11 @@ import sys
2
2
  from abc import ABC, abstractmethod
3
3
  from typing import Any
4
4
 
5
- from distilabel.llms.base import LLM
5
+ from distilabel.models import LLM
6
6
  from distilabel.steps.tasks import TextGeneration
7
7
 
8
8
  from ragbits.core.prompt import ChatFormat, Prompt
9
- from ragbits.core.utils.config_handling import get_cls_from_config
9
+ from ragbits.core.utils.config_handling import import_by_path
10
10
 
11
11
  module = sys.modules[__name__]
12
12
 
@@ -18,9 +18,7 @@ class BaseDistilabelTask(TextGeneration, ABC):
18
18
  super().__init__(llm=llm)
19
19
  self._inputs = inputs
20
20
  self._outputs = outputs
21
- self._prompt_class = (
22
- get_cls_from_config(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
23
- )
21
+ self._prompt_class = import_by_path(prompt_class, module) if isinstance(prompt_class, str) else prompt_class
24
22
 
25
23
  @property
26
24
  def inputs(self) -> list[str]:
@@ -1,9 +1,9 @@
1
1
  from typing import Any
2
2
 
3
- from distilabel.llms.base import LLM
3
+ from distilabel.models import LLM
4
4
 
5
- from ...utils import get_closest_substring, get_passages_list
6
- from .base import BaseDistilabelTask
5
+ from ragbits.evaluate.dataset_generator.tasks.text_generation.base import BaseDistilabelTask
6
+ from ragbits.evaluate.dataset_generator.utils import get_closest_substring, get_passages_list
7
7
 
8
8
 
9
9
  class QueryGenTask(BaseDistilabelTask):
@@ -1,53 +1,153 @@
1
+ import asyncio
2
+ import random
1
3
  import time
2
- from collections.abc import Iterable
3
- from dataclasses import asdict
4
- from typing import Any
4
+ from collections.abc import Awaitable, Callable, Iterable, Sized
5
+ from dataclasses import dataclass
6
+ from typing import Generic, ParamSpec, TypeVar
5
7
 
6
- from tqdm.asyncio import tqdm
8
+ from pydantic import BaseModel
9
+ from tqdm import tqdm
7
10
 
8
- from ragbits.evaluate.loaders.base import DataLoader
11
+ from ragbits.core.utils.config_handling import ObjectConstructionConfig, WithConstructionConfig
12
+ from ragbits.core.utils.helpers import batched
13
+ from ragbits.evaluate.dataloaders.base import DataLoader
9
14
  from ragbits.evaluate.metrics.base import MetricSet
10
- from ragbits.evaluate.pipelines.base import EvaluationPipeline, EvaluationResult
15
+ from ragbits.evaluate.pipelines.base import EvaluationDataT, EvaluationPipeline, EvaluationResultT, EvaluationTargetT
11
16
 
17
+ _CallP = ParamSpec("_CallP")
18
+ _CallReturnT = TypeVar("_CallReturnT")
12
19
 
13
- class Evaluator:
20
+
21
+ @dataclass
22
+ class EvaluationTimePerf:
23
+ """
24
+ Container for evaluation time performance metrics.
25
+ """
26
+
27
+ total_time_in_seconds: float
28
+ samples_per_second: float
29
+ latency_in_seconds: float
30
+
31
+
32
+ @dataclass
33
+ class EvaluatorResult(Generic[EvaluationResultT]):
34
+ """
35
+ Container for evaluation results.
36
+ """
37
+
38
+ metrics: dict[str, int | float]
39
+ results: list[EvaluationResultT]
40
+ errors: list[Exception]
41
+ time_perf: EvaluationTimePerf
42
+
43
+
44
+ class EvaluationConfig(BaseModel):
45
+ """
46
+ Schema for the evaluation run config.
47
+ """
48
+
49
+ pipeline: ObjectConstructionConfig
50
+ dataloader: ObjectConstructionConfig
51
+ metrics: dict[str, ObjectConstructionConfig]
52
+
53
+
54
+ class EvaluatorConfig(BaseModel):
55
+ """
56
+ Schema for the evaluator config.
57
+ """
58
+
59
+ evaluation: EvaluationConfig
60
+ evaluator: dict | None = None
61
+
62
+
63
+ class Evaluator(WithConstructionConfig):
14
64
  """
15
65
  Evaluator class.
16
66
  """
17
67
 
68
+ def __init__(
69
+ self,
70
+ batch_size: int = 10,
71
+ num_retries: int = 3,
72
+ backoff_multiplier: int = 1,
73
+ backoff_max: int = 60,
74
+ parallelize_batches: bool = False,
75
+ ) -> None:
76
+ """
77
+ Initialize the Evaluator instance.
78
+
79
+ Args:
80
+ batch_size: batch size for the evaluation pipeline inference.
81
+ num_retries: The number of retries per evaluation pipeline inference error.
82
+ backoff_multiplier: The base delay multiplier for exponential backoff (in seconds).
83
+ backoff_max: The maximum allowed delay (in seconds) between retries.
84
+ parallelize_batches: Whether to process samples within each batch in parallel (asyncio.gather).
85
+ """
86
+ self.batch_size = batch_size
87
+ self.num_retries = num_retries
88
+ self.backoff_multiplier = backoff_multiplier
89
+ self.backoff_max = backoff_max
90
+ self.parallelize_batches = parallelize_batches
91
+
92
+ @classmethod
93
+ async def run_from_config(cls, config: dict) -> EvaluatorResult:
94
+ """
95
+ Run the evaluation based on configuration.
96
+
97
+ Args:
98
+ config: Evaluation config.
99
+
100
+ Returns:
101
+ The evaluation results.
102
+ """
103
+ evaluator_config = EvaluatorConfig.model_validate(config)
104
+ evaluation_config = EvaluationConfig.model_validate(evaluator_config.evaluation)
105
+ pipeline: EvaluationPipeline = EvaluationPipeline.subclass_from_config(evaluation_config.pipeline)
106
+ dataloader: DataLoader = DataLoader.subclass_from_config(evaluation_config.dataloader)
107
+ metricset: MetricSet = MetricSet.from_config(evaluation_config.metrics)
108
+
109
+ evaluator = cls.from_config(evaluator_config.evaluator or {})
110
+ return await evaluator.compute(
111
+ pipeline=pipeline,
112
+ dataloader=dataloader,
113
+ metricset=metricset,
114
+ )
115
+
18
116
  async def compute(
19
117
  self,
20
- pipeline: EvaluationPipeline,
21
- dataloader: DataLoader,
22
- metrics: MetricSet,
23
- ) -> dict[str, Any]:
118
+ pipeline: EvaluationPipeline[EvaluationTargetT, EvaluationDataT, EvaluationResultT],
119
+ dataloader: DataLoader[EvaluationDataT],
120
+ metricset: MetricSet[EvaluationResultT],
121
+ ) -> EvaluatorResult[EvaluationResultT]:
24
122
  """
25
123
  Compute the evaluation results for the given pipeline and data.
26
124
 
27
125
  Args:
28
126
  pipeline: The pipeline to be evaluated.
29
127
  dataloader: The dataloader to load the data.
30
- metrics: The metrics to be computed.
128
+ metricset: The metrics to be computed.
31
129
 
32
130
  Returns:
33
131
  The evaluation results.
34
132
  """
133
+ await pipeline.prepare()
134
+
35
135
  dataset = await dataloader.load()
36
- results, perf_results = await self._call_pipeline(pipeline, dataset)
37
- computed_metrics = self._compute_metrics(metrics, results)
38
- processed_results = self._results_processor(results)
136
+ results, errors, time_perf = await self._call_pipeline(pipeline, dataset)
137
+ metrics = await metricset.compute(results)
39
138
 
40
- return {
41
- **perf_results,
42
- **computed_metrics,
43
- **processed_results,
44
- }
139
+ return EvaluatorResult(
140
+ metrics=metrics,
141
+ results=results,
142
+ errors=errors,
143
+ time_perf=time_perf,
144
+ )
45
145
 
46
146
  async def _call_pipeline(
47
147
  self,
48
- pipeline: EvaluationPipeline,
49
- dataset: Iterable,
50
- ) -> tuple[list[EvaluationResult], dict[str, Any]]:
148
+ pipeline: EvaluationPipeline[EvaluationTargetT, EvaluationDataT, EvaluationResultT],
149
+ dataset: Iterable[EvaluationDataT],
150
+ ) -> tuple[list[EvaluationResultT], list[Exception], EvaluationTimePerf]:
51
151
  """
52
152
  Call the pipeline with the given data.
53
153
 
@@ -59,39 +159,69 @@ class Evaluator:
59
159
  The evaluation results and performance metrics.
60
160
  """
61
161
  start_time = time.perf_counter()
62
- pipe_outputs = await tqdm.gather(*[pipeline(data) for data in dataset], desc="Evaluation")
162
+
163
+ total_samples = len(dataset) if isinstance(dataset, Sized) else None
164
+ batches = batched(dataset, self.batch_size)
165
+ outputs: list[Iterable[EvaluationResultT] | Exception] = []
166
+
167
+ with tqdm(total=total_samples, desc="Evaluation", unit="sample") as progress_bar:
168
+ for batch in batches:
169
+ batch_list = list(batch)
170
+
171
+ if self.parallelize_batches:
172
+ tasks = [self._call_with_error_handling(pipeline, [sample]) for sample in batch_list]
173
+ batch_results = await asyncio.gather(*tasks)
174
+
175
+ for result in batch_results:
176
+ outputs.append(result)
177
+ progress_bar.update(1)
178
+ else:
179
+ result = await self._call_with_error_handling(pipeline, batch_list)
180
+ outputs.append(result)
181
+ progress_bar.update(len(batch_list))
182
+
63
183
  end_time = time.perf_counter()
64
- return pipe_outputs, self._compute_time_perf(start_time, end_time, len(pipe_outputs))
65
184
 
66
- @staticmethod
67
- def _results_processor(results: list[EvaluationResult]) -> dict[str, Any]:
185
+ errors = [output for output in outputs if isinstance(output, Exception)]
186
+ results = [item for output in outputs if not isinstance(output, Exception) for item in output]
187
+
188
+ return results, errors, self._compute_time_perf(start_time, end_time, len(results))
189
+
190
+ async def _call_with_error_handling(
191
+ self,
192
+ executable: Callable[_CallP, Awaitable[_CallReturnT]],
193
+ *executable_args: _CallP.args,
194
+ **executable_kwargs: _CallP.kwargs,
195
+ ) -> _CallReturnT | Exception:
68
196
  """
69
- Process the results.
197
+ Call executable with a standarized error handling.
198
+ If an error occurs, the executable is retried `num_retries` times using randomized exponential backoff.
70
199
 
71
200
  Args:
72
- results: The evaluation results.
201
+ executable: The callable function to execute.
202
+ executable_args: Positional arguments to pass to the executable.
203
+ executable_kwargs: Keyword arguments to pass to the executable.
73
204
 
74
205
  Returns:
75
- The processed results.
76
- """
77
- return {"results": [asdict(result) for result in results]}
206
+ The result of the executable if successful.
78
207
 
79
- @staticmethod
80
- def _compute_metrics(metrics: MetricSet, results: list[EvaluationResult]) -> dict[str, Any]:
208
+ Raises:
209
+ Exception: The last encountered exception after all retries are exhausted.
81
210
  """
82
- Compute a metric using the given inputs.
211
+ for i in range(max(0, self.num_retries) + 1):
212
+ try:
213
+ return await executable(*executable_args, **executable_kwargs)
214
+ except Exception as exc:
215
+ if i == self.num_retries:
216
+ return exc
83
217
 
84
- Args:
85
- metrics: The metrics to be computed.
86
- results: The evaluation results.
218
+ delay = random.uniform(0, min(2**i * self.backoff_multiplier, self.backoff_max)) # noqa: S311
219
+ await asyncio.sleep(delay)
87
220
 
88
- Returns:
89
- The computed metric.
90
- """
91
- return {"metrics": metrics.compute(results)}
221
+ raise RuntimeError("Unreachable code reached") # mypy quirk
92
222
 
93
223
  @staticmethod
94
- def _compute_time_perf(start_time: float, end_time: float, num_samples: int) -> dict[str, Any]:
224
+ def _compute_time_perf(start_time: float, end_time: float, num_samples: int) -> EvaluationTimePerf:
95
225
  """
96
226
  Compute the performance metrics.
97
227
 
@@ -107,10 +237,8 @@ class Evaluator:
107
237
  throughput = num_samples / latency
108
238
  latency_sample = 1.0 / throughput if throughput > 0 else 0.0
109
239
 
110
- return {
111
- "time_perf": {
112
- "total_time_in_seconds": latency,
113
- "samples_per_second": throughput,
114
- "latency_in_seconds": latency_sample,
115
- },
116
- }
240
+ return EvaluationTimePerf(
241
+ total_time_in_seconds=latency,
242
+ samples_per_second=throughput,
243
+ latency_in_seconds=latency_sample,
244
+ )
@@ -0,0 +1,42 @@
1
+ import asyncio
2
+
3
+ from continuous_eval.metrics.retrieval.matching_strategy import RougeChunkMatch
4
+ from datasets import load_dataset
5
+
6
+ from ragbits.core.embeddings.dense import LiteLLMEmbedder
7
+ from ragbits.core.sources.hf import HuggingFaceSource
8
+ from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
9
+ from ragbits.document_search import DocumentSearch
10
+ from ragbits.document_search.documents.document import DocumentMeta
11
+ from ragbits.evaluate.dataloaders.document_search import DocumentSearchDataLoader
12
+ from ragbits.evaluate.metrics import MetricSet
13
+ from ragbits.evaluate.metrics.document_search import DocumentSearchPrecisionRecallF1
14
+
15
+
16
+ async def _add_example_documents(document_search: DocumentSearch) -> None:
17
+ dataset = load_dataset(path="deepsense-ai/synthetic-rag-dataset_v1.0", split="train")
18
+ documents = [DocumentMeta.from_literal(doc) for chunks in dataset["chunks"] for doc in chunks]
19
+ await document_search.ingest(documents)
20
+
21
+
22
+ def basic_document_search_factory() -> DocumentSearch:
23
+ """
24
+ Factory for basic example document search instance.
25
+ """
26
+ document_search: DocumentSearch = DocumentSearch(vector_store=InMemoryVectorStore(embedder=LiteLLMEmbedder()))
27
+ asyncio.run(_add_example_documents(document_search))
28
+ return document_search
29
+
30
+
31
+ def synthetic_rag_dataset() -> DocumentSearchDataLoader:
32
+ """
33
+ Factory for synthetic RAG dataset.
34
+ """
35
+ return DocumentSearchDataLoader(source=HuggingFaceSource(path="deepsense-ai/synthetic-rag-dataset_v1.0"))
36
+
37
+
38
+ def precision_recall_f1() -> MetricSet:
39
+ """
40
+ Factory of precision recall f1 metric set for retrival evaluation.
41
+ """
42
+ return MetricSet(DocumentSearchPrecisionRecallF1(matching_strategy=RougeChunkMatch()))
@@ -1,24 +1,3 @@
1
- import sys
1
+ from ragbits.evaluate.metrics.base import Metric, MetricSet
2
2
 
3
- from omegaconf import ListConfig
4
-
5
- from ragbits.core.utils.config_handling import get_cls_from_config
6
-
7
- from .base import MetricSet
8
-
9
- module = sys.modules[__name__]
10
-
11
-
12
- def metric_set_factory(cfg: ListConfig) -> MetricSet:
13
- """
14
- A function creating MetricSet instance from the configuration
15
- Args:
16
- cfg - metric cnfiguration
17
- Returns:
18
- MetricSet
19
- """
20
- metrics = []
21
- for metric_cfg in cfg:
22
- metric_module = get_cls_from_config(metric_cfg.type, module)
23
- metrics.append(metric_module(metric_cfg))
24
- return MetricSet(*metrics)
3
+ __all__ = ["Metric", "MetricSet"]
@@ -1,31 +1,35 @@
1
+ import asyncio
1
2
  from abc import ABC, abstractmethod
2
- from typing import Any, Generic, TypeVar
3
+ from types import ModuleType
4
+ from typing import ClassVar, Generic
3
5
 
4
- from omegaconf import DictConfig
6
+ from typing_extensions import Self
5
7
 
6
- from ragbits.evaluate.pipelines.base import EvaluationResult
8
+ from ragbits.core.utils.config_handling import WithConstructionConfig
9
+ from ragbits.evaluate import metrics
10
+ from ragbits.evaluate.pipelines.base import EvaluationResultT
7
11
 
8
- ResultT = TypeVar("ResultT", bound=EvaluationResult)
9
12
 
10
-
11
- class Metric(Generic[ResultT], ABC):
13
+ class Metric(WithConstructionConfig, Generic[EvaluationResultT], ABC):
12
14
  """
13
15
  Base class for metrics.
14
16
  """
15
17
 
16
- def __init__(self, config: DictConfig | None = None) -> None:
18
+ default_module: ClassVar[ModuleType | None] = metrics
19
+ configuration_key: ClassVar[str] = "metric"
20
+
21
+ def __init__(self, weight: float = 1.0) -> None:
17
22
  """
18
- Initializes the metric.
23
+ Initialize the metric.
19
24
 
20
25
  Args:
21
- config: The metric configuration.
26
+ weight: Metric value weight in the final score, used during optimization.
22
27
  """
23
28
  super().__init__()
24
- self.config = config
25
- self.weight: float = getattr(self.config, "weight", 1.0)
29
+ self.weight = weight
26
30
 
27
31
  @abstractmethod
28
- def compute(self, results: list[ResultT]) -> dict[str, Any]:
32
+ async def compute(self, results: list[EvaluationResultT]) -> dict:
29
33
  """
30
34
  Compute the metric.
31
35
 
@@ -37,21 +41,37 @@ class Metric(Generic[ResultT], ABC):
37
41
  """
38
42
 
39
43
 
40
- class MetricSet(Generic[ResultT]):
44
+ class MetricSet(WithConstructionConfig, Generic[EvaluationResultT]):
41
45
  """
42
46
  Represents a set of metrics.
43
47
  """
44
48
 
45
- def __init__(self, *metrics: Metric[ResultT]) -> None:
49
+ configuration_key: ClassVar[str] = "metrics"
50
+ default_module: ClassVar[ModuleType | None] = metrics
51
+
52
+ def __init__(self, *metrics: Metric[EvaluationResultT]) -> None:
46
53
  """
47
- Initializes the metric set.
54
+ Initialize the metric set.
48
55
 
49
56
  Args:
50
57
  metrics: The metrics.
51
58
  """
52
59
  self.metrics = metrics
53
60
 
54
- def compute(self, results: list[ResultT]) -> dict[str, Any]:
61
+ @classmethod
62
+ def from_config(cls, config: dict) -> Self:
63
+ """
64
+ Create an instance of `MetricSet` from a configuration dictionary.
65
+
66
+ Args:
67
+ config: A dictionary containing configuration settings for the metric set.
68
+
69
+ Returns:
70
+ An instance of the metric set class initialized with the provided configuration.
71
+ """
72
+ return cls(*[Metric.subclass_from_config(metric_config) for metric_config in config.values()])
73
+
74
+ async def compute(self, results: list[EvaluationResultT]) -> dict:
55
75
  """
56
76
  Compute the metrics.
57
77
 
@@ -61,6 +81,9 @@ class MetricSet(Generic[ResultT]):
61
81
  Returns:
62
82
  The computed metrics.
63
83
  """
84
+ metric_results = await asyncio.gather(*[metric.compute(results) for metric in self.metrics])
64
85
  return {
65
- name: metric.weight * value for metric in self.metrics for name, value in metric.compute(results).items()
86
+ name: metric.weight * value
87
+ for metric, result in zip(self.metrics, metric_results, strict=False)
88
+ for name, value in result.items()
66
89
  }
@@ -1,10 +1,9 @@
1
1
  import importlib
2
2
  from abc import ABC
3
- from typing import Any
4
3
 
5
4
  from continuous_eval.metrics.retrieval import PrecisionRecallF1, RankedRetrievalMetrics
6
- from continuous_eval.metrics.retrieval.matching_strategy import RougeChunkMatch
7
- from omegaconf import DictConfig, OmegaConf
5
+ from continuous_eval.metrics.retrieval.matching_strategy import MatchingStrategy
6
+ from typing_extensions import Self
8
7
 
9
8
  from ragbits.evaluate.metrics.base import Metric
10
9
  from ragbits.evaluate.pipelines.document_search import DocumentSearchResult
@@ -17,30 +16,37 @@ class DocumentSearchMetric(Metric[DocumentSearchResult], ABC):
17
16
  """
18
17
 
19
18
  metric_cls: type[PrecisionRecallF1 | RankedRetrievalMetrics]
20
- default_matching_strategy: type[RougeChunkMatch] = RougeChunkMatch
21
- default_matching_options: DictConfig = OmegaConf.create({"threshold": 0.5})
22
19
 
23
- def __init__(self, config: DictConfig | None = None) -> None:
20
+ def __init__(self, matching_strategy: MatchingStrategy, weight: float = 1.0) -> None:
24
21
  """
25
- Initializes the metric.
22
+ Initialize the document search metric.
26
23
 
27
24
  Args:
28
- config: The metric configuration.
25
+ matching_strategy: Matching strategys that determine relevance.
26
+ weight: Metric value weight in the final score, used during optimization.
29
27
  """
30
- super().__init__(config)
31
- if not self.config:
32
- matching_strategy = self.default_matching_strategy
33
- options = self.default_matching_options
34
-
35
- else:
36
- matching_strategy = getattr(
37
- importlib.import_module("continuous_eval.metrics.retrieval.matching_strategy"),
38
- self.config.matching_strategy,
39
- )
40
- options = self.config.options
41
- self.metric = self.metric_cls(matching_strategy(**options))
42
-
43
- def compute(self, results: list[DocumentSearchResult]) -> dict[str, Any]:
28
+ super().__init__(weight=weight)
29
+ self.metric = self.metric_cls(matching_strategy)
30
+
31
+ @classmethod
32
+ def from_config(cls, config: dict) -> Self:
33
+ """
34
+ Create an instance of `DocumentSearchMetric` from a configuration dictionary.
35
+
36
+ Args:
37
+ config: A dictionary containing configuration settings for the metric.
38
+
39
+ Returns:
40
+ An instance of the metric class initialized with the provided configuration.
41
+ """
42
+ matching_strategy_cls = getattr(
43
+ importlib.import_module("continuous_eval.metrics.retrieval.matching_strategy"),
44
+ config["matching_strategy"]["type"],
45
+ )
46
+ matching_strategy = matching_strategy_cls(**config["matching_strategy"]["config"])
47
+ return cls(matching_strategy=matching_strategy, weight=config.get("weight", 1.0))
48
+
49
+ async def compute(self, results: list[DocumentSearchResult]) -> dict:
44
50
  """
45
51
  Compute the metric.
46
52
 
@@ -51,7 +57,18 @@ class DocumentSearchMetric(Metric[DocumentSearchResult], ABC):
51
57
  The computed metric.
52
58
  """
53
59
  return self.metric.aggregate(
54
- [self.metric(result.predicted_passages, result.reference_passages) for result in results]
60
+ [
61
+ self.metric(
62
+ [
63
+ element.text_representation
64
+ for element in result.predicted_elements
65
+ if element.text_representation
66
+ ],
67
+ result.reference_passages,
68
+ )
69
+ for result in results
70
+ if result.reference_passages is not None
71
+ ]
55
72
  )
56
73
 
57
74