ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. ragbits/evaluate/agent_simulation/__init__.py +87 -0
  2. ragbits/evaluate/agent_simulation/context.py +118 -0
  3. ragbits/evaluate/agent_simulation/conversation.py +333 -0
  4. ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
  5. ragbits/evaluate/agent_simulation/logger.py +165 -0
  6. ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
  7. ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
  8. ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
  9. ragbits/evaluate/agent_simulation/models.py +37 -0
  10. ragbits/evaluate/agent_simulation/results.py +200 -0
  11. ragbits/evaluate/agent_simulation/scenarios.py +129 -0
  12. ragbits/evaluate/agent_simulation/simulation.py +243 -0
  13. ragbits/evaluate/cli.py +150 -0
  14. ragbits/evaluate/config.py +11 -0
  15. ragbits/evaluate/dataloaders/__init__.py +3 -0
  16. ragbits/evaluate/dataloaders/base.py +95 -0
  17. ragbits/evaluate/dataloaders/document_search.py +61 -0
  18. ragbits/evaluate/dataloaders/exceptions.py +25 -0
  19. ragbits/evaluate/dataloaders/gaia.py +78 -0
  20. ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
  21. ragbits/evaluate/dataloaders/human_eval.py +70 -0
  22. ragbits/evaluate/dataloaders/question_answer.py +56 -0
  23. ragbits/evaluate/dataset_generator/pipeline.py +4 -4
  24. ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
  25. ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
  26. ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
  27. ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
  28. ragbits/evaluate/evaluator.py +178 -50
  29. ragbits/evaluate/factories/__init__.py +42 -0
  30. ragbits/evaluate/metrics/__init__.py +2 -23
  31. ragbits/evaluate/metrics/base.py +40 -17
  32. ragbits/evaluate/metrics/document_search.py +40 -23
  33. ragbits/evaluate/metrics/gaia.py +84 -0
  34. ragbits/evaluate/metrics/hotpot_qa.py +51 -0
  35. ragbits/evaluate/metrics/human_eval.py +105 -0
  36. ragbits/evaluate/metrics/question_answer.py +222 -0
  37. ragbits/evaluate/optimizer.py +138 -86
  38. ragbits/evaluate/pipelines/__init__.py +37 -0
  39. ragbits/evaluate/pipelines/base.py +34 -10
  40. ragbits/evaluate/pipelines/document_search.py +72 -67
  41. ragbits/evaluate/pipelines/gaia.py +249 -0
  42. ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
  43. ragbits/evaluate/pipelines/human_eval.py +323 -0
  44. ragbits/evaluate/pipelines/question_answer.py +96 -0
  45. ragbits/evaluate/utils.py +86 -59
  46. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
  47. ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
  48. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
  49. ragbits/evaluate/callbacks/base.py +0 -22
  50. ragbits/evaluate/callbacks/neptune.py +0 -26
  51. ragbits/evaluate/loaders/__init__.py +0 -21
  52. ragbits/evaluate/loaders/base.py +0 -24
  53. ragbits/evaluate/loaders/hf.py +0 -25
  54. ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
  55. /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
@@ -1,132 +1,184 @@
1
1
  import asyncio
2
2
  import warnings
3
+ from collections.abc import Callable
3
4
  from copy import deepcopy
4
- from typing import Any
5
5
 
6
6
  import optuna
7
- from omegaconf import DictConfig, ListConfig
7
+ from optuna import Trial
8
+ from pydantic import BaseModel
8
9
 
9
- from .callbacks.base import CallbackConfigurator
10
- from .evaluator import Evaluator
11
- from .loaders.base import DataLoader
12
- from .metrics.base import MetricSet
13
- from .pipelines.base import EvaluationPipeline
10
+ from ragbits.core.utils.config_handling import WithConstructionConfig, import_by_path
11
+ from ragbits.evaluate.dataloaders.base import DataLoader
12
+ from ragbits.evaluate.evaluator import Evaluator, EvaluatorConfig
13
+ from ragbits.evaluate.metrics.base import MetricSet
14
+ from ragbits.evaluate.pipelines.base import EvaluationPipeline
15
+ from ragbits.evaluate.utils import setup_optuna_neptune_callback
14
16
 
15
17
 
16
- class Optimizer:
18
+ class OptimizerConfig(BaseModel):
17
19
  """
18
- Class for optimization
20
+ Schema for the optimizer config.
19
21
  """
20
22
 
