retrievalbase 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. retrievalbase/__init__.py +0 -0
  2. retrievalbase/connector/__init__.py +69 -0
  3. retrievalbase/connector/minio.py +45 -0
  4. retrievalbase/connector/parquet.py +20 -0
  5. retrievalbase/connector/settings.py +22 -0
  6. retrievalbase/constants.py +1 -0
  7. retrievalbase/dataset/__init__.py +146 -0
  8. retrievalbase/dataset/hf.py +49 -0
  9. retrievalbase/dataset/mixins.py +108 -0
  10. retrievalbase/dataset/polars.py +43 -0
  11. retrievalbase/dataset/preprocess/__init__.py +29 -0
  12. retrievalbase/dataset/preprocess/preprocess.py +96 -0
  13. retrievalbase/dataset/preprocess/token_counter.py +41 -0
  14. retrievalbase/dataset/settings.py +63 -0
  15. retrievalbase/enums.py +11 -0
  16. retrievalbase/evaluation/__init__.py +179 -0
  17. retrievalbase/evaluation/async_batcher.py +79 -0
  18. retrievalbase/evaluation/embedders.py +28 -0
  19. retrievalbase/evaluation/evaluators/__init__.py +37 -0
  20. retrievalbase/evaluation/evaluators/python/__init__.py +149 -0
  21. retrievalbase/evaluation/evaluators/python/evaluators.py +71 -0
  22. retrievalbase/evaluation/evaluators/python/scores.py +118 -0
  23. retrievalbase/evaluation/processors.py +15 -0
  24. retrievalbase/evaluation/rerankers.py +182 -0
  25. retrievalbase/evaluation/retrievers/__init__.py +112 -0
  26. retrievalbase/evaluation/retrievers/dense/__init__.py +56 -0
  27. retrievalbase/evaluation/retrievers/dense/retrievers.py +86 -0
  28. retrievalbase/evaluation/settings.py +204 -0
  29. retrievalbase/evaluation/vector_stores.py +146 -0
  30. retrievalbase/exceptions.py +61 -0
  31. retrievalbase/ingestion/__init__.py +50 -0
  32. retrievalbase/ingestion/settings.py +10 -0
  33. retrievalbase/mixins.py +85 -0
  34. retrievalbase/py.typed +0 -0
  35. retrievalbase/settings.py +33 -0
  36. retrievalbase/types.py +55 -0
  37. retrievalbase/utils.py +107 -0
  38. retrievalbase-1.0.0.dist-info/METADATA +23 -0
  39. retrievalbase-1.0.0.dist-info/RECORD +40 -0
  40. retrievalbase-1.0.0.dist-info/WHEEL +4 -0
