retrievalbase 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- retrievalbase/__init__.py +0 -0
- retrievalbase/connector/__init__.py +69 -0
- retrievalbase/connector/minio.py +45 -0
- retrievalbase/connector/parquet.py +20 -0
- retrievalbase/connector/settings.py +22 -0
- retrievalbase/constants.py +1 -0
- retrievalbase/dataset/__init__.py +146 -0
- retrievalbase/dataset/hf.py +49 -0
- retrievalbase/dataset/mixins.py +108 -0
- retrievalbase/dataset/polars.py +43 -0
- retrievalbase/dataset/preprocess/__init__.py +29 -0
- retrievalbase/dataset/preprocess/preprocess.py +96 -0
- retrievalbase/dataset/preprocess/token_counter.py +41 -0
- retrievalbase/dataset/settings.py +63 -0
- retrievalbase/enums.py +11 -0
- retrievalbase/evaluation/__init__.py +179 -0
- retrievalbase/evaluation/async_batcher.py +79 -0
- retrievalbase/evaluation/embedders.py +28 -0
- retrievalbase/evaluation/evaluators/__init__.py +37 -0
- retrievalbase/evaluation/evaluators/python/__init__.py +149 -0
- retrievalbase/evaluation/evaluators/python/evaluators.py +71 -0
- retrievalbase/evaluation/evaluators/python/scores.py +118 -0
- retrievalbase/evaluation/processors.py +15 -0
- retrievalbase/evaluation/rerankers.py +182 -0
- retrievalbase/evaluation/retrievers/__init__.py +112 -0
- retrievalbase/evaluation/retrievers/dense/__init__.py +56 -0
- retrievalbase/evaluation/retrievers/dense/retrievers.py +86 -0
- retrievalbase/evaluation/settings.py +204 -0
- retrievalbase/evaluation/vector_stores.py +146 -0
- retrievalbase/exceptions.py +61 -0
- retrievalbase/ingestion/__init__.py +50 -0
- retrievalbase/ingestion/settings.py +10 -0
- retrievalbase/mixins.py +85 -0
- retrievalbase/py.typed +0 -0
- retrievalbase/settings.py +33 -0
- retrievalbase/types.py +55 -0
- retrievalbase/utils.py +107 -0
- retrievalbase-1.0.0.dist-info/METADATA +23 -0
- retrievalbase-1.0.0.dist-info/RECORD +40 -0
- retrievalbase-1.0.0.dist-info/WHEEL +4 -0
|
File without changes
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
from retrievalbase.connector.settings import DatasetConnectorSettings
|
|
8
|
+
from retrievalbase.mixins import FromConfigMixin
|
|
9
|
+
from retrievalbase.types import TCDatasetConnector as TCDatasetConnector
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from retrievalbase.dataset import Dataset, TextDataset
|
|
13
|
+
|
|
14
|
+
_logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DatasetConnector[TCDatasetConnector: DatasetConnectorSettings](
|
|
18
|
+
FromConfigMixin[TCDatasetConnector],
|
|
19
|
+
ABC,
|
|
20
|
+
):
|
|
21
|
+
def __init__(self, config: TCDatasetConnector):
|
|
22
|
+
super().__init__(config)
|
|
23
|
+
|
|
24
|
+
_logger.info(
|
|
25
|
+
f"Initializing dataset connector | class={self.__class__.__name__} | module={self.__class__.__module__}"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def _load(self) -> pl.DataFrame | pl.LazyFrame:
|
|
30
|
+
"""
|
|
31
|
+
Load raw data as Polars DataFrame or LazyFrame.
|
|
32
|
+
"""
|
|
33
|
+
raise NotImplementedError()
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def to(self, ds: "Dataset[Any]") -> None:
|
|
37
|
+
raise NotImplementedError()
|
|
38
|
+
|
|
39
|
+
def load(self) -> "Dataset[pl.DataFrame | pl.LazyFrame]":
|
|
40
|
+
from retrievalbase.dataset.polars import PolarsDataset
|
|
41
|
+
|
|
42
|
+
_logger.info(f"Loading dataset | connector={self.__class__.__name__}")
|
|
43
|
+
|
|
44
|
+
df = self._load()
|
|
45
|
+
self._log_polars_info(df)
|
|
46
|
+
|
|
47
|
+
return PolarsDataset.from_polars(df)
|
|
48
|
+
|
|
49
|
+
def load_text(self) -> "TextDataset[pl.DataFrame | pl.LazyFrame]":
|
|
50
|
+
from retrievalbase.dataset.polars import PolarsTextDataset
|
|
51
|
+
|
|
52
|
+
_logger.info(f"Loading text dataset | connector={self.__class__.__name__}")
|
|
53
|
+
|
|
54
|
+
df = self._load()
|
|
55
|
+
self._log_polars_info(df)
|
|
56
|
+
|
|
57
|
+
return PolarsTextDataset.from_polars(df)
|
|
58
|
+
|
|
59
|
+
# ------------------------------------------------------------------
|
|
60
|
+
# Helpers
|
|
61
|
+
# ------------------------------------------------------------------
|
|
62
|
+
|
|
63
|
+
def _log_polars_info(self, df: pl.DataFrame | pl.LazyFrame) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Log dataset structure without forcing materialization.
|
|
66
|
+
"""
|
|
67
|
+
if isinstance(df, pl.DataFrame):
|
|
68
|
+
schema = df.schema
|
|
69
|
+
_logger.info(f"Loaded DataFrame | columns={len(schema)} | schema={list(schema.items())}")
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import io
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
import polars as pl
|
|
5
|
+
from minio import Minio
|
|
6
|
+
|
|
7
|
+
from retrievalbase.connector import DatasetConnector
|
|
8
|
+
from retrievalbase.connector.settings import MinioDatasetConnectorSettings
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from retrievalbase.dataset import Dataset
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MinioDatasetConnector(DatasetConnector[MinioDatasetConnectorSettings]):
|
|
15
|
+
def __init__(self, config: MinioDatasetConnectorSettings):
|
|
16
|
+
super().__init__(config)
|
|
17
|
+
self.client = Minio(
|
|
18
|
+
self.config.endpoint.replace("http://", "").replace("https://", ""),
|
|
19
|
+
access_key=self.config.access_key.get_secret_value(),
|
|
20
|
+
secret_key=self.config.secret_key.get_secret_value(),
|
|
21
|
+
secure=self.config.endpoint.startswith("https://"),
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def _load(self) -> pl.DataFrame | pl.LazyFrame:
|
|
25
|
+
response = self.client.get_object(self.config.bucket, self.config.key)
|
|
26
|
+
try:
|
|
27
|
+
buffer = io.BytesIO(response.read())
|
|
28
|
+
finally:
|
|
29
|
+
response.close()
|
|
30
|
+
response.release_conn()
|
|
31
|
+
df = pl.read_parquet(buffer)
|
|
32
|
+
return df
|
|
33
|
+
|
|
34
|
+
def to(self, ds: "Dataset[Any]") -> None:
|
|
35
|
+
df = ds.polars
|
|
36
|
+
buffer = io.BytesIO()
|
|
37
|
+
df.write_parquet(buffer)
|
|
38
|
+
buffer.seek(0)
|
|
39
|
+
self.client.put_object(
|
|
40
|
+
bucket_name=self.config.bucket,
|
|
41
|
+
object_name=self.config.key,
|
|
42
|
+
data=buffer,
|
|
43
|
+
length=buffer.getbuffer().nbytes,
|
|
44
|
+
content_type="application/octet-stream",
|
|
45
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
from retrievalbase.connector import DatasetConnector
|
|
6
|
+
from retrievalbase.connector.settings import ParquetDatasetConnectorSettings
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from retrievalbase.dataset import Dataset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ParquetDatasetConnector(DatasetConnector[ParquetDatasetConnectorSettings]):
|
|
13
|
+
def __init__(self, config: ParquetDatasetConnectorSettings):
|
|
14
|
+
super().__init__(config)
|
|
15
|
+
|
|
16
|
+
def _load(self) -> pl.DataFrame | pl.LazyFrame:
|
|
17
|
+
return pl.scan_parquet(self.config.path) if self.config.lazy else pl.read_parquet(self.config.path)
|
|
18
|
+
|
|
19
|
+
def to(self, ds: "Dataset[Any]") -> None:
|
|
20
|
+
ds.polars.write_parquet(self.config.path)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from pydantic import SecretStr
|
|
2
|
+
from pydantic_settings import SettingsConfigDict
|
|
3
|
+
|
|
4
|
+
from retrievalbase.settings import FromConfigMixinSettings
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DatasetConnectorSettings(FromConfigMixinSettings):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ParquetDatasetConnectorSettings(DatasetConnectorSettings):
|
|
12
|
+
path: str
|
|
13
|
+
lazy: bool
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MinioDatasetConnectorSettings(DatasetConnectorSettings):
|
|
17
|
+
endpoint: str
|
|
18
|
+
bucket: str
|
|
19
|
+
key: str
|
|
20
|
+
access_key: SecretStr
|
|
21
|
+
secret_key: SecretStr
|
|
22
|
+
model_config = SettingsConfigDict(env_prefix="MINIO_", extra="ignore")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
CONFIG_PATH = "/config/config.yaml"
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Iterable, Iterator
|
|
3
|
+
from typing import (
|
|
4
|
+
TYPE_CHECKING,
|
|
5
|
+
Any,
|
|
6
|
+
Literal,
|
|
7
|
+
Self,
|
|
8
|
+
cast,
|
|
9
|
+
overload,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
import polars as pl
|
|
13
|
+
from polars._typing import ColumnNameOrSelector, IntoExpr, IntoExprColumn
|
|
14
|
+
|
|
15
|
+
from retrievalbase.dataset.mixins import TextDatasetMixin
|
|
16
|
+
from retrievalbase.types import TDataset as TDataset
|
|
17
|
+
from retrievalbase.utils import _get_minio_connector, _get_parquet_connector
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Dataset[TDataset](ABC):
|
|
24
|
+
def __init__(self, service: TDataset):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.service = service
|
|
27
|
+
self._polars: pl.DataFrame | None = None
|
|
28
|
+
self._lazy_polars: pl.LazyFrame | None = None
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> Self:
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def to_polars(self) -> pl.DataFrame:
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def to_lazy_polars(self) -> pl.LazyFrame:
|
|
41
|
+
raise NotImplementedError
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_parquet(
|
|
45
|
+
cls,
|
|
46
|
+
path: str,
|
|
47
|
+
*,
|
|
48
|
+
lazy: bool = True,
|
|
49
|
+
) -> "Dataset[pl.DataFrame | pl.LazyFrame]":
|
|
50
|
+
return _get_parquet_connector(path, lazy=lazy).load()
|
|
51
|
+
|
|
52
|
+
def to_parquet(self, path: str, *, lazy: bool = True) -> None:
|
|
53
|
+
return _get_parquet_connector(path, lazy=lazy).to(self)
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_minio(
|
|
57
|
+
cls,
|
|
58
|
+
bucket: str,
|
|
59
|
+
key: str,
|
|
60
|
+
endpoint: str,
|
|
61
|
+
access_key: str,
|
|
62
|
+
secret_key: str,
|
|
63
|
+
) -> "Dataset[pl.DataFrame | pl.LazyFrame]":
|
|
64
|
+
return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).load()
|
|
65
|
+
|
|
66
|
+
def to_minio(self, bucket: str, key: str, endpoint: str, access_key: str, secret_key: str) -> None:
|
|
67
|
+
return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).to(self)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def polars(self) -> pl.DataFrame:
|
|
71
|
+
if self._polars is None:
|
|
72
|
+
self._polars = self.to_polars()
|
|
73
|
+
return self._polars
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def lazy_polars(self) -> pl.LazyFrame:
|
|
77
|
+
if self._lazy_polars is None:
|
|
78
|
+
self._lazy_polars = self.to_lazy_polars()
|
|
79
|
+
return self._lazy_polars
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def polars_shape(self) -> tuple[int, int]:
|
|
83
|
+
return self.polars.shape
|
|
84
|
+
|
|
85
|
+
def with_columns(
|
|
86
|
+
self,
|
|
87
|
+
*exprs: IntoExpr | Iterable[IntoExpr],
|
|
88
|
+
**named_exprs: IntoExpr,
|
|
89
|
+
) -> Self:
|
|
90
|
+
df = self.polars.with_columns(*exprs, **named_exprs)
|
|
91
|
+
return self.__class__.from_polars(df)
|
|
92
|
+
|
|
93
|
+
def filter(
|
|
94
|
+
self,
|
|
95
|
+
*predicates: (IntoExprColumn | Iterable[IntoExprColumn] | bool | list[bool]),
|
|
96
|
+
**constraints: Any,
|
|
97
|
+
) -> Self:
|
|
98
|
+
return self.__class__.from_polars(self.polars.filter(*predicates, **constraints))
|
|
99
|
+
|
|
100
|
+
def drop(
|
|
101
|
+
self,
|
|
102
|
+
*columns: ColumnNameOrSelector | Iterable[ColumnNameOrSelector],
|
|
103
|
+
strict: bool = True,
|
|
104
|
+
) -> Self:
|
|
105
|
+
return self.__class__.from_polars(self.polars.drop(*columns, strict=strict))
|
|
106
|
+
|
|
107
|
+
def __len__(self) -> int:
|
|
108
|
+
return self.polars_shape[0]
|
|
109
|
+
|
|
110
|
+
@overload
|
|
111
|
+
def iter_rows(self, *, named: Literal[False] = ..., buffer_size: int = ...) -> Iterator[tuple[Any, ...]]: ...
|
|
112
|
+
|
|
113
|
+
@overload
|
|
114
|
+
def iter_rows(self, *, named: Literal[True], buffer_size: int = ...) -> Iterator[dict[str, Any]]: ...
|
|
115
|
+
|
|
116
|
+
def iter_rows(
|
|
117
|
+
self, *, named: Literal[False, True] = False, buffer_size: int = 512
|
|
118
|
+
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
|
|
119
|
+
df_stream: pl.DataFrame = cast(pl.DataFrame, self.lazy_polars.collect())
|
|
120
|
+
return df_stream.iter_rows(named=named, buffer_size=buffer_size) # type: ignore[no-any-return]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class TextDataset[TDataset](Dataset[TDataset], TextDatasetMixin[TDataset]):
|
|
124
|
+
def __init__(self, service: TDataset):
|
|
125
|
+
super().__init__(service)
|
|
126
|
+
self._validate_schema()
|
|
127
|
+
|
|
128
|
+
@classmethod
|
|
129
|
+
def from_parquet(
|
|
130
|
+
cls,
|
|
131
|
+
path: str,
|
|
132
|
+
*,
|
|
133
|
+
lazy: bool = True,
|
|
134
|
+
) -> "TextDataset[pl.DataFrame | pl.LazyFrame]":
|
|
135
|
+
return _get_parquet_connector(path, lazy=lazy).load_text()
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def from_minio(
|
|
139
|
+
cls,
|
|
140
|
+
bucket: str,
|
|
141
|
+
key: str,
|
|
142
|
+
endpoint: str,
|
|
143
|
+
access_key: str,
|
|
144
|
+
secret_key: str,
|
|
145
|
+
) -> "TextDataset[pl.DataFrame | pl.LazyFrame]":
|
|
146
|
+
return _get_minio_connector(bucket, key, endpoint, access_key, secret_key).load_text()
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import polars as pl
|
|
5
|
+
from datasets import Dataset as HFDataset
|
|
6
|
+
|
|
7
|
+
from retrievalbase.connector import DatasetConnector
|
|
8
|
+
from retrievalbase.dataset import TextDataset
|
|
9
|
+
from retrievalbase.dataset.settings import HuggingFaceDatasetAdaptaterSettings
|
|
10
|
+
from retrievalbase.exceptions import DatasetSchemaError
|
|
11
|
+
from retrievalbase.mixins import FromConfigMixin
|
|
12
|
+
from retrievalbase.utils import build_schema, extract_schema_columns, load_class
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HuggingFaceDatasetAdaptater[TCHFDatasetAdaptater: HuggingFaceDatasetAdaptaterSettings[Any]](
|
|
16
|
+
FromConfigMixin[TCHFDatasetAdaptater],
|
|
17
|
+
ABC,
|
|
18
|
+
):
|
|
19
|
+
def __init__(self, config: TCHFDatasetAdaptater) -> None:
|
|
20
|
+
super().__init__(config)
|
|
21
|
+
self._dataset: TextDataset[pl.DataFrame | pl.LazyFrame] | None = None
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def dataset(self) -> TextDataset:
|
|
25
|
+
if self._dataset is None:
|
|
26
|
+
connector: DatasetConnector[Any] = load_class(self.config.connector.module_path).from_config(
|
|
27
|
+
self.config.connector
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
self._dataset = connector.load_text()
|
|
31
|
+
|
|
32
|
+
return self._dataset
|
|
33
|
+
|
|
34
|
+
def to_hf(self) -> HFDataset:
|
|
35
|
+
df = self.dataset.polars
|
|
36
|
+
# schema mapping
|
|
37
|
+
schema = self.config.columns_mapping
|
|
38
|
+
# validation (top-level)
|
|
39
|
+
available_columns = df.columns
|
|
40
|
+
flattened_columns = extract_schema_columns(df.schema)
|
|
41
|
+
missing_roots = {path.split(".")[0] for path in schema.values() if path.split(".")[0] not in available_columns}
|
|
42
|
+
if missing_roots:
|
|
43
|
+
raise DatasetSchemaError(
|
|
44
|
+
missing_columns=missing_roots,
|
|
45
|
+
available_columns=available_columns,
|
|
46
|
+
flattened_columns=flattened_columns,
|
|
47
|
+
)
|
|
48
|
+
df = build_schema(df, schema)
|
|
49
|
+
return HFDataset.from_dict(df.to_dict(as_series=False))
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import (
|
|
5
|
+
TYPE_CHECKING,
|
|
6
|
+
Any,
|
|
7
|
+
Protocol,
|
|
8
|
+
runtime_checkable,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import polars as pl
|
|
12
|
+
from langchain_core.documents import Document
|
|
13
|
+
|
|
14
|
+
from retrievalbase.types import TDataset as TDataset
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from retrievalbase.dataset import TextDataset
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@runtime_checkable
|
|
21
|
+
class SupportsPolarsDataset(Protocol[TDataset]):
|
|
22
|
+
@property
|
|
23
|
+
def polars(self) -> pl.DataFrame: ...
|
|
24
|
+
|
|
25
|
+
def to_polars(self) -> pl.DataFrame: ...
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> "TextDataset[TDataset]":
|
|
29
|
+
raise NotImplementedError
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TextDatasetMixin[TDataset](SupportsPolarsDataset[TDataset]):
|
|
33
|
+
REQUIRED_COLUMNS = {"page_content", "metadata"}
|
|
34
|
+
|
|
35
|
+
def _validate_schema(self) -> None:
|
|
36
|
+
df = self.polars
|
|
37
|
+
missing = self.REQUIRED_COLUMNS - set(df.columns)
|
|
38
|
+
if missing:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"{self.__class__.__name__} requires columns {self.REQUIRED_COLUMNS}, but missing {missing}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def iter_documents(self) -> Iterator[Document]:
|
|
44
|
+
df = self.to_polars()
|
|
45
|
+
for row in df.iter_rows(named=True):
|
|
46
|
+
yield Document(
|
|
47
|
+
page_content=row["page_content"],
|
|
48
|
+
metadata=row["metadata"],
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_records(
|
|
53
|
+
cls,
|
|
54
|
+
records: list[tuple[str, dict[str, Any]]],
|
|
55
|
+
) -> "TextDataset[TDataset]":
|
|
56
|
+
if not records:
|
|
57
|
+
raise ValueError("records must be non-empty")
|
|
58
|
+
df = pl.DataFrame(
|
|
59
|
+
{
|
|
60
|
+
"page_content": [text for text, _ in records],
|
|
61
|
+
"metadata": [meta for _, meta in records],
|
|
62
|
+
}
|
|
63
|
+
)
|
|
64
|
+
return cls.from_polars(df)
|
|
65
|
+
|
|
66
|
+
def dump_documents(
|
|
67
|
+
self,
|
|
68
|
+
out_dir: str,
|
|
69
|
+
prefix: str = "doc",
|
|
70
|
+
ext: str = ".md",
|
|
71
|
+
encoding: str = "utf-8",
|
|
72
|
+
) -> None:
|
|
73
|
+
out_dir_path = Path(out_dir)
|
|
74
|
+
out_dir_path.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
|
|
76
|
+
for i, row in enumerate(self.polars.iter_rows(named=True)):
|
|
77
|
+
path = out_dir_path / f"{prefix}_{i:05d}{ext}"
|
|
78
|
+
page_content = row.get("page_content", "")
|
|
79
|
+
metadata = row.get("metadata", {})
|
|
80
|
+
try:
|
|
81
|
+
metadata_str = json.dumps(metadata, indent=2, ensure_ascii=False)
|
|
82
|
+
except TypeError:
|
|
83
|
+
metadata_str = str(metadata)
|
|
84
|
+
content = f"###PAGE_CONTENT###\n{page_content}\n\n###METADATA###\n{metadata_str}\n"
|
|
85
|
+
path.write_text(content, encoding=encoding)
|
|
86
|
+
|
|
87
|
+
def to_langchain_documents(
|
|
88
|
+
self,
|
|
89
|
+
) -> list[Document]:
|
|
90
|
+
docs: list[Document] = []
|
|
91
|
+
for row in self.polars.iter_rows(named=True):
|
|
92
|
+
docs.append(
|
|
93
|
+
Document(
|
|
94
|
+
page_content=row["page_content"],
|
|
95
|
+
metadata=row["metadata"],
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
return docs
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_documents(cls, docs: list[Document]) -> "TextDataset[TDataset]":
|
|
102
|
+
df = pl.DataFrame(
|
|
103
|
+
{
|
|
104
|
+
"page_content": [d.page_content for d in docs],
|
|
105
|
+
"metadata": [d.metadata or {} for d in docs],
|
|
106
|
+
}
|
|
107
|
+
)
|
|
108
|
+
return cls.from_polars(df)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Self
|
|
2
|
+
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
from retrievalbase.dataset import Dataset, TextDataset
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PolarsDataset(Dataset[pl.DataFrame | pl.LazyFrame]):
|
|
9
|
+
def __init__(self, service: pl.DataFrame | pl.LazyFrame):
|
|
10
|
+
super().__init__(service)
|
|
11
|
+
|
|
12
|
+
@classmethod
|
|
13
|
+
def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> Self:
|
|
14
|
+
return cls(df)
|
|
15
|
+
|
|
16
|
+
def to_polars(self) -> pl.DataFrame:
|
|
17
|
+
if isinstance(self.service, pl.LazyFrame):
|
|
18
|
+
return self.service.collect(background=False) # ty: ignore[invalid-return-type]
|
|
19
|
+
return self.service
|
|
20
|
+
|
|
21
|
+
def to_lazy_polars(self) -> pl.LazyFrame:
|
|
22
|
+
if isinstance(self.service, pl.LazyFrame):
|
|
23
|
+
return self.service
|
|
24
|
+
return self.service.lazy()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PolarsTextDataset(TextDataset[pl.DataFrame | pl.LazyFrame]):
|
|
28
|
+
def __init__(self, service: pl.DataFrame | pl.LazyFrame):
|
|
29
|
+
super().__init__(service)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def from_polars(cls, df: pl.DataFrame | pl.LazyFrame) -> Self:
|
|
33
|
+
return cls(df)
|
|
34
|
+
|
|
35
|
+
def to_polars(self) -> pl.DataFrame:
|
|
36
|
+
if isinstance(self.service, pl.LazyFrame):
|
|
37
|
+
return self.service.collect(background=False) # ty: ignore[invalid-return-type]
|
|
38
|
+
return self.service
|
|
39
|
+
|
|
40
|
+
def to_lazy_polars(self) -> pl.LazyFrame:
|
|
41
|
+
if isinstance(self.service, pl.LazyFrame):
|
|
42
|
+
return self.service
|
|
43
|
+
return self.service.lazy()
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from retrievalbase.dataset.polars import TextDataset
|
|
5
|
+
from retrievalbase.dataset.settings import TextPreprocessorSettings, TokenCounterSettings
|
|
6
|
+
from retrievalbase.mixins import FromConfigMixin
|
|
7
|
+
from retrievalbase.types import TCTextPreprocessor as TCTextPreprocessor
|
|
8
|
+
from retrievalbase.types import TCTokenCounter as TCTokenCounter
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TextPreprocessor[TCTextPreprocessor: TextPreprocessorSettings](FromConfigMixin[TCTextPreprocessor], ABC):
|
|
12
|
+
def __init__(self, config: TCTextPreprocessor) -> None:
|
|
13
|
+
self.config = config
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
|
|
17
|
+
raise NotImplementedError
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TokenCounter[TCTokenCounter: TokenCounterSettings](FromConfigMixin[TCTokenCounter], ABC):
|
|
21
|
+
def __init__(self, config: TCTokenCounter) -> None:
|
|
22
|
+
self.config = config
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def count(self, text: str) -> int:
|
|
26
|
+
raise NotImplementedError()
|
|
27
|
+
|
|
28
|
+
def count_batch(self, texts: list[str]) -> list[int]:
|
|
29
|
+
return [self.count(t) for t in texts]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
from retrievalbase.dataset import TextDataset
|
|
6
|
+
from retrievalbase.dataset.preprocess import TextPreprocessor, TokenCounter
|
|
7
|
+
from retrievalbase.dataset.settings import (
|
|
8
|
+
MaxTokenFilterSettings,
|
|
9
|
+
MinTokenFilterSettings,
|
|
10
|
+
PreprocessPipelineSettings,
|
|
11
|
+
QuantileTokenFilterSettings,
|
|
12
|
+
SigmaBandTokenFilterSettings,
|
|
13
|
+
)
|
|
14
|
+
from retrievalbase.utils import load_class
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MinTokenFilter(TextPreprocessor[MinTokenFilterSettings[Any]]):
|
|
18
|
+
def __init__(self, config: MinTokenFilterSettings[Any]):
|
|
19
|
+
super().__init__(config)
|
|
20
|
+
self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
|
|
21
|
+
self.config.token_counter
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
|
|
25
|
+
return ds.__class__.from_polars(
|
|
26
|
+
ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
|
|
27
|
+
.filter(pl.col("n_tokens") >= self.config.min_tokens)
|
|
28
|
+
.drop("n_tokens")
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MaxTokenFilter(TextPreprocessor[MaxTokenFilterSettings[Any]]):
|
|
33
|
+
def __init__(self, config: MaxTokenFilterSettings[Any]):
|
|
34
|
+
super().__init__(config)
|
|
35
|
+
self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
|
|
36
|
+
self.config.token_counter
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
|
|
40
|
+
return ds.__class__.from_polars(
|
|
41
|
+
ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
|
|
42
|
+
.filter(pl.col("n_tokens") <= self.config.max_tokens)
|
|
43
|
+
.drop("n_tokens")
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class QuantileTokenFilter(TextPreprocessor[QuantileTokenFilterSettings[Any]]):
|
|
48
|
+
def __init__(self, config: QuantileTokenFilterSettings[Any]):
|
|
49
|
+
super().__init__(config)
|
|
50
|
+
self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
|
|
51
|
+
self.config.token_counter
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
|
|
55
|
+
df = ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
|
|
56
|
+
cutoff = df.select(pl.col("n_tokens").quantile(self.config.q)).item()
|
|
57
|
+
return ds.__class__.from_polars(df.filter(pl.col("n_tokens") <= cutoff).drop("n_tokens"))
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SigmaBandTokenFilter(TextPreprocessor[SigmaBandTokenFilterSettings[Any]]):
|
|
61
|
+
def __init__(self, config: SigmaBandTokenFilterSettings[Any]):
|
|
62
|
+
super().__init__(config)
|
|
63
|
+
self.token_counter: TokenCounter[Any] = load_class(self.config.token_counter.module_path).from_config(
|
|
64
|
+
self.config.token_counter
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
|
|
68
|
+
df = ds.polars.with_columns(pl.col("page_content").map_elements(self.token_counter.count).alias("n_tokens"))
|
|
69
|
+
mu, sigma = df.select(
|
|
70
|
+
pl.col("n_tokens").mean().alias("mu"),
|
|
71
|
+
pl.col("n_tokens").std().alias("sigma"),
|
|
72
|
+
).row(0)
|
|
73
|
+
return ds.__class__.from_polars(
|
|
74
|
+
df.filter(pl.col("n_tokens").is_between(mu - self.config.z * sigma, mu + self.config.z * sigma)).drop(
|
|
75
|
+
"n_tokens"
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class PreprocessPipeline(TextPreprocessor[PreprocessPipelineSettings]):
|
|
81
|
+
def __init__(self, config: PreprocessPipelineSettings):
|
|
82
|
+
super().__init__(config)
|
|
83
|
+
self.steps = self._load_steps()
|
|
84
|
+
|
|
85
|
+
def _load_steps(self) -> list[TextPreprocessor[Any]]:
|
|
86
|
+
steps = []
|
|
87
|
+
for step_config in self.config.steps:
|
|
88
|
+
step: TextPreprocessor[Any] = load_class(step_config.module_path)(step_config)
|
|
89
|
+
steps.append(step)
|
|
90
|
+
return steps
|
|
91
|
+
|
|
92
|
+
def apply(self, ds: TextDataset[Any]) -> TextDataset[Any]:
|
|
93
|
+
current = ds
|
|
94
|
+
for step in self.steps:
|
|
95
|
+
current = step.apply(current)
|
|
96
|
+
return current
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from retrievalbase.dataset.preprocess import TokenCounter
|
|
2
|
+
from retrievalbase.dataset.settings import (
|
|
3
|
+
HeuristicTokenCounterSettings,
|
|
4
|
+
HuggingFaceTokenCounterSettings,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HeuristicTokenCounter(TokenCounter[HeuristicTokenCounterSettings]):
|
|
9
|
+
def __init__(self, config: HeuristicTokenCounterSettings) -> None:
|
|
10
|
+
super().__init__(config)
|
|
11
|
+
|
|
12
|
+
def count(self, text: str) -> int:
|
|
13
|
+
return max(1, int(len(text) / self.config.chars_per_token))
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HuggingFaceTokenCounter(TokenCounter[HuggingFaceTokenCounterSettings]):
|
|
17
|
+
def __init__(self, config: HuggingFaceTokenCounterSettings) -> None:
|
|
18
|
+
super().__init__(config)
|
|
19
|
+
try:
|
|
20
|
+
from transformers import AutoTokenizer
|
|
21
|
+
except ImportError as e:
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"HuggingFaceTokenCounter requires transformers.\nInstall with: pip install retrievalbase[transformers]"
|
|
24
|
+
) from e
|
|
25
|
+
|
|
26
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
27
|
+
self.config.name,
|
|
28
|
+
revision=self.config.revision,
|
|
29
|
+
trust_remote_code=True,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def count(self, text: str) -> int:
|
|
33
|
+
return len(
|
|
34
|
+
self.tokenizer( # ty: ignore[call-non-callable]
|
|
35
|
+
text,
|
|
36
|
+
add_special_tokens=self.config.add_special_tokens,
|
|
37
|
+
truncation=False,
|
|
38
|
+
return_attention_mask=False,
|
|
39
|
+
return_token_type_ids=False,
|
|
40
|
+
)["input_ids"]
|
|
41
|
+
)
|