21
- INFINITY = 1e16
23
+ evaluator: EvaluatorConfig
24
+ optimizer: dict | None = None
25
+ neptune_callback: bool = False
22
26
 
23
- def __init__(self, cfg: DictConfig):
24
- self.config = cfg
27
+
28
+ class Optimizer(WithConstructionConfig):
29
+ """
30
+ Optimizer class.
31
+ """
32
+
33
+ def __init__(self, direction: str = "maximize", n_trials: int = 10, max_retries_for_trial: int = 1) -> None:
34
+ """
35
+ Initialize the pipeline optimizer.
36
+
37
+ Args:
38
+ direction: Direction of optimization.
39
+ n_trials: The number of trials for each process.
40
+ max_retries_for_trial: The number of retires for single process.
41
+ """
42
+ self.direction = direction
43
+ self.n_trials = n_trials
44
+ self.max_retries_for_trial = max_retries_for_trial
25
45
  # workaround for optuna not allowing different choices for different trials
26
46
  # TODO check how optuna handles parallelism. discuss if we want to have parallel studies
27
- self._choices_cache: dict[str, list[Any]] = {}
47
+ self._choices_cache: dict[str, list] = {}
48
+
49
+ @classmethod
50
+ def run_from_config(cls, config: dict) -> list[tuple[dict, float, dict[str, float]]]:
51
+ """
52
+ Run the optimization process configured with a config object.
53
+
54
+ Args:
55
+ config: Optimizer config.
56
+
57
+ Returns:
58
+ List of tested configs with associated scores and metrics.
59
+ """
60
+ optimizer_config = OptimizerConfig.model_validate(config)
61
+ evaluator_config = EvaluatorConfig.model_validate(optimizer_config.evaluator)
62
+
63
+ dataloader: DataLoader = DataLoader.subclass_from_config(evaluator_config.evaluation.dataloader)
64
+ metricset: MetricSet = MetricSet.from_config(evaluator_config.evaluation.metrics)
65
+
66
+ pipeline_class = import_by_path(evaluator_config.evaluation.pipeline.type)
67
+ pipeline_config = dict(evaluator_config.evaluation.pipeline.config)
68
+ callbacks = [setup_optuna_neptune_callback()] if optimizer_config.neptune_callback else []
69
+
70
+ optimizer = cls.from_config(optimizer_config.optimizer or {})
71
+ return optimizer.optimize(
72
+ pipeline_class=pipeline_class,
73
+ pipeline_config=pipeline_config,
74
+ metricset=metricset,
75
+ dataloader=dataloader,
76
+ callbacks=callbacks,
77
+ )
28
78
 
29
79
  def optimize(
30
80
  self,
31
81
  pipeline_class: type[EvaluationPipeline],
32
- config_with_params: DictConfig,
82
+ pipeline_config: dict,
33
83
  dataloader: DataLoader,
34
- metrics: MetricSet,
35
- callback_configurators: list[CallbackConfigurator] | None = None,
36
- ) -> list[tuple[DictConfig, float, dict[str, float]]]:
84
+ metricset: MetricSet,
85
+ callbacks: list[Callable] | None = None,
86
+ ) -> list[tuple[dict, float, dict[str, float]]]:
37
87
  """
38
- A method for running the optimization process for given parameters
88
+ Run the optimization process for given parameters.
89
+
39
90
  Args:
40
- pipeline_class - a type of pipeline to be optimized
41
- config_with_params - a configuration defining the optimization process
42
- dataloader - a dataloader
43
- metrics - object representing the metrics to be optimized
44
- log_to_neptune - indicator whether the results should be logged to neptune
91
+ pipeline_class: Pipeline to be optimized.
92
+ pipeline_config: Configuration defining the optimization process.
93
+ dataloader: Data loader.
94
+ metricset: Metrics to be optimized.
95
+ callbacks: Experiment callbacks.
96
+
45
97
  Returns:
46
- list of tuples with configs and their scores
98
+ List of tested configs with associated scores and metrics.
47
99
  """
48
- # TODO check details on how to parametrize optuna
49
- optimization_kwargs = {"n_trials": self.config.n_trials}
50
- if callback_configurators:
51
- optimization_kwargs["callbacks"] = [configurator.get_callback() for configurator in callback_configurators]
52
100
 
53
- def objective(trial: optuna.Trial) -> float:
101
+ def objective(trial: Trial) -> float:
54
102
  return self._objective(
55
103
  trial=trial,
56
104
  pipeline_class=pipeline_class,
57
- config_with_params=config_with_params,
105
+ pipeline_config=pipeline_config,
58
106
  dataloader=dataloader,
59
- metrics=metrics,
107
+ metricset=metricset,
60
108
  )
