ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. ragbits/evaluate/agent_simulation/__init__.py +87 -0
  2. ragbits/evaluate/agent_simulation/context.py +118 -0
  3. ragbits/evaluate/agent_simulation/conversation.py +333 -0
  4. ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
  5. ragbits/evaluate/agent_simulation/logger.py +165 -0
  6. ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
  7. ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
  8. ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
  9. ragbits/evaluate/agent_simulation/models.py +37 -0
  10. ragbits/evaluate/agent_simulation/results.py +200 -0
  11. ragbits/evaluate/agent_simulation/scenarios.py +129 -0
  12. ragbits/evaluate/agent_simulation/simulation.py +243 -0
  13. ragbits/evaluate/cli.py +150 -0
  14. ragbits/evaluate/config.py +11 -0
  15. ragbits/evaluate/dataloaders/__init__.py +3 -0
  16. ragbits/evaluate/dataloaders/base.py +95 -0
  17. ragbits/evaluate/dataloaders/document_search.py +61 -0
  18. ragbits/evaluate/dataloaders/exceptions.py +25 -0
  19. ragbits/evaluate/dataloaders/gaia.py +78 -0
  20. ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
  21. ragbits/evaluate/dataloaders/human_eval.py +70 -0
  22. ragbits/evaluate/dataloaders/question_answer.py +56 -0
  23. ragbits/evaluate/dataset_generator/pipeline.py +4 -4
  24. ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
  25. ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
  26. ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
  27. ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
  28. ragbits/evaluate/evaluator.py +178 -50
  29. ragbits/evaluate/factories/__init__.py +42 -0
  30. ragbits/evaluate/metrics/__init__.py +2 -23
  31. ragbits/evaluate/metrics/base.py +40 -17
  32. ragbits/evaluate/metrics/document_search.py +40 -23
  33. ragbits/evaluate/metrics/gaia.py +84 -0
  34. ragbits/evaluate/metrics/hotpot_qa.py +51 -0
  35. ragbits/evaluate/metrics/human_eval.py +105 -0
  36. ragbits/evaluate/metrics/question_answer.py +222 -0
  37. ragbits/evaluate/optimizer.py +138 -86
  38. ragbits/evaluate/pipelines/__init__.py +37 -0
  39. ragbits/evaluate/pipelines/base.py +34 -10
  40. ragbits/evaluate/pipelines/document_search.py +72 -67
  41. ragbits/evaluate/pipelines/gaia.py +249 -0
  42. ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
  43. ragbits/evaluate/pipelines/human_eval.py +323 -0
  44. ragbits/evaluate/pipelines/question_answer.py +96 -0
  45. ragbits/evaluate/utils.py +86 -59
  46. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
  47. ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
  48. {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
  49. ragbits/evaluate/callbacks/base.py +0 -22
  50. ragbits/evaluate/callbacks/neptune.py +0 -26
  51. ragbits/evaluate/loaders/__init__.py +0 -21
  52. ragbits/evaluate/loaders/base.py +0 -24
  53. ragbits/evaluate/loaders/hf.py +0 -25
  54. ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
  55. /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
@@ -0,0 +1,150 @@
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Annotated
5
+
6
+ import typer
7
+ from pydantic import BaseModel
8
+
9
+ from ragbits.cli._utils import get_instance_or_exit
10
+ from ragbits.cli.state import print_output
11
+ from ragbits.core.utils.config_handling import WithConstructionConfig
12
+ from ragbits.evaluate.config import eval_config
13
+ from ragbits.evaluate.dataloaders import DataLoader
14
+ from ragbits.evaluate.evaluator import Evaluator
15
+ from ragbits.evaluate.metrics.base import MetricSet
16
+ from ragbits.evaluate.pipelines import get_evaluation_pipeline_for_target
17
+ from ragbits.evaluate.pipelines.base import EvaluationPipeline
18
+
19
+ eval_app = typer.Typer(no_args_is_help=True)
20
+
21
+
22
+ def register(app: typer.Typer) -> None:
23
+ """
24
+ Register the CLI commands for the package.
25
+
26
+ Args:
27
+ app: The Typer object to register the commands with.
28
+ """
29
+ app.add_typer(eval_app, name="evaluate", help="Commands for interacting with ragbits evaluate module")
30
+
31
+
32
+ @dataclass
33
+ class _CLIState:
34
+ dataloader: DataLoader | None = None
35
+ pipeline: EvaluationPipeline | None = None
36
+ metrics: MetricSet | None = None
37
+
38
+
39
+ class EvaluationResult(BaseModel):
40
+ """A container for evaluation results"""
41
+
42
+ metrics: dict
43
+
44
+
45
+ state: _CLIState = _CLIState()
46
+
47
+
48
+ @eval_app.callback()
49
+ def common_args(
50
+ dataloader_factory_path: Annotated[
51
+ str | None,
52
+ typer.Option(
53
+ help="A path to evaluation data loader factory in format python.path:function_name",
54
+ exists=True,
55
+ resolve_path=True,
56
+ ),
57
+ ] = None,
58
+ dataloader_yaml_path: Annotated[
59
+ Path | None,
60
+ typer.Option(
61
+ help="A path to evaluation data loader configuration",
62
+ exists=True,
63
+ resolve_path=True,
64
+ ),
65
+ ] = None,
66
+ target_factory_path: Annotated[
67
+ str | None,
68
+ typer.Option(
69
+ help="A path to a factory of the evaluation target class in format: python.path:function_name",
70
+ exists=True,
71
+ resolve_path=True,
72
+ ),
73
+ ] = None,
74
+ target_yaml_path: Annotated[
75
+ Path | None,
76
+ typer.Option(
77
+ help="A path to a YAML configuration file of the evaluation target class",
78
+ exists=True,
79
+ resolve_path=True,
80
+ ),
81
+ ] = None,
82
+ metrics_factory_path: Annotated[
83
+ str | None,
84
+ typer.Option(
85
+ help="A path to metrics factory in format python.path:function_name",
86
+ exists=True,
87
+ resolve_path=True,
88
+ ),
89
+ ] = None,
90
+ metrics_yaml_path: Annotated[
91
+ Path | None,
92
+ typer.Option(
93
+ help="A path to metrics configuration",
94
+ exists=True,
95
+ resolve_path=True,
96
+ ),
97
+ ] = None,
98
+ ) -> None:
99
+ """
100
+ Common arguments for the evaluate commands.
101
+ """
102
+ evaluation_target = get_instance_or_exit(
103
+ cls=WithConstructionConfig,
104
+ factory_path=target_factory_path,
105
+ yaml_path=target_yaml_path,
106
+ config_override=eval_config,
107
+ )
108
+ state.pipeline = get_evaluation_pipeline_for_target(evaluation_target)
109
+ # TODO: validate if given dataloader is suitable for evaluation pipeline
110
+ state.dataloader = get_instance_or_exit(
111
+ cls=DataLoader,
112
+ factory_path=dataloader_factory_path,
113
+ yaml_path=dataloader_yaml_path,
114
+ config_override=eval_config,
115
+ )
116
+ # TODO: validate if given metric set is suitable for evaluation pipeline
117
+ state.metrics = get_instance_or_exit(
118
+ cls=MetricSet,
119
+ factory_path=metrics_factory_path,
120
+ yaml_path=metrics_yaml_path,
121
+ config_override=eval_config,
122
+ )
123
+
124
+
125
+ @eval_app.command()
126
+ def run() -> None:
127
+ """
128
+ Evaluate the pipeline.
129
+ """
130
+
131
+ async def run() -> None:
132
+ if state.dataloader is None:
133
+ raise ValueError("Evaluation dataloader not initialized")
134
+ if state.pipeline is None:
135
+ raise ValueError("Evaluation pipeline not initialized")
136
+ if state.metrics is None:
137
+ raise ValueError("Evaluation metrics not initialized")
138
+
139
+ evaluator = Evaluator()
140
+ metric_results = await evaluator.compute(
141
+ pipeline=state.pipeline,
142
+ dataloader=state.dataloader,
143
+ metricset=state.metrics,
144
+ )
145
+ evaluation_results = EvaluationResult(
146
+ metrics={"metrics": metric_results.metrics, "time_perf": metric_results.time_perf}
147
+ )
148
+ print_output(evaluation_results)
149
+
150
+ asyncio.run(run())
@@ -0,0 +1,11 @@
1
+ from ragbits.core.config import CoreConfig
2
+ from ragbits.core.utils._pyproject import get_config_instance
3
+
4
+
5
+ class EvaluateConfig(CoreConfig):
6
+ """
7
+ Configuration for the ragbits-evaluate package, loaded from downstream projects' pyproject.toml files.
8
+ """
9
+
10
+
11
+ eval_config = get_config_instance(EvaluateConfig, subproject="evaluate")
@@ -0,0 +1,3 @@
1
+ from ragbits.evaluate.dataloaders.base import DataLoader
2
+
3
+ __all__ = ["DataLoader"]
@@ -0,0 +1,95 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable
3
+ from types import ModuleType
4
+ from typing import ClassVar, Generic
5
+
6
+ from datasets import load_dataset
7
+ from pydantic import BaseModel
8
+ from typing_extensions import Self
9
+
10
+ from ragbits.core.sources.base import Source
11
+ from ragbits.core.utils.config_handling import ObjectConstructionConfig, WithConstructionConfig
12
+ from ragbits.evaluate import dataloaders
13
+ from ragbits.evaluate.dataloaders.exceptions import DataLoaderIncorrectFormatDataError
14
+ from ragbits.evaluate.pipelines.base import EvaluationDataT
15
+
16
+
17
+ class DataLoaderConfig(BaseModel):
18
+ """
19
+ Schema for the data loader config.
20
+ """
21
+
22
+ source: ObjectConstructionConfig
23
+
24
+
25
+ class DataLoader(WithConstructionConfig, Generic[EvaluationDataT], ABC):
26
+ """
27
+ Evaluation data loader.
28
+ """
29
+
30
+ default_module: ClassVar[ModuleType | None] = dataloaders
31
+ configuration_key: ClassVar[str] = "dataloader"
32
+
33
+ def __init__(self, source: Source, *, split: str = "data", required_keys: set[str] | None = None) -> None:
34
+ """
35
+ Initialize the data loader.
36
+
37
+ Args:
38
+ source: The source to load the evaluation data from.
39
+ split: The split to load the data from. Split is fixed for data loaders to "data",
40
+ but you can slice it using the [Hugging Face API](https://huggingface.co/docs/datasets/v1.11.0/splits.html#slicing-api).
41
+ required_keys: The required columns for the evaluation data.
42
+ """
43
+ self.source = source
44
+ self.split = split
45
+ self.required_keys = required_keys or set()
46
+
47
+ @classmethod
48
+ def from_config(cls, config: dict) -> Self:
49
+ """
50
+ Create an instance of `DataLoader` from a configuration dictionary.
51
+
52
+ Args:
53
+ config: A dictionary containing configuration settings for the data loader.
54
+
55
+ Returns:
56
+ An instance of the data loader class initialized with the provided configuration.
57
+ """
58
+ dataloader_config = DataLoaderConfig.model_validate(config)
59
+ config["source"] = Source.subclass_from_config(dataloader_config.source)
60
+ return super().from_config(config)
61
+
62
+ async def load(self) -> Iterable[EvaluationDataT]:
63
+ """
64
+ Load the data.
65
+
66
+ Returns:
67
+ The loaded evaluation data.
68
+
69
+ Raises:
70
+ DataLoaderIncorrectFormatDataError: If evaluation dataset is incorrectly formatted.
71
+ """
72
+ data_path = await self.source.fetch()
73
+ dataset = load_dataset(
74
+ path=str(data_path.parent),
75
+ data_files={"data": str(data_path.name)},
76
+ split=self.split,
77
+ )
78
+ if not self.required_keys.issubset(dataset.features):
79
+ raise DataLoaderIncorrectFormatDataError(
80
+ required_features=list(self.required_keys),
81
+ data_path=data_path,
82
+ )
83
+ return await self.map(dataset.to_list())
84
+
85
+ @abstractmethod
86
+ async def map(self, dataset: Iterable[dict]) -> Iterable[EvaluationDataT]:
87
+ """
88
+ Map the dataset to the evaluation data.
89
+
90
+ Args:
91
+ dataset: The dataset to map.
92
+
93
+ Returns:
94
+ The evaluation data.
95
+ """
@@ -0,0 +1,61 @@
1
+ from collections.abc import Iterable
2
+
3
+ from ragbits.core.sources.base import Source
4
+ from ragbits.evaluate.dataloaders.base import DataLoader
5
+ from ragbits.evaluate.pipelines.document_search import DocumentSearchData
6
+
7
+
8
+ class DocumentSearchDataLoader(DataLoader[DocumentSearchData]):
9
+ """
10
+ Document search evaluation data loader.
11
+
12
+ The source used for this data loader should point to a file that can be loaded by [Hugging Face](https://huggingface.co/docs/datasets/loading#local-and-remote-files).
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ source: Source,
18
+ *,
19
+ split: str = "data",
20
+ question_key: str = "question",
21
+ document_ids_key: str = "document_ids",
22
+ passages_key: str = "passages",
23
+ page_numbers_key: str = "page_numbers",
24
+ ) -> None:
25
+ """
26
+ Initialize the document search data loader.
27
+
28
+ Args:
29
+ source: The source to load the data from.
30
+ split: The split to load the data from. Split is fixed for data loaders to "data",
31
+ but you can slice it using the [Hugging Face API](https://huggingface.co/docs/datasets/v1.11.0/splits.html#slicing-api).
32
+ question_key: The dataset column name that contains the question.
33
+ document_ids_key: The dataset column name that contains the document ids. Document ids are optional.
34
+ passages_key: The dataset column name that contains the passages. Passages are optional.
35
+ page_numbers_key: The dataset column name that contains the page numbers. Page numbers are optional.
36
+ """
37
+ super().__init__(source=source, split=split, required_keys={question_key})
38
+ self.question_key = question_key
39
+ self.document_ids_key = document_ids_key
40
+ self.passages_key = passages_key
41
+ self.page_numbers_key = page_numbers_key
42
+
43
+ async def map(self, dataset: Iterable[dict]) -> Iterable[DocumentSearchData]:
44
+ """
45
+ Map the dataset to the document search data schema.
46
+
47
+ Args:
48
+ dataset: The dataset to map.
49
+
50
+ Returns:
51
+ The document search data.
52
+ """
53
+ return [
54
+ DocumentSearchData(
55
+ question=data.get(self.question_key, ""),
56
+ reference_document_ids=data.get(self.document_ids_key),
57
+ reference_passages=data.get(self.passages_key),
58
+ reference_page_numbers=data.get(self.page_numbers_key),
59
+ )
60
+ for data in dataset
61
+ ]
@@ -0,0 +1,25 @@
1
+ from pathlib import Path
2
+
3
+
4
+ class DataLoaderError(Exception):
5
+ """
6
+ Class for all exceptions raised by the data loader.
7
+ """
8
+
9
+ def __init__(self, message: str, data_path: Path) -> None:
10
+ super().__init__(message)
11
+ self.message = message
12
+ self.data_path = data_path
13
+
14
+
15
+ class DataLoaderIncorrectFormatDataError(DataLoaderError):
16
+ """
17
+ Raised when the data are incorrectly formatted.
18
+ """
19
+
20
+ def __init__(self, required_features: list[str], data_path: Path) -> None:
21
+ super().__init__(
22
+ message=f"Dataset {data_path} is incorrectly formatted. Required features: {required_features}",
23
+ data_path=data_path,
24
+ )
25
+ self.required_features = required_features
@@ -0,0 +1,78 @@
1
+ from collections.abc import Iterable
2
+
3
+ from ragbits.core.sources.base import Source
4
+ from ragbits.evaluate.dataloaders.base import DataLoader
5
+ from ragbits.evaluate.pipelines.gaia import GaiaData
6
+
7
+
8
+ class GaiaDataLoader(DataLoader[GaiaData]):
9
+ """
10
+ GAIA benchmark evaluation data loader.
11
+
12
+ The source should point to a local/remote JSON or JSONL file exported from the
13
+ Hugging Face dataset `gaia-benchmark/GAIA`. Rows are expected to contain at least:
14
+ - "task_id" (str)
15
+ - "Question" (str)
16
+ - "Level" (int)
17
+ - "Final answer" (str)
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ source: Source,
23
+ *,
24
+ split: str = "data",
25
+ task_id_key: str = "task_id",
26
+ question_key: str = "Question",
27
+ level_key: str = "Level",
28
+ final_answer_key: str = "Final answer",
29
+ file_name_key: str = "file_name",
30
+ skip_file_attachments: bool = False,
31
+ ) -> None:
32
+ """
33
+ Initialize the GAIA data loader.
34
+
35
+ Args:
36
+ source: The source to load the data from.
37
+ split: The split to load the data from (file name generated by the source helper).
38
+ task_id_key: Column name for GAIA task identifier.
39
+ question_key: Column name for the natural language question.
40
+ level_key: Column name for numeric difficulty level (1, 2, 3).
41
+ final_answer_key: Column name for the final ground-truth answer.
42
+ file_name_key: Column name with optional associated file name (may be empty).
43
+ skip_file_attachments: If True, skip rows that have a non-empty file attachment.
44
+ """
45
+ required = {task_id_key, question_key, level_key, final_answer_key}
46
+ super().__init__(source=source, split=split, required_keys=required)
47
+ self.task_id_key = task_id_key
48
+ self.question_key = question_key
49
+ self.level_key = level_key
50
+ self.final_answer_key = final_answer_key
51
+ self.file_name_key = file_name_key
52
+ self.skip_file_attachments = skip_file_attachments
53
+
54
+ async def map(self, dataset: Iterable[dict]) -> Iterable[GaiaData]:
55
+ """
56
+ Map the dataset to the GAIA evaluation data schema.
57
+
58
+ Args:
59
+ dataset: The dataset to map.
60
+
61
+ Returns:
62
+ The GAIA evaluation data rows.
63
+ """
64
+ return [
65
+ GaiaData(
66
+ task_id=str(row.get(self.task_id_key, "")),
67
+ question=str(row.get(self.question_key, "")),
68
+ level=int(row.get(self.level_key, 1)),
69
+ reference_answer=str(row.get(self.final_answer_key, "")),
70
+ file_name=(row.get(self.file_name_key) or None),
71
+ )
72
+ for row in dataset
73
+ if (
74
+ not self.skip_file_attachments
75
+ or not row.get(self.file_name_key)
76
+ or str(row.get(self.file_name_key)).strip() == ""
77
+ )
78
+ ]
@@ -0,0 +1,95 @@
1
+ from collections.abc import Iterable
2
+ from typing import Any
3
+
4
+ from ragbits.core.sources.base import Source
5
+ from ragbits.evaluate.dataloaders.base import DataLoader
6
+ from ragbits.evaluate.pipelines.hotpot_qa import HotpotQAData
7
+
8
+
9
+ class HotpotQADataLoader(DataLoader[HotpotQAData]):
10
+ """
11
+ HotpotQA evaluation data loader.
12
+
13
+ The source should point to a local/remote JSON file exported from Hugging Face, where each example includes at
14
+ least the following keys:
15
+ - "id" (str)
16
+ - "question" (str)
17
+ - "answer" (str)
18
+ - "type" ("bridge" | "comparison")
19
+ - "level" ("easy" | "medium" | "hard")
20
+ - "context" (object with keys: "title": list[str], "sentences": list[list[str]])
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ source: Source,
26
+ *,
27
+ split: str = "data",
28
+ id_key: str = "id",
29
+ question_key: str = "question",
30
+ answer_key: str = "answer",
31
+ type_key: str = "type",
32
+ level_key: str = "level",
33
+ context_key: str = "context",
34
+ # filter
35
+ level_filter: str | None = None, # one of: easy|medium|hard
36
+ ) -> None:
37
+ """
38
+ Initialize the HotpotQA data loader.
39
+
40
+ Args:
41
+ source: The source to load the data from.
42
+ split: The split to load the data from.
43
+ id_key: Column with unique id.
44
+ question_key: Column with question text.
45
+ answer_key: Column with ground truth answer.
46
+ type_key: Column with question type ("bridge" | "comparison").
47
+ level_key: Column with difficulty ("easy" | "medium" | "hard").
48
+ context_key: Column with context object containing titles and sentences.
49
+ level_filter: If provided, return only examples with this level.
50
+ """
51
+ required = {id_key, question_key, answer_key, type_key, level_key, context_key}
52
+ super().__init__(source=source, split=split, required_keys=required)
53
+ self.id_key = id_key
54
+ self.question_key = question_key
55
+ self.answer_key = answer_key
56
+ self.type_key = type_key
57
+ self.level_key = level_key
58
+ self.context_key = context_key
59
+ self.level_filter = level_filter
60
+
61
+ async def map(self, dataset: Iterable[dict]) -> Iterable[HotpotQAData]:
62
+ """
63
+ Map the dataset to the HotpotQA evaluation data schema.
64
+
65
+ Args:
66
+ dataset: The dataset to map.
67
+
68
+ Returns:
69
+ The HotpotQA evaluation data rows.
70
+ """
71
+
72
+ def to_context_rows(context: dict[str, Any]) -> list[str]:
73
+ titles = context.get("title", []) or []
74
+ sentences = context.get("sentences", []) or []
75
+ rows: list[str] = []
76
+ for title, sent_list in zip(titles, sentences, strict=False):
77
+ doc_text = "\n".join(sent_list) if isinstance(sent_list, list) else str(sent_list)
78
+ rows.append(f"{title}\n{doc_text}")
79
+ if not rows and isinstance(sentences, list):
80
+ flat = "\n".join([" ".join(s) if isinstance(s, list) else str(s) for s in sentences])
81
+ rows = [flat]
82
+ return rows
83
+
84
+ return [
85
+ HotpotQAData(
86
+ id=row.get(self.id_key, ""),
87
+ question=row.get(self.question_key, ""),
88
+ reference_answer=str(row.get(self.answer_key, "")),
89
+ qtype=str(row.get(self.type_key, "")),
90
+ level=(row.get(self.level_key) or "").lower(),
91
+ reference_context=to_context_rows(row.get(self.context_key, {}) or {}),
92
+ )
93
+ for row in dataset
94
+ if not self.level_filter or (row.get(self.level_key, "").lower() == self.level_filter)
95
+ ]
@@ -0,0 +1,70 @@
1
+ from collections.abc import Iterable
2
+
3
+ from ragbits.core.sources.base import Source
4
+ from ragbits.evaluate.dataloaders.base import DataLoader
5
+ from ragbits.evaluate.pipelines.human_eval import HumanEvalData
6
+
7
+
8
+ class HumanEvalDataLoader(DataLoader[HumanEvalData]):
9
+ """
10
+ HumanEval evaluation data loader.
11
+
12
+ The source should point to a local/remote JSONL file in HumanEval format, where each line is a JSON object
13
+ with at least the following keys: "
14
+ - task_id" (str)
15
+ - "prompt" (str)
16
+ - "entry_point" (str)
17
+ - "test" (str)
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ source: Source,
23
+ *,
24
+ split: str = "data",
25
+ task_id_key: str = "task_id",
26
+ prompt_key: str = "prompt",
27
+ entry_point_key: str = "entry_point",
28
+ test_key: str = "test",
29
+ canonical_solution_key: str | None = "canonical_solution",
30
+ ) -> None:
31
+ """
32
+ Initialize the HumanEval data loader.
33
+
34
+ Args:
35
+ source: The source to load the data from.
36
+ split: The split to load the data from.
37
+ task_id_key: Dataset column with the HumanEval task identifier.
38
+ prompt_key: Dataset column with the Python prompt (function signature and docstring).
39
+ entry_point_key: Dataset column with the function name to evaluate.
40
+ test_key: Dataset column with the Python test harness defining `check(candidate)`.
41
+ canonical_solution_key: Optional dataset column with the reference solution (not used for scoring).
42
+ """
43
+ required = {task_id_key, prompt_key, entry_point_key, test_key}
44
+ super().__init__(source=source, split=split, required_keys=required)
45
+ self.task_id_key = task_id_key
46
+ self.prompt_key = prompt_key
47
+ self.entry_point_key = entry_point_key
48
+ self.test_key = test_key
49
+ self.canonical_solution_key = canonical_solution_key
50
+
51
+ async def map(self, dataset: Iterable[dict]) -> Iterable[HumanEvalData]:
52
+ """
53
+ Map the dataset to the HumanEval evaluation data schema.
54
+
55
+ Args:
56
+ dataset: The dataset to map.
57
+
58
+ Returns:
59
+ The HumanEval evaluation data rows.
60
+ """
61
+ return [
62
+ HumanEvalData(
63
+ task_id=row.get(self.task_id_key, ""),
64
+ prompt=row.get(self.prompt_key, ""),
65
+ entry_point=row.get(self.entry_point_key, ""),
66
+ test=row.get(self.test_key, ""),
67
+ canonical_solution=(row.get(self.canonical_solution_key) if self.canonical_solution_key else None),
68
+ )
69
+ for row in dataset
70
+ ]
@@ -0,0 +1,56 @@
1
+ from collections.abc import Iterable
2
+
3
+ from ragbits.core.sources.base import Source
4
+ from ragbits.evaluate.dataloaders.base import DataLoader
5
+ from ragbits.evaluate.pipelines.question_answer import QuestionAnswerData
6
+
7
+
8
+ class QuestionAnswerDataLoader(DataLoader[QuestionAnswerData]):
9
+ """
10
+ Question answer evaluation data loader.
11
+
12
+ The source used for this data loader should point to a file that can be loaded by [Hugging Face](https://huggingface.co/docs/datasets/loading#local-and-remote-files).
13
+ """
14
+
15
+ def __init__(
16
+ self,
17
+ source: Source,
18
+ *,
19
+ split: str = "data",
20
+ question_key: str = "question",
21
+ answer_key: str = "answer",
22
+ context_key: str = "context",
23
+ ) -> None:
24
+ """
25
+ Initialize the question answer data loader.
26
+
27
+ Args:
28
+ source: The source to load the data from.
29
+ split: The split to load the data from.
30
+ question_key: The dataset column name that contains the question.
31
+ answer_key: The dataset column name that contains the answer.
32
+ context_key: The dataset column name that contains the context. Context is optional.
33
+ """
34
+ super().__init__(source=source, split=split, required_keys={question_key, answer_key})
35
+ self.question_key = question_key
36
+ self.answer_key = answer_key
37
+ self.context_key = context_key
38
+
39
+ async def map(self, dataset: Iterable[dict]) -> Iterable[QuestionAnswerData]:
40
+ """
41
+ Map the dataset to the question answer data schema.
42
+
43
+ Args:
44
+ dataset: The dataset to map.
45
+
46
+ Returns:
47
+ The question answer data.
48
+ """
49
+ return [
50
+ QuestionAnswerData(
51
+ question=data.get(self.question_key, ""),
52
+ reference_answer=data.get(self.answer_key, ""),
53
+ reference_context=data.get(self.context_key),
54
+ )
55
+ for data in dataset
56
+ ]
@@ -7,7 +7,7 @@ from distilabel.steps.base import Step
7
7
  from omegaconf import DictConfig, OmegaConf