File without changes
@@ -0,0 +1,69 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import polars as pl
6
+
7
+ from retrievalbase.connector.settings import DatasetConnectorSettings
8
+ from retrievalbase.mixins import FromConfigMixin
9
+ from retrievalbase.types import TCDatasetConnector as TCDatasetConnector
10
+
11
+ if TYPE_CHECKING:
12
+ from retrievalbase.dataset import Dataset, TextDataset
13
+
14
+ _logger = logging.getLogger(__name__)
15
+
16
+
17
+ class DatasetConnector[TCDatasetConnector: DatasetConnectorSettings](
18
+ FromConfigMixin[TCDatasetConnector],
19
+ ABC,
20
+ ):
21
+ def __init__(self, config: TCDatasetConnector):
22
+ super().__init__(config)
23
+
24
+ _logger.info(
25
+ f"Initializing dataset connector | class={self.__class__.__name__} | module={self.__class__.__module__}"
26
+ )
27
+
28
+ @abstractmethod
29
+ def _load(self) -> pl.DataFrame | pl.LazyFrame:
30
+ """
31
+ Load raw data as Polars DataFrame or LazyFrame.
32
+ """
33
+ raise NotImplementedError()
34
+
35
+ @abstractmethod
36
+ def to(self, ds: "Dataset[Any]") -> None:
37
+ raise NotImplementedError()
38
+
39
+ def load(self) -> "Dataset[pl.DataFrame | pl.LazyFrame]":
40
+ from retrievalbase.dataset.polars import PolarsDataset
41
+
42
+ _logger.info(f"Loading dataset | connector={self.__class__.__name__}")
43
+
44
+ df = self._load()
45
+ self._log_polars_info(df)
46
+
47
+ return PolarsDataset.from_polars(df)
48
+
49
+ def load_text(self) -> "TextDataset[pl.DataFrame | pl.LazyFrame]":
50
+ from retrievalbase.dataset.polars import PolarsTextDataset
51
+
52
+ _logger.info(f"Loading text dataset | connector={self.__class__.__name__}")
53
+
54
+ df = self._load()
55
+ self._log_polars_info(df)
56
+
57
+ return PolarsTextDataset.from_polars(df)
58
+
59
+ # ------------------------------------------------------------------
60
+ # Helpers
61
+ # ------------------------------------------------------------------
62
+
63
+ def _log_polars_info(self, df: pl.DataFrame | pl.LazyFrame) -> None:
64
+ """
65
+ Log dataset structure without forcing materialization.
66
+ """
67
+ if isinstance(df, pl.DataFrame):
68
+ schema = df.schema
69
+ _logger.info(f"Loaded DataFrame | columns={len(schema)} | schema={list(schema.items())}")
@@ -0,0 +1,45 @@
1
+ import io
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ import polars as pl
5
+ from minio import Minio
6
+
7
+ from retrievalbase.connector import DatasetConnector
8
+ from retrievalbase.connector.settings import MinioDatasetConnectorSettings
9
+
10
+ if TYPE_CHECKING:
11
+ from retrievalbase.dataset import Dataset
12
+
13
+
14
+ class MinioDatasetConnector(DatasetConnector[MinioDatasetConnectorSettings]):
15
+ def __init__(self, config: MinioDatasetConnectorSettings):
16
+ super().__init__(config)
17
+ self.client = Minio(
18
+ self.config.endpoint.replace("http://", "").replace("https://", ""),
19
+ access_key=self.config.access_key.get_secret_value(),
20
+ secret_key=self.config.secret_key.get_secret_value(),
21
+ secure=self.config.endpoint.startswith("https://"),
22
+ )
23
+
24
+ def _load(self) -> pl.DataFrame | pl.LazyFrame:
25
+ response = self.client.get_object(self.config.bucket, self.config.key)
26
+ try:
27
+ buffer = io.BytesIO(response.read())
28
+ finally:
29
+ response.close()
30
+ response.release_conn()
31
+ df = pl.read_parquet(buffer)
32
+ return df
33
+
34
+ def to(self, ds: "Dataset[Any]") -> None:
35
+ df = ds.polars
36
+ buffer = io.BytesIO()
37
+ df.write_parquet(buffer)
38
+ buffer.seek(0)
39
+ self.client.put_object(
40
+ bucket_name=self.config.bucket,
41
+ object_name=self.config.key,
42
+ data=buffer,
43
+ length=buffer.getbuffer().nbytes,
44
+ content_type="application/octet-stream",
45
+ )
@@ -0,0 +1,20 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ import polars as pl
4
+
5
+ from retrievalbase.connector import DatasetConnector
6
+ from retrievalbase.connector.settings import ParquetDatasetConnectorSettings
7
+
8
+ if TYPE_CHECKING:
9
+ from retrievalbase.dataset import Dataset
10
+
11
+
12
+ class ParquetDatasetConnector(DatasetConnector[ParquetDatasetConnectorSettings]):
13
+ def __init__(self, config: ParquetDatasetConnectorSettings):
14
+ super().__init__(config)
15
+
16
+ def _load(self) -> pl.DataFrame | pl.LazyFrame:
17
+ return pl.scan_parquet(self.config.path) if self.config.lazy else pl.read_parquet(self.config.path)
18
+
19
+ def to(self, ds: "Dataset[Any]") -> None:
20
+ ds.polars.write_parquet(self.config.path)
@@ -0,0 +1,22 @@
1
+ from pydantic import SecretStr
2
+ from pydantic_settings import SettingsConfigDict
3
+
4
+ from retrievalbase.settings import FromConfigMixinSettings
5
+
6
+
7
+ class DatasetConnectorSettings(FromConfigMixinSettings):
8
+ pass
9
+
10
+
11
+ class ParquetDatasetConnectorSettings(DatasetConnectorSettings):
12
+ path: str
13
+ lazy: bool
14
+
15
+
16
+ class MinioDatasetConnectorSettings(DatasetConnectorSettings):
17
+ endpoint: str
18
+ bucket: str
19
+ key: str
20
+ access_key: SecretStr
21
+ secret_key: SecretStr
22
+ model_config = SettingsConfigDict(env_prefix="MINIO_", extra="ignore")
@@ -0,0 +1 @@
1
+ CONFIG_PATH = "/config/config.yaml"
@@ -0,0 +1,146 @@
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Iterable, Iterator
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Literal,
7
+ Self,
8
+ cast,
9
+ overload,
10
+ )
11
+
12
+ import polars as pl
13
+ from polars._typing import ColumnNameOrSelector, IntoExpr, IntoExprColumn
14
+
15
+ from retrievalbase.dataset.mixins import TextDatasetMixin
16
+ from retrievalbase.types import TDataset as TDataset
17
+ from retrievalbase.utils import _get_minio_connector, _get_parquet_connector
18
+
19
+ if TYPE_CHECKING:
20
+ pass
21
+
22
+
23
+ class Dataset[TDataset](ABC):
24
+ def __init__(self, service: TDataset):
25
+ super().__init__()
26
+ self.service = service
27
+ self._polars: pl.DataFrame | None = None
28
+ self._lazy_polars: pl.LazyFrame | None = None
29
+
30
+ @classmethod
31
+ @abstractmethod
32
+ def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> Self:
33
+ raise NotImplementedError
34
+
35
+ @abstractmethod
36
+ def to_polars(self) -> pl.DataFrame:
37
+ raise NotImplementedError
38
+
39
+ @abstractmethod
40
+ def to_lazy_polars(self) -> pl.LazyFrame:
41
+ raise NotImplementedError
42
+
43
+ @classmethod
44
+ def from_parquet(
45
+ cls,
46
+ path: str,
47
+ *,
48
+ lazy: bool = True,
49
+ ) -> "Dataset[pl.DataFrame | pl.LazyFrame]":
50
+ return _get_parquet_connector(path, lazy=lazy).load()
51
+
52
+ def to_parquet(self, path: str, *, lazy: bool = True) -> None:
53
+ return _get_parquet_connector(path, lazy=lazy).to(self)
54
+
55
+ @classmethod
56
+ def from_minio(
57
+ cls,
58
+ bucket: str,
59
+ key: str,
60
+ endpoint: str,
61
+ access_key: str,
62
+ secret_key: str,
63
+ ) -> "Dataset[pl.DataFrame | pl.LazyFrame]":
64
+ return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).load()
65
+
66
+ def to_minio(self, bucket: str, key: str, endpoint: str, access_key: str, secret_key: str) -> None:
67
+ return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).to(self)
68
+
69
+ @property
70
+ def polars(self) -> pl.DataFrame:
71
+ if self._polars is None:
72
+ self._polars = self.to_polars()
73
+ return self._polars
74
+
75
+ @property
76
+ def lazy_polars(self) -> pl.LazyFrame:
77
+ if self._lazy_polars is None:
78
+ self._lazy_polars = self.to_lazy_polars()
79
+ return self._lazy_polars
80
+
81
+ @property
82
+ def polars_shape(self) -> tuple[int, int]:
83
+ return self.polars.shape
84
+
85
+ def with_columns(
86
+ self,
87
+ *exprs: IntoExpr | Iterable[IntoExpr],
88
+ **named_exprs: IntoExpr,
89
+ ) -> Self:
90
+ df = self.polars.with_columns(*exprs, **named_exprs)
91
+ return self.__class__.from_polars(df)
92
+
93
+ def filter(
94
+ self,
95
+ *predicates: (IntoExprColumn | Iterable[IntoExprColumn] | bool | list[bool]),
96
+ **constraints: Any,
97
+ ) -> Self:
98
+ return self.__class__.from_polars(self.polars.filter(*predicates, **constraints))
99
+
100
+ def drop(
101
+ self,
102
+ *columns: ColumnNameOrSelector | Iterable[ColumnNameOrSelector],
103
+ strict: bool = True,
104
+ ) -> Self:
105
+ return self.__class__.from_polars(self.polars.drop(*columns, strict=strict))
106
+
107
+ def __len__(self) -> int:
108
+ return self.polars_shape[0]
109
+
110
+ @overload
111
+ def iter_rows(self, *, named: Literal[False] = ..., buffer_size: int = ...) -> Iterator[tuple[Any, ...]]: ...
112
+
113
+ @overload
114
+ def iter_rows(self, *, named: Literal[True], buffer_size: int = ...) -> Iterator[dict[str, Any]]: ...
115
+
116
+ def iter_rows(
117
+ self, *, named: Literal[False, True] = False, buffer_size: int = 512
118
+ ) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
119
+ df_stream: pl.DataFrame = cast(pl.DataFrame, self.lazy_polars.collect())
120
+ return df_stream.iter_rows(named=named, buffer_size=buffer_size) # type: ignore[no-any-return]
121
+
122
+
123
+ class TextDataset[TDataset](Dataset[TDataset], TextDatasetMixin[TDataset]):
124
+ def __init__(self, service: TDataset):
125
+ super().__init__(service)
126
+ self._validate_schema()
127
+
128
+ @classmethod
129
+ def from_parquet(
130
+ cls,
131
+ path: str,
132
+ *,
133
+ lazy: bool = True,
134
+ ) -> "TextDataset[pl.DataFrame | pl.LazyFrame]":
135
+ return _get_parquet_connector(path, lazy=lazy).load_text()
136
+
137
+ @classmethod
138
+ def from_minio(
139
+ cls,
140
+ bucket: str,
141
+ key: str,
142
+ endpoint: str,
143
+ access_key: str,
144
+ secret_key: str,
145
+ ) -> "TextDataset[pl.DataFrame | pl.LazyFrame]":
146
+ return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).load_text()
@@ -0,0 +1,49 @@
1
+ from abc import ABC
2
+ from typing import Any
3
+
4
+ import polars as pl
5
+ from datasets import Dataset as HFDataset
6
+
7
+ from retrievalbase.connector import DatasetConnector
8
+ from retrievalbase.dataset import TextDataset
9
+ from retrievalbase.dataset.settings import HuggingFaceDatasetAdaptaterSettings
10
+ from retrievalbase.exceptions import DatasetSchemaError
11
+ from retrievalbase.mixins import FromConfigMixin
12
+ from retrievalbase.utils import build_schema, extract_schema_columns, load_class
13
+
14
+
15
+ class HuggingFaceDatasetAdaptater[TCHFDatasetAdaptater: HuggingFaceDatasetAdaptaterSettings[Any]](
16
+ FromConfigMixin[TCHFDatasetAdaptater],
17
+ ABC,
18
+ ):
19
+ def __init__(self, config: TCHFDatasetAdaptater) -> None:
20
+ super().__init__(config)
21
+ self._dataset: TextDataset[pl.DataFrame | pl.LazyFrame] | None = None
22
+
23
+ @property
24
+ def dataset(self) -> TextDataset:
25
+ if self._dataset is None:
26
+ connector: DatasetConnector[Any] = load_class(self.config.connector.module_path).from_config(
27
+ self.config.connector
28
+ )
29
+
30
+ self._dataset = connector.load_text()
31
+
32
+ return self._dataset
33
+
34
+ def to_hf(self) -> HFDataset:
35
+ df = self.dataset.polars
36
+ # schema mapping
37
+ schema = self.config.columns_mapping
38
+ # validation (top-level)
39
+ available_columns = df.columns
40
+ flattened_columns = extract_schema_columns(df.schema)
41
+ missing_roots = {path.split(".")[0] for path in schema.values() if path.split(".")[0] not in available_columns}
42
+ if missing_roots:
43
+ raise DatasetSchemaError(
44
+ missing_columns=missing_roots,
45
+ available_columns=available_columns,
46
+ flattened_columns=flattened_columns,
47
+ )
48
+ df = build_schema(df, schema)
49
+ return HFDataset.from_dict(df.to_dict(as_series=False))
@@ -0,0 +1,108 @@
1
+ import json
2
+ from collections.abc import Iterator
3
+ from pathlib import Path
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Protocol,
8
+ runtime_checkable,
9
+ )
10
+
11
+ import polars as pl
12
+ from langchain_core.documents import Document
13
+
14
+ from retrievalbase.types import TDataset as TDataset
15
+
16
+ if TYPE_CHECKING:
17
+ from retrievalbase.dataset import TextDataset
18
+
19
+
20
+ @runtime_checkable
21
+ class SupportsPolarsDataset(Protocol[TDataset]):
22
+ @property
23
+ def polars(self) -> pl.DataFrame: ...
24
+
25
+ def to_polars(self) -> pl.DataFrame: ...
26
+
27
+ @classmethod
28
+ def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> "TextDataset[TDataset]":
29
+ raise NotImplementedError
30
+
31
+
32
+ class TextDatasetMixin[TDataset](SupportsPolarsDataset[TDataset]):
33
+ REQUIRED_COLUMNS = {"page_content", "metadata"}
34
+
35
+ def _validate_schema(self) -> None:
36
+ df = self.polars
37
+ missing = self.REQUIRED_COLUMNS - set(df.columns)
38
+ if missing:
39
+ raise ValueError(
40
+ f"{self.__class__.__name__} requires columns {self.REQUIRED_COLUMNS}, but missing {missing}"
41
+ )
42
+
43
+ def iter_documents(self) -> Iterator[Document]:
44
+ df = self.to_polars()
45
+ for row in df.iter_rows(named=True):
46
+ yield Document(
47
+ page_content=row["page_content"],
48
+ metadata=row["metadata"],
49
+ )
50
+
51
+ @classmethod
52
+ def from_records(
53
+ cls,
54
+ records: list[tuple[str, dict[str, Any]]],
55
+ ) -> "TextDataset[TDataset]":
56
+ if not records:
57
+ raise ValueError("records must be non-empty")
58
+ df = pl.DataFrame(
59
+ {
60
+ "page_content": [text for text, _ in records],
61
+ "metadata": [meta for _, meta in records],
62
+ }
63
+ )
64
+ return cls.from_polars(df)
65
+
66
+ def dump_documents(
67
+ self,
68
+ out_dir: str,
69
+ prefix: str = "doc",
70
+ ext: str = ".md",
71
+ encoding: str = "utf-8",
72
+ ) -> None:
73
+ out_dir_path = Path(out_dir)
74
+ out_dir_path.mkdir(parents=True, exist_ok=True)
75
+
76
+ for i, row in enumerate(self.polars.iter_rows(named=True)):
77
+ path = out_dir_path / f"{prefix}_{i:05d}{ext}"
78
+ page_content = row.get("page_content", "")
79
+ metadata = row.get("metadata", {})
80
+ try:
81
+ metadata_str = json.dumps(metadata, indent=2, ensure_ascii=False)
82
+ except TypeError:
83
+ metadata_str = str(metadata)
84
+ content = f"###PAGE_CONTENT###\n{page_content}\n\n###METADATA###\n{metadata_str}\n"
85
+ path.write_text(content, encoding=encoding)
86
+
87
+ def to_langchain_documents(
88
+ self,
89
+ ) -> list[Document]:
90
+ docs: list[Document] = []
91
+ for row in self.polars.iter_rows(named=True):
92
+ docs.append(
93
+ Document(
94
+ page_content=row["page_content"],
95
+ metadata=row["metadata"],
96
+ )
97
+ )
98
+ return docs
99
+
100
+ @classmethod
101
+ def from_documents(cls, docs: list[Document]) -> "TextDataset[TDataset]":
102
+ df = pl.DataFrame(
103
+ {
104
+ "page_content": [d.page_content for d in docs],
105
+ "metadata": [d.metadata or {} for d in docs],
106
+ }
107
+ )
108
+ return cls.from_polars(df)
@@ -0,0 +1,43 @@
1
+ from typing import Self
2
+
3
+ import polars as pl
4
+
5
+ from retrievalbase.dataset import Dataset, TextDataset
6
+
7
+
8
+ class PolarsDataset(Dataset[pl.DataFrame | pl.LazyFrame]):
9
+ def __init__(self, service: pl.DataFrame | pl.LazyFrame):
10
+ super().__init__(service)
11
+
12
+ @classmethod
13
+ def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> Self:
14
+ return cls(df)
15
+
16
+ def to_polars(self) -> pl.DataFrame:
17
+ if isinstance(self.service, pl.LazyFrame):
18
+ return self.service.collect(background=False) # ty: ignore[invalid-return-type]
19
+ return self.service
20
+
21
+ def to_lazy_polars(self) -> pl.LazyFrame:
22
+ if isinstance(self.service, pl.LazyFrame):
23
+ return self.service
24
+ return self.service.lazy()
25
+
26
+
27
+ class PolarsTextDataset(TextDataset[pl.DataFrame | pl.LazyFrame]):
28
+ def __init__(self, service: pl.DataFrame | pl.LazyFrame):
29
+ super().__init__(service)
30
+
31
+ @classmethod
32
+ def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> Self:
33
+ return cls(df)
34
+
35
+ def to_polars(self) -> pl.DataFrame:
36
+ if isinstance(self.service, pl.LazyFrame):
37
+ return self.service.collect(background=False) # ty: ignore[invalid-return-type]
38
+ return self.service
39
+
40
+ def to_lazy_polars(self) -> pl.LazyFrame:
41
+ if isinstance(self.service, pl.LazyFrame):
42
+ return self.service
43
+ return self.service.lazy()
@@ -0,0 +1,29 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from retrievalbase.dataset.polars import TextDataset
5
+ from retrievalbase.dataset.settings import TextPreprocessorSettings, TokenCounterSettings
6
+ from retrievalbase.mixins import FromConfigMixin
7
+ from retrievalbase.types import TCTextPreprocessor as TCTextPreprocessor
8
+ from retrievalbase.types import TCTokenCounter as TCTokenCounter
9
+
10
+
11
+ class TextPreprocessor[TCTextPreprocessor: TextPreprocessorSettings](FromConfigMixin[TCTextPreprocessor], ABC):
12
+ def __init__(self, config: TCTextPreprocessor) -> None:
13
+ self.config = config
14
+
15
+ @abstractmethod
16
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
17
+ raise NotImplementedError
18
+
19
+
20
+ class TokenCounter[TCTokenCounter: TokenCounterSettings](FromConfigMixin[TCTokenCounter], ABC):
21
+ def __init__(self, config: TCTokenCounter) -> None:
22
+ self.config = config
23
+
24
+ @abstractmethod
25
+ def count(self, text: str) -> int:
26
+ raise NotImplementedError()
27
+
28
+ def count_batch(self, texts: list[str]) -> list[int]:
29
+ return [self.count(t) for t in texts]
@@ -0,0 +1,96 @@
1
+ from typing import Any
2
+
3
+ import polars as pl
4
+
5
+ from retrievalbase.dataset import TextDataset
6
+ from retrievalbase.dataset.preprocess import TextPreprocessor, TokenCounter
7
+ from retrievalbase.dataset.settings import (
8
+ MaxTokenFilterSettings,
9
+ MinTokenFilterSettings,
10
+ PreprocessPipelineSettings,
11
+ QuantileTokenFilterSettings,
12
+ SigmaBandTokenFilterSettings,
13
+ )
14
+ from retrievalbase.utils import load_class
15
+
16
+
17
+ class MinTokenFilter(TextPreprocessor[MinTokenFilterSettings[Any]]):
18
+ def __init__(self, config: MinTokenFilterSettings[Any]):
19
+ super().__init__(config)
20
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
21
+ self.config.token_counter
22
+ )
23
+
24
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
25
+ return ds.__class__.from_polars(
26
+ ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
27
+ .filter(pl.col("n_tokens") >= self.config.min_tokens)
28
+ .drop("n_tokens")
29
+ )
30
+
31
+
32
+ class MaxTokenFilter(TextPreprocessor[MaxTokenFilterSettings[Any]]):
33
+ def __init__(self, config: MaxTokenFilterSettings[Any]):
34
+ super().__init__(config)
35
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
36
+ self.config.token_counter
37
+ )
38
+
39
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
40
+ return ds.__class__.from_polars(
41
+ ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
42
+ .filter(pl.col("n_tokens") <= self.config.max_tokens)
43
+ .drop("n_tokens")
44
+ )
45
+
46
+
47
+ class QuantileTokenFilter(TextPreprocessor[QuantileTokenFilterSettings[Any]]):
48
+ def __init__(self, config: QuantileTokenFilterSettings[Any]):
49
+ super().__init__(config)
50
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
51
+ self.config.token_counter
52
+ )
53
+
54
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
55
+ df = ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
56
+ cutoff = df.select(pl.col("n_tokens").quantile(self.config.q)).item()
57
+ return ds.__class__.from_polars(df.filter(pl.col("n_tokens") <= cutoff).drop("n_tokens"))
58
+
59
+
60
+ class SigmaBandTokenFilter(TextPreprocessor[SigmaBandTokenFilterSettings[Any]]):
61
+ def __init__(self, config: SigmaBandTokenFilterSettings[Any]):
62
+ super().__init__(config)
63
+ self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
64
+ self.config.token_counter
65
+ )
66
+
67
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
68
+ df = ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
69
+ mu, sigma = df.select(
70
+ pl.col("n_tokens").mean().alias("mu"),
71
+ pl.col("n_tokens").std().alias("sigma"),
72
+ ).row(0)
73
+ return ds.__class__.from_polars(
74
+ df.filter(pl.col("n_tokens").is_between(mu - self.config.z * sigma, mu + self.config.z * sigma)).drop(
75
+ "n_tokens"
76
+ )
77
+ )
78
+
79
+
80
+ class PreprocessPipeline(TextPreprocessor[PreprocessPipelineSettings]):
81
+ def __init__(self, config: PreprocessPipelineSettings):
82
+ super().__init__(config)
83
+ self.steps = self._load_steps()
84
+
85
+ def _load_steps(self) -> list[TextPreprocessor[Any]]:
86
+ steps = []
87
+ for step_config in self.config.steps:
88
+ step: TextPreprocessor[Any] = load_class(step_config.module_path)(step_config)
89
+ steps.append(step)
90
+ return steps
91
+
92
+ def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
93
+ current = ds
94
+ for step in self.steps:
95
+ current = step.apply(current)
96
+ return current
@@ -0,0 +1,41 @@
1
+ from retrievalbase.dataset.preprocess import TokenCounter
2
+ from retrievalbase.dataset.settings import (
3
+ HeuristicTokenCounterSettings,
4
+ HuggingFaceTokenCounterSettings,
5
+ )
6
+
7
+
8
+ class HeuristicTokenCounter(TokenCounter[HeuristicTokenCounterSettings]):
9
+ def __init__(self, config: HeuristicTokenCounterSettings) -> None:
10
+ super().__init__(config)
11
+
12
+ def count(self, text: str) -> int:
13
+ return max(1, int(len(text) / self.config.chars_per_token))
14
+
15
+
16
+ class HuggingFaceTokenCounter(TokenCounter[HuggingFaceTokenCounterSettings]):
17
+ def __init__(self, config: HuggingFaceTokenCounterSettings) -> None:
18
+ super().__init__(config)
19
+ try:
20
+ from transformers import AutoTokenizer
21
+ except ImportError as e:
22
+ raise ImportError(
23
+ "HuggingFaceTokenCounter requires transformers.\nInstall with: pip install retrievalbase[transformers]"
24
+ ) from e
25
+
26
+ self.tokenizer = AutoTokenizer.from_pretrained(
27
+ self.config.name,
28
+ revision=self.config.revision,
29
+ trust_remote_code=True,
30
+ )
31
+
32
+ def count(self, text: str) -> int:
33
+ return len(
34
+ self.tokenizer( # ty: ignore[call-non-callable]
35
+ text,
36
+ add_special_tokens=self.config.add_special_tokens,
37
+ truncation=False,
38
+ return_attention_mask=False,
39
+ return_token_type_ids=False,
40
+ )["input_ids"]
41
+ )