61
109
 
62
- study = optuna.create_study(direction=self.config.direction)
63
-
64
- study.optimize(objective, **optimization_kwargs)
65
- configs_with_scores = [
66
- (trial.user_attrs["cfg"], trial.user_attrs["score"], trial.user_attrs["all_metrics"])
67
- for trial in study.get_trials()
68
- ]
69
-
70
- def sorting_key(results: tuple[DictConfig, float, dict[str, float]]) -> float:
71
- if self.config.direction == "maximize":
72
- return -results[1]
73
- else:
74
- return results[1]
75
-
76
- return sorted(configs_with_scores, key=sorting_key)
110
+ study = optuna.create_study(direction=self.direction)
111
+ study.optimize(
112
+ func=objective,
113
+ n_trials=self.n_trials,
114
+ callbacks=callbacks,
115
+ )
116
+ return sorted(
117
+ [
118
+ (
119
+ trial.user_attrs["config"],
120
+ trial.user_attrs["score"],
121
+ trial.user_attrs["metrics"],
122
+ )
123
+ for trial in study.get_trials()
124
+ ],
125
+ key=lambda x: -x[1] if self.direction == "maximize" else x[1],
126
+ )
77
127
 
78
128
  def _objective(
79
129
  self,
130
+ trial: Trial,
80
131
  pipeline_class: type[EvaluationPipeline],
81
- trial: optuna.Trial,
82
- config_with_params: DictConfig,
132
+ pipeline_config: dict,
83
133
  dataloader: DataLoader,
84
- metrics: MetricSet,
134
+ metricset: MetricSet,
85
135
  ) -> float:
86
- max_retries = getattr(self.config, "max_retries_for_trial", 1)
136
+ """
137
+ Run a single experiment.
138
+ """
139
+ evaluator = Evaluator()
140
+ event_loop = asyncio.get_event_loop()
141
+
142
+ score = 1e16 if self.direction == "maximize" else -1e16
143
+ metrics_values = None
87
144
  config_for_trial = None
88
- for attempt_idx in range(max_retries):
145
+
146
+ for attempt in range(1, self.max_retries_for_trial + 1):
89
147
  try:
90
- config_for_trial = deepcopy(config_with_params)
148
+ config_for_trial = deepcopy(pipeline_config)
91
149
  self._set_values_for_optimized_params(cfg=config_for_trial, trial=trial, ancestors=[])
92
- pipeline = pipeline_class(config_for_trial)
93
- metrics_values = self._score(pipeline=pipeline, dataloader=dataloader, metrics=metrics)
94
- score = sum(metrics_values.values())
95
- break
96
- except Exception as e:
97
- if attempt_idx < max_retries - 1:
98
- warnings.warn(
99
- message=f"Execution of the trial failed: {e}. A retry will be initiated.", category=UserWarning
100
- )
101
- else:
102
- score = self.INFINITY
103
- if self.config.direction == "maximize":
104
- score *= -1
105
- metrics_values = {}
106
- warnings.warn(
107
- message=f"Execution of the trial failed: {e}. Setting the score to {score}",
108
- category=UserWarning,
150
+ pipeline = pipeline_class.from_config(config_for_trial)
151
+
152
+ results = event_loop.run_until_complete(
153
+ evaluator.compute(
154
+ pipeline=pipeline,
155
+ dataloader=dataloader,
156
+ metricset=metricset,
109
157
  )
158
+ )
159
+ score = sum(results.metrics.values())
160
+ metrics_values = results.metrics
161
+ break
162
+ except Exception as exc:
163
+ message = (
164
+ f"Execution of the trial failed: {exc}. A retry will be initiated"
165
+ if attempt < self.max_retries_for_trial
166
+ else f"Execution of the trial failed: {exc}. Setting the score to {score}"
167
+ )
168
+ warnings.warn(message=message, category=UserWarning)
169
+
110
170
  trial.set_user_attr("score", score)
111
- trial.set_user_attr("cfg", config_for_trial)
112
- trial.set_user_attr("all_metrics", metrics_values)
113
- return score
171
+ trial.set_user_attr("metrics", metrics_values)
172
+ trial.set_user_attr("config", config_for_trial)
114
173
 
115
- @staticmethod
116
- def _score(pipeline: EvaluationPipeline, dataloader: DataLoader, metrics: MetricSet) -> dict[str, float]:
117
- evaluator = Evaluator()
118
- event_loop = asyncio.get_event_loop()
119
- results = event_loop.run_until_complete(
120
- evaluator.compute(pipeline=pipeline, dataloader=dataloader, metrics=metrics)
121
- )
122
- return results["metrics"]
174
+ return score
123
175
 
124
- def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial, ancestors: list[str]) -> None: # noqa: PLR0912
176
+ def _set_values_for_optimized_params(self, cfg: dict, trial: Trial, ancestors: list[str]) -> None: # noqa: PLR0912
125
177
  """
126
- Recursive method for sampling parameter values for optuna.Trial
178
+ Recursive method for sampling parameter values for optuna trial.
127
179
  """