8
8
  from pydantic import BaseModel
9
9
 
10
- from ragbits.core.utils.config_handling import get_cls_from_config
10
+ from ragbits.core.utils.config_handling import import_by_path
11
11
 
12
12
  module = sys.modules[__name__]
13
13
 
@@ -120,14 +120,14 @@ class DatasetGenerationPipeline:
120
120
  tasks = []
121
121
  for task_config in self.config.tasks:
122
122
  llm_config = task_config.llm
123
- llm = get_cls_from_config(llm_config.provider_type, module)(**llm_config.kwargs)
123
+ llm = import_by_path(llm_config.provider_type, module)(**llm_config.kwargs)
124
124
  task_kwargs: dict[Any, Any] = {"llm": llm}
125
125
  task_kwargs.update(task_config.kwargs or {}) # type: ignore
126
- task = get_cls_from_config(task_config.type, module)(**task_kwargs)
126
+ task = import_by_path(task_config.type, module)(**task_kwargs)
127
127
  tasks.append(task)
128
128
  filter_types = getattr(task_config, "filters", None) or []
129
129
  for filter_type in filter_types:
130
- filter = get_cls_from_config(filter_type, module)(tasks[-1])
130
+ filter = import_by_path(filter_type, module)(tasks[-1])
131
131
  tasks.append(filter)
132
132
  return tasks
133
133