128
180
  for key, value in cfg.items():
129
- if isinstance(value, DictConfig):
181
+ if isinstance(value, dict):
130
182
  if value.get("optimize"):
131
183
  param_id = f"{'.'.join(ancestors)}.{key}" # type: ignore
132
184
  choices = value.get("choices")
@@ -147,12 +199,12 @@ class Optimizer:
147
199
  raise ValueError("Either choices or range must be specified")
148
200
  choice_idx = trial.suggest_categorical(name=param_id, choices=choices_index) # type: ignore
149
201
  choice = choices[choice_idx]
150
- if isinstance(choice, DictConfig):
202
+ if isinstance(choice, dict):
151
203
  self._set_values_for_optimized_params(choice, trial, ancestors + [key, str(choice_idx)]) # type: ignore
152
204
  cfg[key] = choice
153
205
  else:
154
206
  self._set_values_for_optimized_params(value, trial, ancestors + [key]) # type: ignore
155
- elif isinstance(value, ListConfig):
207
+ elif isinstance(value, list):
156
208
  for param in value:
157
- if isinstance(param, DictConfig):
209
+ if isinstance(param, dict):
158
210
  self._set_values_for_optimized_params(param, trial, ancestors + [key]) # type: ignore
@@ -0,0 +1,37 @@
1
+ from ragbits.core.utils.config_handling import WithConstructionConfig
2
+ from ragbits.document_search import DocumentSearch
3
+ from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult
4
+ from ragbits.evaluate.pipelines.document_search import DocumentSearchPipeline
5
+ from ragbits.evaluate.pipelines.gaia import GaiaPipeline
6
+ from ragbits.evaluate.pipelines.hotpot_qa import HotpotQAPipeline
7
+ from ragbits.evaluate.pipelines.human_eval import HumanEvalPipeline
8
+
9
+ __all__ = [
10
+ "DocumentSearchPipeline",
11
+ "EvaluationData",
12
+ "EvaluationPipeline",
13
+ "EvaluationResult",
14
+ "GaiaPipeline",
15
+ "HotpotQAPipeline",
16
+ "HumanEvalPipeline",
17
+ ]
18
+
19
+ _target_to_evaluation_pipeline: dict[type[WithConstructionConfig], type[EvaluationPipeline]] = {
20
+ DocumentSearch: DocumentSearchPipeline,
21
+ }
22
+
23
+
24
+ def get_evaluation_pipeline_for_target(evaluation_target: WithConstructionConfig) -> EvaluationPipeline:
25
+ """
26
+ A function instantiating evaluation pipeline for given WithConstructionConfig object
27
+ Args:
28
+ evaluation_target: WithConstructionConfig object to be evaluated
29
+ Returns:
30
+ instance of evaluation pipeline
31
+ Raises:
32
+ ValueError for classes with no registered evaluation pipeline
33
+ """
34
+ for supported_type, evaluation_pipeline_type in _target_to_evaluation_pipeline.items():
35
+ if isinstance(evaluation_target, supported_type):
36
+ return evaluation_pipeline_type(evaluation_target=evaluation_target)
37
+ raise ValueError(f"Evaluation pipeline not implemented for {evaluation_target.__class__}")
@@ -1,8 +1,23 @@
1
1
  from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable
2
3
  from dataclasses import dataclass
3
- from typing import Any
4
+ from types import ModuleType
5
+ from typing import ClassVar, Generic, TypeVar
4
6
 
5
- from omegaconf import DictConfig
7
+ from pydantic import BaseModel
8
+
9
+ from ragbits.core.utils.config_handling import WithConstructionConfig
10
+ from ragbits.evaluate import pipelines
11
+
12
+ EvaluationDataT = TypeVar("EvaluationDataT", bound="EvaluationData")
13
+ EvaluationResultT = TypeVar("EvaluationResultT", bound="EvaluationResult")
14
+ EvaluationTargetT = TypeVar("EvaluationTargetT", bound=WithConstructionConfig)
15
+
16
+
17
+ class EvaluationData(BaseModel, ABC):
18
+ """
19
+ Represents the data for a single evaluation.
20
+ """
6
21
 
7
22
 
8
23
  @dataclass
@@ -12,25 +27,34 @@ class EvaluationResult(ABC):
12
27
  """
13
28
 
14
29
 
15
- class EvaluationPipeline(ABC):
30
+ class EvaluationPipeline(WithConstructionConfig, Generic[EvaluationTargetT, EvaluationDataT, EvaluationResultT], ABC):
16
31
  """
17
- Collection evaluation pipeline.
32
+ Evaluation pipeline.
18
33
  """
19
34
 
20
- def __init__(self, config: DictConfig | None = None) -> None:
35
+ default_module: ClassVar[ModuleType | None] = pipelines
36
+ configuration_key: ClassVar[str] = "pipeline"
37
+
38
+ def __init__(self, evaluation_target: EvaluationTargetT) -> None:
21
39
  """
22
- Initializes the evaluation pipeline.
40
+ Initialize the evaluation pipeline.
23
41
 
24
42
  Args:
25
- config: The evaluation pipeline configuration.
43
+ evaluation_target: Evaluation target instance.
26
44
  """
27
45
  super().__init__()
28
- self.config = config or DictConfig({})
46
+ self.evaluation_target = evaluation_target
47
+
48
+ async def prepare(self) -> None:
49
+ """
50
+ Prepare pipeline for evaluation. Optional step.
51
+ """
52
+ pass
29
53
 
30
54
  @abstractmethod
31
- async def __call__(self, data: dict[str, Any]) -> EvaluationResult:
55
+ async def __call__(self, data: Iterable[EvaluationDataT]) -> Iterable[EvaluationResultT]:
32
56
  """
33
- Runs the evaluation pipeline.
57
+ Run the evaluation pipeline.
34
58
 
35
59
  Args:
36
60
  data: The evaluation data.
@@ -1,16 +1,25 @@
1
1
  import asyncio
2
- import uuid
2
+ from collections.abc import Iterable, Sequence
3
3
  from dataclasses import dataclass
4
- from functools import cached_property
4
+ from uuid import uuid4
5
5
 
6
- from omegaconf import DictConfig
7
- from tqdm.asyncio import tqdm
6
+ from typing_extensions import Self
8
7
 
8
+ from ragbits.core.sources.hf import HuggingFaceSource
9
9
  from ragbits.document_search import DocumentSearch
10
- from ragbits.document_search.documents.document import DocumentMeta
11
- from ragbits.document_search.documents.element import TextElement
12
- from ragbits.document_search.documents.sources import HuggingFaceSource
13
- from ragbits.evaluate.pipelines.base import EvaluationPipeline, EvaluationResult
10
+ from ragbits.document_search.documents.element import Element
11
+ from ragbits.evaluate.pipelines.base import EvaluationData, EvaluationPipeline, EvaluationResult
12
+
13
+
14
+ class DocumentSearchData(EvaluationData):
15
+ """
16
+ Represents the evaluation data for document search.
17
+ """
18
+
19
+ question: str
20
+ reference_document_ids: list[str | int] | None = None
21
+ reference_passages: list[str] | None = None
22
+ reference_page_numbers: list[int] | None = None
14
23
 
15
24
 
16
25
  @dataclass
@@ -20,82 +29,78 @@ class DocumentSearchResult(EvaluationResult):
20
29
  """
21
30
 
22
31
  question: str
23
- reference_passages: list[str]
24
- predicted_passages: list[str]
32
+ predicted_elements: Sequence[Element]
33
+ reference_document_ids: list[str | int] | None = None
34
+ reference_passages: list[str] | None = None
35
+ reference_page_numbers: list[int] | None = None
25
36
 
26
37
 
27
- class DocumentSearchPipeline(EvaluationPipeline):
38
+ class DocumentSearchPipeline(EvaluationPipeline[DocumentSearch, DocumentSearchData, DocumentSearchResult]):
28
39
  """
29
40
  Document search evaluation pipeline.
30
41
  """
31
42
 
32
- @cached_property
33
- def document_search(self) -> "DocumentSearch":
43
+ def __init__(self, evaluation_target: DocumentSearch, source: dict | None = None) -> None:
34
44
  """
35
- Returns the document search instance.
45
+ Initialize the document search evaluation pipeline.
36
46
 
37
- Returns:
38
- The document search instance.
47
+ Args:
48
+ evaluation_target: Document Search instance.
49
+ source: Source data config for ingest.
39
50
  """
40
- return DocumentSearch.from_config(self.config) # type: ignore
51
+ super().__init__(evaluation_target=evaluation_target)
52
+ self.source = source or {}
41
53
 
42
- async def __call__(self, data: dict) -> DocumentSearchResult:
54
+ @classmethod
55
+ def from_config(cls, config: dict) -> Self:
43
56
  """
44
- Runs the document search evaluation pipeline.
57
+ Create an instance of `DocumentSearchPipeline` from a configuration dictionary.
45
58
 
46
59
  Args:
47
- data: The evaluation data.
60
+ config: A dictionary containing configuration settings for the pipeline.
48
61
 
49
62
  Returns:
50
- The evaluation result.
63
+ An instance of the pipeline class initialized with the provided configuration.
51
64
  """
52
- elements = await self.document_search.search(data["question"])
53
- predicted_passages = [element.content for element in elements if isinstance(element, TextElement)]
54
- return DocumentSearchResult(
55
- question=data["question"],
56
- reference_passages=data["passages"],
57
- predicted_passages=predicted_passages,
58
- )
59
-
60
-
61
- class DocumentSearchWithIngestionPipeline(DocumentSearchPipeline):
62
- """
63
- A class for joint doument ingestion and search
64
- """
65
-
66
- def __init__(self, config: DictConfig | None = None) -> None:
67
- super().__init__(config)
68
- self.config.vector_store.config.index_name = str(uuid.uuid4())
69
- self._ingested = False
70
- self._lock = asyncio.Lock()
71
-
72
- async def __call__(self, data: dict) -> DocumentSearchResult:
65
+ # At this point, we assume that if the source is set, the pipeline is run in experimental mode
66
+ # and create random indexes for testing
67
+ # TODO: optimize this for cases with duplicated document search configs between runs
68
+ if config.get("source"):
69
+ config["vector_store"]["config"]["index_name"] = str(uuid4())
70
+ evaluation_target: DocumentSearch = DocumentSearch.from_config(config)
71
+ return cls(evaluation_target=evaluation_target, source=config.get("source"))
72
+
73
+ async def prepare(self) -> None:
73
74
  """
74
- Queries a vector store with given data
75
- Ingests the corpus to the store if has not been done
75
+ Ingest corpus data for evaluation.
76
+ """
77
+ if self.source:
78
+ # For now we only support HF sources for pre-evaluation ingest
79
+ # TODO: Make it generic to any data source
80
+ sources = await HuggingFaceSource.list_sources(
81
+ path=self.source["config"]["path"],
82
+ split=self.source["config"]["split"],
83
+ )
84
+ await self.evaluation_target.ingest(sources)
85
+
86
+ async def __call__(self, data: Iterable[DocumentSearchData]) -> Iterable[DocumentSearchResult]:
87
+ """
88
+ Run the document search evaluation pipeline.
89
+
76
90
  Args:
77
- data: dict - query
91
+ data: The evaluation data batch.
92
+
78
93
  Returns:
79
- DocumentSearchResult - query result
94
+ The evaluation result batch.
80
95
  """
81
- async with self._lock:
82
- if not self._ingested:
83
- await self._ingest_documents()
84
- self._ingested = True
85
- return await super().__call__(data)
86
-
87
- async def _ingest_documents(self) -> None:
88
- documents = await tqdm.gather(
89
- *[
90
- DocumentMeta.from_source(
91
- HuggingFaceSource(
92
- path=self.config.answer_data_source.path,
93
- split=self.config.answer_data_source.split,
94
- row=i,
95
- )
96
- )
97
- for i in range(self.config.answer_data_source.num_docs)
98
- ],
99
- desc="Download",
100
- )
101
- await self.document_search.ingest(documents)
96
+ results = await asyncio.gather(*[self.evaluation_target.search(row.question) for row in data])
97
+ return [
98
+ DocumentSearchResult(
99
+ question=row.question,
100
+ predicted_elements=elements,
101
+ reference_document_ids=row.reference_document_ids,
102
+ reference_passages=row.reference_passages,
103
+ reference_page_numbers=row.reference_page_numbers,
104
+ )
105
+ for row, elements in zip(data, results, strict=False)
106
+ ]